# Credit to original authors
from __future__ import annotations
import asyncio
import os
import tempfile
import warnings
from pathlib import Path
from string import Template
from typing import TYPE_CHECKING, Any, Callable, ClassVar, List, Optional, Protocol, Sequence, Union
from uuid import uuid4
import aiohttp
# async functions shouldn't use open()
from anyio import open_file
from autogen_core import CancellationToken
from autogen_core.code_executor import (
CodeBlock,
CodeExecutor,
CodeResult,
FunctionWithRequirements,
FunctionWithRequirementsStr,
)
from typing_extensions import ParamSpec
from .._common import build_python_functions_file, get_required_packages, to_stub
if TYPE_CHECKING:
from azure.core.credentials import AccessToken
PYTHON_VARIANTS = ["python", "Python", "py"]
__all__ = ("ACADynamicSessionsCodeExecutor", "TokenProvider")
A = ParamSpec("A")
[文档]
class TokenProvider(Protocol):
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken: ...
[文档]
class ACADynamicSessionsCodeExecutor(CodeExecutor):
"""(实验性)一个通过 Azure 容器应用动态会话实例执行代码的代码执行器类。
.. note::
此类需要 :code:`autogen-ext` 包的 :code:`azure` 额外依赖:
.. code-block:: bash
pip install "autogen-ext[azure]"
.. caution::
**这将在 Azure 动态代码容器上执行 LLM 生成的代码。**
执行环境类似于 Jupyter 笔记本,允许增量代码执行。参数函数在每次会话开始时按顺序执行一次。然后每个代码块按接收顺序串行执行。每个环境都有一组静态定义的可用包,这些包无法更改。
目前,尝试使用环境中不可用的包将导致错误。要获取支持的包列表,请调用 `get_available_packages` 函数。
目前唯一支持的语言是 Python。
对于 Python 代码,请使用语言标识 "python" 作为代码块。
Args:
pool_management_endpoint (str): Azure 容器应用动态会话端点。
credential (TokenProvider): 实现 get_token 函数的对象。
timeout (int): 任何单个代码块的执行超时时间。默认为 60。
work_dir (str): 代码执行的工作目录。如果为 None,
将使用默认工作目录。默认工作
目录是一个临时目录。
functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): 可供代码执行器使用的函数列表。默认为空列表。
suppress_result_output bool: 默认情况下,执行器会将执行响应中的任何结果信息附加到结果输出。设置为 True 可阻止此行为。
session_id (str): 代码执行的会话 ID(传递给动态会话)。如果为 None,将生成新的会话 ID。默认为 None。注意调用 `restart` 时此值将被重置
.. note::
使用当前目录 (".") 作为工作目录已弃用。使用它将引发弃用警告。
"""
SUPPORTED_LANGUAGES: ClassVar[List[str]] = [
"python",
]
FUNCTION_PROMPT_TEMPLATE: ClassVar[str] = """You have access to the following user defined functions.
$functions"""
_AZURE_API_VER = "2024-02-02-preview"
def __init__(
self,
pool_management_endpoint: str,
credential: TokenProvider,
timeout: int = 60,
work_dir: Union[Path, str, None] = None,
functions: Sequence[
Union[
FunctionWithRequirements[Any, A],
Callable[..., Any],
FunctionWithRequirementsStr,
]
] = [],
functions_module: str = "functions",
suppress_result_output: bool = False,
session_id: Optional[str] = None,
):
if timeout < 1:
raise ValueError("Timeout must be greater than or equal to 1.")
self._work_dir: Optional[Path] = None
self._temp_dir: Optional[tempfile.TemporaryDirectory[str]] = None
# If a user specifies a working directory, use that
if work_dir is not None:
if isinstance(work_dir, str):
self._work_dir = Path(work_dir)
else:
self._work_dir = work_dir
# Create the directory if it doesn't exist
self._work_dir.mkdir(exist_ok=True, parents=True)
# If a user does not specify a working directory, use the default directory (tempfile.TemporaryDirectory)
else:
self._temp_dir = tempfile.TemporaryDirectory()
temp_dir_path = Path(self._temp_dir.name)
temp_dir_path.mkdir(exist_ok=True, parents=True)
self._started = False
# Rest of initialization remains the same
self._functions_module = functions_module
self._timeout = timeout
self._functions = functions
self._func_code: Optional[str] = None
# Setup could take some time so we intentionally wait for the first code block to do it.
if len(functions) > 0:
self._setup_functions_complete = False
else:
self._setup_functions_complete = True
self._suppress_result_output = suppress_result_output
self._pool_management_endpoint = pool_management_endpoint
self._access_token: str | None = None
self._session_id: str = session_id or str(uuid4())
self._available_packages: set[str] | None = None
self._credential: TokenProvider = credential
# cwd needs to be set to /mnt/data to properly read uploaded files and download written files
self._setup_cwd_complete = False
# TODO: expiration?
def _ensure_access_token(self) -> None:
if not self._access_token:
scope = "https://dynamicsessions.io/.default"
self._access_token = self._credential.get_token(scope).token
@property
def functions_module(self) -> str:
"""(实验性)函数的模块名称。"""
return self._functions_module
@property
def functions(self) -> List[str]:
raise NotImplementedError
@property
def timeout(self) -> int:
"""(实验性)代码执行的超时时间。"""
return self._timeout
@property
def work_dir(self) -> Path:
# If a user specifies a working directory, use that
if self._work_dir is not None:
# If a user specifies the current directory, warn them that this is deprecated
if self._work_dir == Path("."):
warnings.warn(
"Using the current directory as work_dir is deprecated",
DeprecationWarning,
stacklevel=2,
)
return self._work_dir
# If a user does not specify a working directory, use the default directory (tempfile.TemporaryDirectory)
elif self._temp_dir is not None:
return Path(self._temp_dir.name)
else:
raise RuntimeError("Working directory not properly initialized")
def _construct_url(self, path: str) -> str:
endpoint = self._pool_management_endpoint
if not endpoint.endswith("/"):
endpoint += "/"
url = endpoint + f"{path}?api-version={self._AZURE_API_VER}&identifier={self._session_id}"
return url
[文档]
async def get_available_packages(self, cancellation_token: CancellationToken) -> set[str]:
if self._available_packages is not None:
return self._available_packages
avail_pkgs = """
import pkg_resources\n[d.project_name for d in pkg_resources.working_set]
"""
ret = await self._execute_code_dont_check_setup(
[CodeBlock(code=avail_pkgs, language="python")], cancellation_token
)
if ret.exit_code != 0:
raise ValueError(f"Failed to get list of available packages: {ret.output.strip()}")
pkgs = ret.output.strip("[]")
pkglist = pkgs.split(",\n")
return {pkg.strip(" '") for pkg in pkglist}
async def _populate_available_packages(self, cancellation_token: CancellationToken) -> None:
self._available_packages = await self.get_available_packages(cancellation_token)
async def _setup_functions(self, cancellation_token: CancellationToken) -> None:
if not self._func_code:
self._func_code = build_python_functions_file(self._functions)
# Check required function imports and packages
lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)]
# Should we also be checking the imports?
flattened_packages = [item for sublist in lists_of_packages for item in sublist]
required_packages = set(flattened_packages)
if self._available_packages is None:
await self._populate_available_packages(cancellation_token)
if self._available_packages is not None:
missing_pkgs = set(required_packages - self._available_packages)
if len(missing_pkgs) > 0:
raise ValueError(f"Packages unavailable in environment: {missing_pkgs}")
func_file = self.work_dir / f"{self._functions_module}.py"
func_file.write_text(self._func_code)
# Attempt to load the function file to check for syntax errors, imports etc.
exec_result = await self._execute_code_dont_check_setup(
[CodeBlock(code=self._func_code, language="python")], cancellation_token
)
if exec_result.exit_code != 0:
raise ValueError(f"Functions failed to load: {exec_result.output.strip()}")
self._setup_functions_complete = True
async def _setup_cwd(self, cancellation_token: CancellationToken) -> None:
# Change the cwd to /mnt/data to properly have access to uploaded files
exec_result = await self._execute_code_dont_check_setup(
[CodeBlock(code="import os; os.chdir('/mnt/data')", language="python")], cancellation_token
)
if exec_result.exit_code != 0:
raise ValueError("Failed to set up Azure container working directory")
self._setup_cwd_complete = True
[文档]
async def get_file_list(self, cancellation_token: CancellationToken) -> List[str]:
self._ensure_access_token()
timeout = aiohttp.ClientTimeout(total=float(self._timeout))
headers = {
"Authorization": f"Bearer {self._access_token}",
}
url = self._construct_url("files")
async with aiohttp.ClientSession(timeout=timeout) as client:
task = asyncio.create_task(
client.get(
url,
headers=headers,
)
)
cancellation_token.link_future(task)
try:
resp = await task
resp.raise_for_status()
data = await resp.json()
except asyncio.TimeoutError as e:
# e.add_note is only in py 3.11+
raise asyncio.TimeoutError("Timeout getting file list") from e
except asyncio.CancelledError as e:
# e.add_note is only in py 3.11+
raise asyncio.CancelledError("File list retrieval cancelled") from e
except aiohttp.ClientResponseError as e:
raise ConnectionError("Error while getting file list") from e
values = data["value"]
file_info_list: List[str] = []
for value in values:
file = value["properties"]
file_info_list.append(file["filename"])
return file_info_list
[文档]
async def upload_files(self, files: List[Union[Path, str]], cancellation_token: CancellationToken) -> None:
self._ensure_access_token()
# TODO: Better to use the client auth system rather than headers
headers = {"Authorization": f"Bearer {self._access_token}"}
url = self._construct_url("files/upload")
timeout = aiohttp.ClientTimeout(total=float(self._timeout))
async with aiohttp.ClientSession(timeout=timeout) as client:
for file in files:
file_path = self.work_dir / file
if not file_path.is_file():
# TODO: what to do here?
raise FileNotFoundError(f"{file} does not exist")
data = aiohttp.FormData()
async with await open_file(file_path, "rb") as f:
data.add_field(
"file",
f,
filename=os.path.basename(file_path),
content_type="application/octet-stream",
)
task = asyncio.create_task(
client.post(
url,
headers=headers,
data=data,
)
)
cancellation_token.link_future(task)
try:
resp = await task
resp.raise_for_status()
except asyncio.TimeoutError as e:
# e.add_note is only in py 3.11+
raise asyncio.TimeoutError("Timeout uploading files") from e
except asyncio.CancelledError as e:
# e.add_note is only in py 3.11+
raise asyncio.CancelledError("Uploading files cancelled") from e
except aiohttp.ClientResponseError as e:
raise ConnectionError("Error while uploading files") from e
[文档]
async def download_files(self, files: List[Union[Path, str]], cancellation_token: CancellationToken) -> List[str]:
self._ensure_access_token()
available_files = await self.get_file_list(cancellation_token)
# TODO: Better to use the client auth system rather than headers
headers = {"Authorization": f"Bearer {self._access_token}"}
timeout = aiohttp.ClientTimeout(total=float(self._timeout))
local_paths: List[str] = []
async with aiohttp.ClientSession(timeout=timeout) as client:
for file in files:
if file not in available_files:
# TODO: what's the right thing to do here?
raise FileNotFoundError(f"{file} does not exist")
url = self._construct_url(f"files/content/{file}")
task = asyncio.create_task(
client.get(
url,
headers=headers,
)
)
cancellation_token.link_future(task)
try:
resp = await task
resp.raise_for_status()
local_path = self.work_dir / file
local_paths.append(str(local_path))
async with await open_file(local_path, "wb") as f:
await f.write(await resp.read())
except asyncio.TimeoutError as e:
# e.add_note is only in py 3.11+
raise asyncio.TimeoutError("Timeout downloading files") from e
except asyncio.CancelledError as e:
# e.add_note is only in py 3.11+
raise asyncio.CancelledError("Downloading files cancelled") from e
except aiohttp.ClientResponseError as e:
raise ConnectionError("Error while downloading files") from e
return local_paths
[文档]
async def execute_code_blocks(
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
) -> CodeResult:
"""(实验性)执行代码块并返回结果。
Args:
code_blocks (List[CodeBlock]): 要执行的代码块
cancellation_token (CancellationToken): 用于取消操作的令牌
input_files (Optional[Union[Path, str]]): 代码块需要访问的任何文件
Returns:
CodeResult: 代码执行的结果"""
self._ensure_access_token()
if self._available_packages is None:
await self._populate_available_packages(cancellation_token)
if not self._setup_functions_complete:
await self._setup_functions(cancellation_token)
if not self._setup_cwd_complete:
await self._setup_cwd(cancellation_token)
return await self._execute_code_dont_check_setup(code_blocks, cancellation_token)
# The http call here should be replaced by an actual Azure client call once its available
async def _execute_code_dont_check_setup(
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
) -> CodeResult:
logs_all = ""
exitcode = 0
# TODO: Better to use the client auth system rather than headers
assert self._access_token is not None
headers = {
"Authorization": f"Bearer {self._access_token}",
"Content-Type": "application/json",
}
properties = {
"codeInputType": "inline",
"executionType": "synchronous",
"code": "", # Filled in later
}
url = self._construct_url("code/execute")
timeout = aiohttp.ClientTimeout(total=float(self._timeout))
async with aiohttp.ClientSession(timeout=timeout) as client:
for code_block in code_blocks:
lang, code = code_block.language, code_block.code
lang = lang.lower()
if lang in PYTHON_VARIANTS:
lang = "python"
if lang not in self.SUPPORTED_LANGUAGES:
# In case the language is not supported, we return an error message.
exitcode = 1
logs_all += "\n" + f"unknown language {lang}"
break
if self._available_packages is not None:
req_pkgs = get_required_packages(code, lang)
missing_pkgs = set(req_pkgs - self._available_packages)
if len(missing_pkgs) > 0:
# In case the code requires packages that are not available in the environment
exitcode = 1
logs_all += "\n" + f"Python packages unavailable in environment: {missing_pkgs}"
break
properties["code"] = code_block.code
task = asyncio.create_task(
client.post(
url,
headers=headers,
json={"properties": properties},
)
)
cancellation_token.link_future(task)
try:
response = await task
response.raise_for_status()
data = await response.json()
data = data["properties"]
logs_all += data.get("stderr", "") + data.get("stdout", "")
if "Success" in data["status"]:
if not self._suppress_result_output:
logs_all += str(data["result"])
elif "Failure" in data["status"]:
exitcode = 1
except asyncio.TimeoutError as e:
logs_all += "\n Timeout"
# e.add_note is only in py 3.11+
raise asyncio.TimeoutError(logs_all) from e
except asyncio.CancelledError as e:
logs_all += "\n Cancelled"
# e.add_note is only in py 3.11+
raise asyncio.CancelledError(logs_all) from e
except aiohttp.ClientResponseError as e:
logs_all += "\nError while sending code block to endpoint"
raise ConnectionError(logs_all) from e
return CodeResult(exit_code=exitcode, output=logs_all)
[文档]
async def restart(self) -> None:
"""(实验性)重启代码执行器。
通过生成新的会话ID和重置设置变量来重置执行器的内部状态。
这将导致下一次代码执行时重新初始化环境并重新运行任何设置代码。
"""
self._session_id = str(uuid4())
self._setup_functions_complete = False
self._access_token = None
self._available_packages = None
self._setup_cwd_complete = False
[文档]
async def start(self) -> None:
"""(实验性)启动代码执行器。
将代码执行器标记为已启动。"""
# No setup needed for this executor
self._started = True
[文档]
async def stop(self) -> None:
"""(实验性)停止代码执行器。
在清理临时工作目录(如果已创建)后停止代码执行器。"""
if self._temp_dir is not None:
self._temp_dir.cleanup()
self._temp_dir = None
self._started = False