autogen_core.tools._function_tool 源代码
import asyncio
import functools
import warnings
from textwrap import dedent
from typing import Any, Callable, Sequence
from pydantic import BaseModel
from typing_extensions import Self
from .. import CancellationToken
from .._component_config import Component
from .._function_utils import (
args_base_model_from_signature,
get_typed_signature,
)
from ..code_executor._func_with_reqs import Import, import_to_str, to_code
from ._base import BaseTool
class FunctionToolConfig(BaseModel):
"""函数工具的配置。"""
source_code: str
name: str
description: str
global_imports: Sequence[Import]
has_cancellation_support: bool
[文档]
class FunctionTool(BaseTool[BaseModel, BaseModel], Component[FunctionToolConfig]):
"""
通过包装标准 Python 函数创建自定义工具。
`FunctionTool` 提供了异步或同步执行 Python 函数的接口。
每个函数必须包含所有参数的类型注解及其返回类型。这些注解
使 `FunctionTool` 能够生成必要的模式,用于输入验证、序列化以及
向大语言模型(LLM)说明预期参数。当 LLM 准备函数调用时,它会利用此模式
生成符合函数规范的参数。
.. note::
用户需自行验证工具的输出类型是否符合预期类型。
Args:
func (Callable[..., ReturnT | Awaitable[ReturnT]]): 要包装并作为工具公开的函数。
description (str): 向模型说明函数用途的描述,指定其功能
及应调用的上下文。
name (str, optional): 工具的可选自定义名称。若未提供,则默认
使用函数的原始名称。
strict (bool, optional): 若设为 True,工具模式将仅包含函数签名中
明确定义的参数,且不允许默认值。默认为 False。
在结构化输出模式下使用时必须设为 True。
Example:
.. code-block:: python
import random
from autogen_core import CancellationToken
from autogen_core.tools import FunctionTool
from typing_extensions import Annotated
import asyncio
async def get_stock_price(ticker: str, date: Annotated[str, "Date in YYYY/MM/DD"]) -> float:
# 通过返回指定范围内的随机浮点数模拟股票价格检索。
return random.uniform(10, 200)
async def example():
# 初始化用于检索股票价格的 FunctionTool 实例。
stock_price_tool = FunctionTool(get_stock_price, description="获取给定股票代码的价格。")
# 执行支持取消的工具。
cancellation_token = CancellationToken()
result = await stock_price_tool.run_json({"ticker": "AAPL", "date": "2021/01/01"}, cancellation_token)
# 将结果格式化为字符串输出。
print(stock_price_tool.return_value_as_string(result))
asyncio.run(example())
"""
component_provider_override = "autogen_core.tools.FunctionTool"
component_config_schema = FunctionToolConfig
def __init__(
self,
func: Callable[..., Any],
description: str,
name: str | None = None,
global_imports: Sequence[Import] = [],
strict: bool = False,
) -> None:
self._func = func
self._global_imports = global_imports
self._signature = get_typed_signature(func)
func_name = name or func.func.__name__ if isinstance(func, functools.partial) else name or func.__name__
args_model = args_base_model_from_signature(func_name + "args", self._signature)
self._has_cancellation_support = "cancellation_token" in self._signature.parameters
return_type = self._signature.return_annotation
super().__init__(args_model, return_type, func_name, description, strict)
[文档]
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
kwargs = {}
for name in self._signature.parameters.keys():
if hasattr(args, name):
kwargs[name] = getattr(args, name)
if asyncio.iscoroutinefunction(self._func):
if self._has_cancellation_support:
result = await self._func(**kwargs, cancellation_token=cancellation_token)
else:
result = await self._func(**kwargs)
else:
if self._has_cancellation_support:
result = await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
self._func,
**kwargs,
cancellation_token=cancellation_token,
),
)
else:
future = asyncio.get_event_loop().run_in_executor(None, functools.partial(self._func, **kwargs))
cancellation_token.link_future(future)
result = await future
return result
[文档]
def _to_config(self) -> FunctionToolConfig:
return FunctionToolConfig(
source_code=dedent(to_code(self._func)),
global_imports=self._global_imports,
name=self.name,
description=self.description,
has_cancellation_support=self._has_cancellation_support,
)
[文档]
@classmethod
def _from_config(cls, config: FunctionToolConfig) -> Self:
warnings.warn(
"\n⚠️ SECURITY WARNING ⚠️\n"
"Loading a FunctionTool from config will execute code to import the provided global imports and and function code.\n"
"Only load configs from TRUSTED sources to prevent arbitrary code execution.",
UserWarning,
stacklevel=2,
)
exec_globals: dict[str, Any] = {}
# Execute imports first
for import_stmt in config.global_imports:
import_code = import_to_str(import_stmt)
try:
exec(import_code, exec_globals)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Failed to import {import_code}: Module not found. Please ensure the module is installed."
) from e
except ImportError as e:
raise ImportError(f"Failed to import {import_code}: {str(e)}") from e
except Exception as e:
raise RuntimeError(f"Unexpected error while importing {import_code}: {str(e)}") from e
# Execute function code
try:
exec(config.source_code, exec_globals)
func_name = config.source_code.split("def ")[1].split("(")[0]
except Exception as e:
raise ValueError(f"Could not compile and load function: {e}") from e
# Get function and verify it's callable
func: Callable[..., Any] = exec_globals[func_name]
if not callable(func):
raise TypeError(f"Expected function but got {type(func)}")
return cls(func, name=config.name, description=config.description, global_imports=config.global_imports)