import asyncio
import json
import logging
import os
from typing import (
Any,
AsyncGenerator,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Set,
cast,
)
from autogen_agentchat import TRACE_LOGGER_NAME
from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import (
AgentEvent,
BaseChatMessage,
ChatMessage,
HandoffMessage,
MultiModalMessage,
StopMessage,
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
)
from autogen_core import CancellationToken, FunctionCall
from autogen_core.models._types import FunctionExecutionResult
from autogen_core.tools import FunctionTool, Tool
import azure.ai.projects.models as models
from azure.ai.projects import _types
from azure.ai.projects.aio import AIProjectClient
from ._types import AzureAIAgentState, ListToolType
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
[文档]
class AzureAIAgent(BaseChatAgent):
"""
AutoGen 的 Azure AI 助手代理。
安装方法:
.. code-block:: bash
pip install "autogen-ext[azure]" # 用于 Azure AI Foundry 代理服务
该代理利用 Azure AI 助手 API 创建具有以下功能的 AI 助手:
* 代码解释与执行
* 基于 Bing 搜索的落地
* 文件处理与搜索
* 自定义函数调用
* 多轮对话
该代理与 AutoGen 的消息系统集成,提供了一种在 AutoGen 框架内无缝使用 Azure AI 功能的方式。
支持代码解释器、文件搜索等多种落地机制工具。
代理名称必须是有效的 Python 标识符:
1. 必须以字母(A-Z,a-z)或下划线(_)开头
2. 只能包含字母、数字(0-9)或下划线
3. 不能是 Python 关键字
4. 不能包含空格或特殊字符
5. 不能以数字开头
查看如何创建具有用户托管身份的安全代理:
https://learn.microsoft.com/en-us/azure/ai-services/agents/how-to/virtual-networks
示例:
使用 AzureAIAgent 创建基于 Bing 落地的代理:
.. code-block:: python
import asyncio
import os
from autogen_agentchat.messages import TextMessage
from autogen_core import CancellationToken
from autogen_ext.agents.azure._azure_ai_agent import AzureAIAgent
from azure.ai.projects.aio import AIProjectClient
from azure.identity.aio import DefaultAzureCredential
import azure.ai.projects.models as models
import dotenv
async def bing_example():
credential = DefaultAzureCredential()
async with AIProjectClient.from_connection_string( # type: ignore
credential=credential, conn_str=os.getenv("AI_PROJECT_CONNECTION_STRING", "")
) as project_client:
conn = await project_client.connections.get(connection_name=os.getenv("BING_CONNECTION_NAME", ""))
bing_tool = models.BingGroundingTool(conn.id)
agent_with_bing_grounding = AzureAIAgent(
name="bing_agent",
description="An AI assistant with Bing grounding",
project_client=project_client,
deployment_name="gpt-4o",
instructions="You are a helpful assistant.",
tools=bing_tool.definitions,
metadata={"source": "AzureAIAgent"},
)
# 要使 Bing 落地工具返回引用,消息中必须包含让模型返回引用的指令
# 例如:"请为答案提供引用来源"
result = await agent_with_bing_grounding.on_messages(
messages=[
TextMessage(
content="What is Microsoft's annual leave policy? Provide citations for your answers.",
source="user",
)
],
cancellation_token=CancellationToken(),
message_limit=5,
)
print(result)
if __name__ == "__main__":
dotenv.load_dotenv()
asyncio.run(bing_example())
使用 AzureAIAgent 创建具有文件搜索功能的代理:
.. code-block:: python
import asyncio
import os
import tempfile
import urllib.request
import dotenv
from autogen_agentchat.messages import TextMessage
from autogen_core import CancellationToken
from autogen_ext.agents.azure._azure_ai_agent import AzureAIAgent
from azure.ai.projects.aio import AIProjectClient
from azure.identity.aio import DefaultAzureCredential
async def file_search_example():
# 从 GitHub 下载 README.md
readme_url = "https://raw.githubusercontent.com/microsoft/autogen/refs/heads/main/README.md"
temp_file = None
try:
# 创建临时文件存储下载的 README
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".md")
urllib.request.urlretrieve(readme_url, temp_file.name)
print(f"Downloaded README.md to {temp_file.name}")
credential = DefaultAzureCredential()
async with AIProjectClient.from_connection_string( # type: ignore
credential=credential, conn_str=os.getenv("AI_PROJECT_CONNECTION_STRING", "")
) as project_client:
agent_with_file_search = AzureAIAgent(
name="file_search_agent",
description="An AI assistant with file search capabilities",
project_client=project_client,
deployment_name="gpt-4o",
instructions="You are a helpful assistant.",
tools=["file_search"],
metadata={"source": "AzureAIAgent"},
)
ct: CancellationToken = CancellationToken()
# 使用下载的 README 文件进行文件搜索
await agent_with_file_search.on_upload_for_file_search(
file_paths=[temp_file.name],
vector_store_name="file_upload_index",
vector_store_metadata={"source": "AzureAIAgent"},
cancellation_token=ct,
)
result = await agent_with_file_search.on_messages(
messages=[
TextMessage(content="Hello, what is AutoGen and what capabilities does it have?", source="user")
],
cancellation_token=ct,
message_limit=5,
)
print(result)
finally:
# 清理临时文件
if temp_file and os.path.exists(temp_file.name):
os.unlink(temp_file.name)
print(f"Removed temporary file {temp_file.name}")
if __name__ == "__main__":
dotenv.load_dotenv()
asyncio.run(file_search_example())
使用 AzureAIAgent 创建具有代码解释器功能的代理:
.. code-block:: python
import asyncio
import os
import dotenv
from autogen_agentchat.messages import TextMessage
from autogen_core import CancellationToken
from autogen_ext.agents.azure._azure_ai_agent import AzureAIAgent
from azure.ai.projects.aio import AIProjectClient
from azure.identity.aio import DefaultAzureCredential
async def code_interpreter_example():
credential = DefaultAzureCredential()
async with AIProjectClient.from_connection_string( # type: ignore
credential=credential, conn_str=os.getenv("AI_PROJECT_CONNECTION_STRING", "")
) as project_client:
agent_with_code_interpreter = AzureAIAgent(
name="code_interpreter_agent",
description="An AI assistant with code interpreter capabilities",
project_client=project_client,
deployment_name="gpt-4o",
instructions="You are a helpful assistant.",
tools=["code_interpreter"],
metadata={"source": "AzureAIAgent"},
)
await agent_with_code_interpreter.on_upload_for_code_interpreter(
file_paths="/workspaces/autogen/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/data/nifty_500_quarterly_results.csv",
cancellation_token=CancellationToken(),
)
result = await agent_with_code_interpreter.on_messages(
messages=[
TextMessage(
content="Aggregate the number of stocks per industry and give me a markdown table as a result?",
source="user",
)
],
cancellation_token=CancellationToken(),
)
print(result)
if __name__ == "__main__":
dotenv.load_dotenv()
asyncio.run(code_interpreter_example())
"""
def __init__(
self,
name: str,
description: str,
project_client: AIProjectClient,
deployment_name: str,
instructions: str,
tools: Optional[ListToolType] = None,
agent_id: Optional[str] = None,
thread_id: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
response_format: Optional["_types.AgentsApiResponseFormatOption"] = None,
temperature: Optional[float] = None,
tool_resources: Optional["models.ToolResources"] = None,
top_p: Optional[float] = None,
) -> None:
"""
初始化 Azure AI 代理。
Args:
name (str): 代理名称,必须是有效的 Python 标识符
description (str): 代理用途的简要描述
project_client (AIProjectClient): 用于 API 交互的 Azure AI 项目客户端
deployment_name (str): 代理使用的模型部署名称(如 "gpt-4")
instructions (str): 代理行为的详细指令
tools (Optional[Iterable[Union[str, ToolDefinition, Tool, Callable]]]): 代理可使用的工具列表
支持的字符串值:"file_search", "code_interpreter", "bing_grounding",
"azure_ai_search", "azure_function", "sharepoint_grounding"
agent_id (Optional[str]): 现有代理 ID,用于复用而非创建新代理
thread_id (Optional[str]): 现有会话线程 ID,用于继续对话
metadata (Optional[Dict[str, str]]): 代理的附加元数据
response_format (Optional[_types.AgentsApiResponseFormatOption]): 代理响应的格式选项
temperature (Optional[float]): 采样温度,控制输出的随机性
tool_resources (Optional[models.ToolResources]): 代理工具的资源配置
top_p (Optional[float]): 温度替代参数,核采样参数
Raises:
ValueError: 当提供不支持的工具类型时抛出
"""
super().__init__(name, description)
if tools is None:
tools = []
self._original_tools: list[Tool] = []
converted_tools: List["models.ToolDefinition"] = []
self._add_tools(tools, converted_tools)
self._project_client = project_client
self._agent: Optional["models.Agent"] = None
self._thread: Optional["models.AgentThread"] = None
self._init_thread_id = thread_id
self._deployment_name = deployment_name
self._instructions = instructions
self._api_tools = converted_tools
self._agent_id = agent_id
self._metadata = metadata
self._response_format = response_format
self._temperature = temperature
self._tool_resources = tool_resources
self._top_p = top_p
self._vector_store_id: Optional[str] = None
self._uploaded_file_ids: List[str] = []
self._initial_message_ids: Set[str] = set()
self._initial_state_retrieved: bool = False
# Properties
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
"""助手代理产生的消息类型"""
return (TextMessage,)
@property
def thread_id(self) -> str:
if self._thread is None:
raise ValueError("Thread not initialized")
return self._thread.id
@property
def _get_agent_id(self) -> str:
if self._agent is None:
raise ValueError("Agent not initialized")
return self._agent.id
@property
def description(self) -> str:
if not self._description:
raise ValueError("Description not initialized")
return self._description
@property
def agent_id(self) -> str:
if not self._agent_id:
raise ValueError("Agent not initialized")
return self._agent_id
@property
def deployment_name(self) -> str:
if not self._deployment_name:
raise ValueError("Deployment name not initialized")
return self._deployment_name
@property
def instructions(self) -> str:
if not self._instructions:
raise ValueError("Instructions not initialized")
return self._instructions
@property
def tools(self) -> List[models.ToolDefinition]:
"""
获取代理可用的工具列表。
Returns:
List[models.ToolDefinition]: 工具定义列表
"""
return self._api_tools
def _add_tools(self, tools: Optional[ListToolType], converted_tools: List["models.ToolDefinition"]) -> None:
"""
将各种工具格式转换为 Azure AI Agent 工具定义。
Args:
tools: 各种格式的工具列表(字符串标识符、ToolDefinition 对象、Tool 对象或可调用对象)
converted_tools: 用于存储转换后工具定义的列表
Raises:
ValueError: 如果提供了不支持的工具类型
"""
if tools is None:
return
for tool in tools:
if isinstance(tool, str):
if tool == "file_search":
converted_tools.append(models.FileSearchToolDefinition())
elif tool == "code_interpreter":
converted_tools.append(models.CodeInterpreterToolDefinition())
elif tool == "bing_grounding":
converted_tools.append(models.BingGroundingToolDefinition()) # type: ignore
elif tool == "azure_ai_search":
converted_tools.append(models.AzureAISearchToolDefinition())
elif tool == "azure_function":
converted_tools.append(models.AzureFunctionToolDefinition()) # type: ignore
elif tool == "sharepoint_grounding":
converted_tools.append(models.SharepointToolDefinition()) # type: ignore
else:
raise ValueError(f"Unsupported tool string: {tool}")
elif isinstance(tool, models.ToolDefinition):
converted_tools.append(tool)
elif isinstance(tool, Tool):
self._original_tools.append(tool)
converted_tools.append(self._convert_tool_to_function_tool_definition(tool))
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
function_tool = FunctionTool(tool, description=description)
self._original_tools.append(function_tool)
converted_tools.append(self._convert_tool_to_function_tool_definition(function_tool))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
def _convert_tool_to_function_tool_definition(self, tool: Tool) -> models.FunctionToolDefinition:
"""
将 autogen Tool 转换为 Azure AI Agent 函数工具定义。
Args:
tool (Tool): 要转换的 AutoGen 工具
Returns:
models.FunctionToolDefinition: 兼容 Azure AI Agent API 的函数工具定义
"""
schema = tool.schema
parameters: Dict[str, object] = {}
if "parameters" in schema:
parameters = {
"type": schema["parameters"]["type"],
"properties": schema["parameters"]["properties"],
}
if "required" in schema["parameters"]:
parameters["required"] = schema["parameters"]["required"]
func_definition = models.FunctionDefinition(name=tool.name, description=tool.description, parameters=parameters)
return models.FunctionToolDefinition(
function=func_definition,
)
async def _ensure_initialized(self, create_new_thread: bool = False, create_new_agent: bool = False) -> None:
"""
确保在执行操作前正确初始化代理和线程。
该方法确保Azure AI代理和线程都已创建或从现有ID中检索到。
同时还会在需要时处理检索现有线程的初始状态。
Args:
create_new_thread (bool): 为True时,即使提供了thread_id也会创建新线程
create_new_agent (bool): 为True时,即使提供了agent_id也会创建新代理
Raises:
ValueError: 如果代理或线程创建失败
"""
if self._agent is None or create_new_agent:
if self._agent_id and create_new_agent is False:
self._agent = await self._project_client.agents.get_agent(agent_id=self._agent_id)
else:
self._agent = await self._project_client.agents.create_agent(
name=self.name,
model=self._deployment_name,
description=self.description,
instructions=self._instructions,
tools=self._api_tools,
metadata=self._metadata,
response_format=self._response_format if self._response_format else None, # type: ignore
temperature=self._temperature,
tool_resources=self._tool_resources if self._tool_resources else None, # type: ignore
top_p=self._top_p,
)
if self._thread is None or create_new_thread:
if self._init_thread_id and create_new_thread is False:
self._thread = await self._project_client.agents.get_thread(thread_id=self._init_thread_id)
# Retrieve initial state only once
if not self._initial_state_retrieved:
await self._retrieve_initial_state()
self._initial_state_retrieved = True
else:
self._thread = await self._project_client.agents.create_thread()
async def _retrieve_initial_state(self) -> None:
"""
检索并存储线程中消息的初始状态。
该方法从现有线程中检索所有消息ID,以跟踪在该代理实例开始与线程交互之前
已存在的消息。它会处理分页以确保捕获所有消息。
"""
# Retrieve all initial message IDs
initial_message_ids: Set[str] = set()
after: str | None = None
while True:
msgs: models.OpenAIPageableListOfThreadMessage = await self._project_client.agents.list_messages(
thread_id=self.thread_id, after=after, order=models.ListSortOrder.ASCENDING, limit=100
)
for msg in msgs.data:
initial_message_ids.add(msg.id)
if not msgs.has_more:
break
after = msgs.data[-1].id
self._initial_message_ids = initial_message_ids
async def _execute_tool_call(self, tool_call: FunctionCall, cancellation_token: CancellationToken) -> str:
"""
执行Azure AI代理请求的工具调用。
Args:
tool_call (FunctionCall): 包含名称和参数的函数调用信息
cancellation_token (CancellationToken): 用于取消处理的令牌
Returns:
str: 工具调用结果的字符串表示
Raises:
ValueError: 如果请求的工具不可用或未注册任何工具
"""
if not self._original_tools:
raise ValueError("No tools are available.")
tool = next((t for t in self._original_tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
return tool.return_value_as_string(result)
async def _upload_files(
self,
file_paths: str | Iterable[str],
purpose: str = "assistant",
sleep_interval: float = 0.5,
cancellation_token: Optional[CancellationToken] = None,
) -> List[str]:
"""
将文件上传至 Azure AI 助手 API。
该方法负责上传一个或多个文件供代理使用,
并在代理状态中跟踪这些文件的ID。
Args:
file_paths (str | Iterable[str]): 要上传的文件路径(单个或多个)
purpose (str): 文件用途,默认为"assistant"
sleep_interval (float): 轮询文件状态时的休眠间隔时间
cancellation_token (Optional[CancellationToken]): 用于取消操作的令牌
Returns:
List[str]: 已上传文件的ID列表
Raises:
ValueError: 如果文件上传失败
"""
if cancellation_token is None:
cancellation_token = CancellationToken()
await self._ensure_initialized()
if isinstance(file_paths, str):
file_paths = [file_paths]
file_ids: List[str] = []
for file_path in file_paths:
file_name = os.path.basename(file_path)
file: models.OpenAIFile = await cancellation_token.link_future(
asyncio.ensure_future(
self._project_client.agents.upload_file_and_poll(
file_path=file_path, purpose=purpose, sleep_interval=sleep_interval
)
)
)
if file.status != models.FileState.PROCESSED:
raise ValueError(f"File upload failed with status {file.status}")
trace_logger.debug(f"File uploaded successfully: {file.id}, {file_name}")
file_ids.append(file.id)
self._uploaded_file_ids.append(file.id)
return file_ids
# Public Methods
[文档]
async def on_messages(
self,
messages: Sequence[BaseChatMessage],
cancellation_token: Optional[CancellationToken] = None,
message_limit: int = 1,
) -> Response:
"""
处理传入消息并返回来自 Azure AI 代理的响应。
该方法是与代理交互的主要入口点。
它委托给 on_messages_stream 并返回最终响应。
Args:
messages (Sequence[ChatMessage]): 要处理的消息序列
cancellation_token (CancellationToken): 用于取消操作的令牌
message_limit (int, optional): 从线程中检索的最大消息数量
Returns:
Response: 代理的响应,包括聊天消息和任何内部事件
Raises:
AssertionError: 如果流未返回最终结果
"""
async for message in self.on_messages_stream(
messages=messages, cancellation_token=cancellation_token, message_limit=message_limit
):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")
[文档]
async def on_messages_stream(
self,
messages: Sequence[BaseChatMessage],
cancellation_token: Optional[CancellationToken] = None,
message_limit: int = 1,
sleep_interval: float = 0.5,
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
"""
处理传入消息并生成来自 Azure AI 代理的流式响应。
该方法处理与 Azure AI 代理的完整交互流程:
1. 处理输入消息
2. 创建并监控运行
3. 处理工具调用及其结果
4. 检索并返回代理的最终响应
该方法在处理过程中会生成事件(如工具调用),
最后生成包含代理消息的完整 Response。
Args:
messages (Sequence[ChatMessage]): 要处理的消息序列
cancellation_token (CancellationToken): 用于取消操作的令牌
message_limit (int, optional): 从线程中检索的最大消息数量
sleep_interval (float, optional): 轮询运行状态时的休眠间隔时间
Yields:
AgentEvent | ChatMessage | Response: 处理过程中的事件及最终响应
Raises:
ValueError: 如果运行失败或未收到来自助手的消息
"""
if cancellation_token is None:
cancellation_token = CancellationToken()
await self._ensure_initialized()
# Process all messages in sequence
for message in messages:
if isinstance(message, (TextMessage, MultiModalMessage)):
await self.handle_text_message(str(message.content), cancellation_token)
elif isinstance(message, (StopMessage, HandoffMessage)):
await self.handle_text_message(message.content, cancellation_token)
# Inner messages for tool calls
inner_messages: List[AgentEvent | ChatMessage] = []
# Create and start a run
run: models.ThreadRun = await cancellation_token.link_future(
asyncio.ensure_future(
self._project_client.agents.create_run(
thread_id=self.thread_id,
agent_id=self._get_agent_id,
)
)
)
# Wait for run completion by polling
while True:
run = await cancellation_token.link_future(
asyncio.ensure_future(
self._project_client.agents.get_run(
thread_id=self.thread_id,
run_id=run.id,
)
)
)
if run.status == models.RunStatus.FAILED:
raise ValueError(f"Run failed: {run.last_error}")
# If the run requires action (function calls), execute tools and continue
if run.status == models.RunStatus.REQUIRES_ACTION and run.required_action is not None:
tool_calls: List[FunctionCall] = []
submit_tool_outputs = getattr(run.required_action, "submit_tool_outputs", None)
if submit_tool_outputs and hasattr(submit_tool_outputs, "tool_calls"):
for required_tool_call in submit_tool_outputs.tool_calls:
if required_tool_call.type == "function":
tool_calls.append(
FunctionCall(
id=required_tool_call.id,
name=required_tool_call.function.name,
arguments=required_tool_call.function.arguments,
)
)
# Add tool call message to inner messages
tool_call_msg = ToolCallRequestEvent(source=self.name, content=tool_calls)
inner_messages.append(tool_call_msg)
trace_logger.debug(tool_call_msg)
yield tool_call_msg
# Execute tool calls and get results
tool_outputs: List[FunctionExecutionResult] = []
# TODO: Support parallel execution of tool calls
for tool_call in tool_calls:
try:
result = await self._execute_tool_call(tool_call, cancellation_token)
is_error = False
except Exception as e:
result = f"Error: {e}"
is_error = True
tool_outputs.append(
FunctionExecutionResult(
content=result, call_id=tool_call.id, is_error=is_error, name=tool_call.name
)
)
# Add tool result message to inner messages
tool_result_msg = ToolCallExecutionEvent(source=self.name, content=tool_outputs)
inner_messages.append(tool_result_msg)
trace_logger.debug(tool_result_msg)
yield tool_result_msg
# Submit tool outputs back to the run
run = await cancellation_token.link_future(
asyncio.ensure_future(
self._project_client.agents.submit_tool_outputs_to_run(
thread_id=self.thread_id,
run_id=run.id,
tool_outputs=[
models.ToolOutput(tool_call_id=t.call_id, output=t.content) for t in tool_outputs
],
)
)
)
continue
if run.status == models.RunStatus.COMPLETED:
break
# TODO support for parameter to control polling interval
await asyncio.sleep(sleep_interval)
# After run is completed, get the messages
trace_logger.debug("Retrieving messages from thread")
agent_messages: models.OpenAIPageableListOfThreadMessage = await cancellation_token.link_future(
asyncio.ensure_future(
self._project_client.agents.list_messages(
thread_id=self.thread_id, order=models.ListSortOrder.DESCENDING, limit=message_limit
)
)
)
if not agent_messages.data:
raise ValueError("No messages received from assistant")
# Get the last message from the agent
last_message: Optional[models.ThreadMessage] = agent_messages.get_last_message_by_role(models.MessageRole.AGENT)
if not last_message:
trace_logger.debug("No message with AGENT role found, falling back to first message")
last_message = agent_messages.data[0] # Fallback to first message
if not last_message.content:
raise ValueError("No content in the last message")
# Extract text content
message_text = ""
for text_message in last_message.text_messages:
message_text += text_message.text.value
# Extract citations
citations: list[Any] = []
# Try accessing annotations directly
annotations = getattr(last_message, "annotations", [])
if isinstance(annotations, list) and annotations:
annotations = cast(List[models.MessageTextUrlCitationAnnotation], annotations)
trace_logger.debug(f"Found {len(annotations)} annotations")
for annotation in annotations:
if hasattr(annotation, "url_citation"): # type: ignore
trace_logger.debug(f"Citation found: {annotation.url_citation.url}")
citations.append(
{"url": annotation.url_citation.url, "title": annotation.url_citation.title, "text": None} # type: ignore
)
# For backwards compatibility
elif hasattr(last_message, "url_citation_annotations") and last_message.url_citation_annotations:
url_annotations = cast(List[Any], last_message.url_citation_annotations)
trace_logger.debug(f"Found {len(url_annotations)} URL citations")
for annotation in url_annotations:
citations.append(
{"url": annotation.url_citation.url, "title": annotation.url_citation.title, "text": None} # type: ignore
)
elif hasattr(last_message, "file_citation_annotations") and last_message.file_citation_annotations:
file_annotations = cast(List[Any], last_message.file_citation_annotations)
trace_logger.debug(f"Found {len(file_annotations)} URL citations")
for annotation in file_annotations:
citations.append(
{"file_id": annotation.file_citation.file_id, "title": None, "text": annotation.file_citation.quote} # type: ignore
)
trace_logger.debug(f"Total citations extracted: {len(citations)}")
# Create the response message with citations as JSON string
chat_message = TextMessage(
source=self.name, content=message_text, metadata={"citations": json.dumps(citations)} if citations else {}
)
# Return the assistant's response as a Response with inner messages
yield Response(chat_message=chat_message, inner_messages=inner_messages)
[文档]
async def handle_text_message(self, content: str, cancellation_token: Optional[CancellationToken] = None) -> None:
"""
处理文本消息,将其添加到对话线程中。
Args:
content (str): 消息的文本内容
cancellation_token (CancellationToken): 用于取消处理的令牌
Returns:
None
"""
if cancellation_token is None:
cancellation_token = CancellationToken()
await cancellation_token.link_future(
asyncio.ensure_future(
self._project_client.agents.create_message(
thread_id=self.thread_id,
content=content,
role=models.MessageRole.USER,
)
)
)
[文档]
async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""
通过创建新线程重置代理的对话。
此方法允许在不丢失代理定义或功能的情况下重置对话。
它会为新的对话创建一个新线程。
Note: 当前 Azure AI Agent API 不支持删除消息,
因此改为创建新线程。
Args:
cancellation_token (CancellationToken): 用于取消处理的令牌
"""
# This will enforce the creation of a new thread
await self._ensure_initialized(create_new_thread=True)
[文档]
async def save_state(self) -> Mapping[str, Any]:
"""
保存代理的当前状态以便将来恢复。
此方法序列化代理的状态,包括代理ID、线程ID、消息ID以及
相关资源如向量存储和上传的文件。
Returns:
Mapping[str, Any]: 包含序列化状态数据的字典
"""
state = AzureAIAgentState(
agent_id=self._agent.id if self._agent else self._agent_id,
thread_id=self._thread.id if self._thread else self._init_thread_id,
initial_message_ids=list(self._initial_message_ids),
vector_store_id=self._vector_store_id,
uploaded_file_ids=self._uploaded_file_ids,
)
return state.model_dump()
[文档]
async def load_state(self, state: Mapping[str, Any]) -> None:
"""
将先前保存的状态加载到当前代理中。
此方法反序列化并恢复先前保存的代理状态,
使代理能够继续之前的对话或会话。
Args:
state (Mapping[str, Any]): 先前保存的状态字典
"""
agent_state = AzureAIAgentState.model_validate(state)
self._agent_id = agent_state.agent_id
self._init_thread_id = agent_state.thread_id
self._initial_message_ids = set(agent_state.initial_message_ids)
self._vector_store_id = agent_state.vector_store_id
self._uploaded_file_ids = agent_state.uploaded_file_ids
[文档]
async def on_upload_for_code_interpreter(
self,
file_paths: str | Iterable[str],
cancellation_token: Optional[CancellationToken] = None,
sleep_interval: float = 0.5,
) -> None:
"""
上传文件以供代码解释器工具使用。
此方法为代理的代码解释器工具上传文件,
并更新线程的工具资源以包含这些文件。
Args:
file_paths (str | Iterable[str]): 要上传的文件路径(单个或多个)
cancellation_token (Optional[CancellationToken]): 用于取消处理的令牌
sleep_interval (float): 轮询文件状态之间的休眠时间
Raises:
ValueError: 如果文件上传失败或代理不具备代码解释器能力
"""
if cancellation_token is None:
cancellation_token = CancellationToken()
await self._ensure_initialized()
file_ids = await self._upload_files(
file_paths=file_paths,
cancellation_token=cancellation_token,
sleep_interval=sleep_interval,
purpose=models.FilePurpose.AGENTS,
)
# Update thread with the new files
thread: models.AgentThread = await cancellation_token.link_future(
asyncio.ensure_future(self._project_client.agents.get_thread(thread_id=self.thread_id))
)
tool_resources: models.ToolResources = thread.tool_resources or models.ToolResources()
code_interpreter_resource = tool_resources.code_interpreter or models.CodeInterpreterToolResource()
existing_file_ids: List[str] = code_interpreter_resource.file_ids or []
existing_file_ids.extend(file_ids)
await cancellation_token.link_future(
asyncio.ensure_future(
self._project_client.agents.update_thread(
thread_id=self.thread_id,
tool_resources=models.ToolResources(
code_interpreter=models.CodeInterpreterToolResource(file_ids=existing_file_ids)
),
)
)
)
[文档]
async def on_upload_for_file_search(
self,
file_paths: str | Iterable[str],
cancellation_token: CancellationToken,
vector_store_name: Optional[str] = None,
data_sources: Optional[List[models.VectorStoreDataSource]] = None,
expires_after: Optional[models.VectorStoreExpirationPolicy] = None,
chunking_strategy: Optional[models.VectorStoreChunkingStrategyRequest] = None,
vector_store_metadata: Optional[Dict[str, str]] = None,
vector_store_polling_sleep_interval: float = 1,
) -> None:
"""
上传文件以供文件搜索工具使用。
此方法处理文件搜索功能相关的文件上传,
必要时创建向量存储,并更新代理配置以使用该向量存储。
Args:
file_paths (str | Iterable[str]): 要上传的文件路径(单个或多个)
cancellation_token (CancellationToken): 用于取消处理的令牌
vector_store_name (Optional[str]): 新建向量存储时指定的名称
data_sources (Optional[List[models.VectorStoreDataSource]]): 向量存储的附加数据源
expires_after (Optional[models.VectorStoreExpirationPolicy]): 向量存储内容的过期策略
chunking_strategy (Optional[models.VectorStoreChunkingStrategyRequest]): 文件内容分块策略
vector_store_metadata (Optional[Dict[str, str]]): 向量存储的附加元数据
vector_store_polling_sleep_interval (float): 轮询向量存储状态之间的休眠时间
Raises:
ValueError: 如果此代理未启用文件搜索功能或文件上传失败
"""
await self._ensure_initialized()
# Check if file_search is enabled in tools
if not any(tool.get("type") == "file_search" for tool in self._api_tools):
raise ValueError(
"File search is not enabled for this assistant. Add a file_search tool when creating the assistant."
)
# Create vector store if not already created
if self._vector_store_id is None:
vector_store: models.VectorStore = await cancellation_token.link_future(
asyncio.ensure_future(
self._project_client.agents.create_vector_store_and_poll(
file_ids=[],
name=vector_store_name,
data_sources=data_sources,
expires_after=expires_after,
chunking_strategy=chunking_strategy,
metadata=vector_store_metadata,
sleep_interval=vector_store_polling_sleep_interval,
)
)
)
self._vector_store_id = vector_store.id
# Update assistant with vector store ID
await cancellation_token.link_future(
asyncio.ensure_future(
self._project_client.agents.update_agent(
agent_id=self._get_agent_id,
tools=self._api_tools,
tool_resources=models.ToolResources(
file_search=models.FileSearchToolResource(vector_store_ids=[self._vector_store_id])
),
)
)
)
file_ids = await self._upload_files(
file_paths=file_paths, cancellation_token=cancellation_token, purpose=models.FilePurpose.AGENTS
)
# Create file batch with the file IDs
await cancellation_token.link_future(
asyncio.ensure_future(
self._project_client.agents.create_vector_store_file_batch_and_poll(
vector_store_id=self._vector_store_id, file_ids=file_ids
)
)
)
if __name__ == "__main__":
# Example usage of AzureAIAgent
# Replace with your actual connection string and credentials
"""
TODO:
[X] Support for file upload
[] Support for sharepoint grounding
[] Support for azure function grounding
[X] Support for file search
[X] Support for custom function calling
[X] Add metadata to the thread (agent_id, source ="AUTODGEN_AGENT")
"""