autogen_core.tool_agent._caller_loop 源代码
import asyncio
from typing import List
from .. import AgentId, AgentRuntime, BaseAgent, CancellationToken, FunctionCall
from ..models import (
AssistantMessage,
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
)
from ..tools import Tool, ToolSchema
from ._tool_agent import ToolException
[文档]
async def tool_agent_caller_loop(
caller: BaseAgent | AgentRuntime,
tool_agent_id: AgentId,
model_client: ChatCompletionClient,
input_messages: List[LLMMessage],
tool_schema: List[ToolSchema] | List[Tool],
cancellation_token: CancellationToken | None = None,
caller_source: str = "assistant",
) -> List[LLMMessage]:
"""启动工具代理的调用循环。该函数以交替方式向工具代理和模型客户端发送消息,
直到模型客户端停止生成工具调用。
Args:
tool_agent_id (AgentId): 工具代理的代理ID。
input_messages (List[LLMMessage]): 输入消息列表。
model_client (ChatCompletionClient): 用于模型API的模型客户端。
tool_schema (List[Tool | ToolSchema]): 模型可以使用的工具列表。
Returns:
List[LLMMessage]: 调用循环中创建的输出消息列表。
"""
generated_messages: List[LLMMessage] = []
# Get a response from the model.
response = await model_client.create(input_messages, tools=tool_schema, cancellation_token=cancellation_token)
# Add the response to the generated messages.
generated_messages.append(AssistantMessage(content=response.content, source=caller_source))
# Keep iterating until the model stops generating tool calls.
while isinstance(response.content, list) and all(isinstance(item, FunctionCall) for item in response.content):
# Execute functions called by the model by sending messages to tool agent.
results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
*[
caller.send_message(
message=call,
recipient=tool_agent_id,
cancellation_token=cancellation_token,
)
for call in response.content
],
return_exceptions=True,
)
# Combine the results into a single response and handle exceptions.
function_results: List[FunctionExecutionResult] = []
for result in results:
if isinstance(result, FunctionExecutionResult):
function_results.append(result)
elif isinstance(result, ToolException):
function_results.append(
FunctionExecutionResult(
content=f"Error: {result}", call_id=result.call_id, is_error=True, name=result.name
)
)
elif isinstance(result, BaseException):
raise result # Unexpected exception.
generated_messages.append(FunctionExecutionResultMessage(content=function_results))
# Query the model again with the new response.
response = await model_client.create(
input_messages + generated_messages, tools=tool_schema, cancellation_token=cancellation_token
)
generated_messages.append(AssistantMessage(content=response.content, source=caller_source))
# Return the generated messages.
return generated_messages