from __future__ import annotations
import asyncio
import inspect
import json
import logging
import sys
import uuid
import warnings
from asyncio import CancelledError, Future, Queue, Task
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
from opentelemetry.trace import TracerProvider
from .logging import (
AgentConstructionExceptionEvent,
DeliveryStage,
MessageDroppedEvent,
MessageEvent,
MessageHandlerExceptionEvent,
MessageKind,
)
if sys.version_info >= (3, 13):
from asyncio import Queue, QueueShutDown
else:
from ._queue import Queue, QueueShutDown # type: ignore
from ._agent import Agent
from ._agent_id import AgentId
from ._agent_instantiation import AgentInstantiationContext
from ._agent_metadata import AgentMetadata
from ._agent_runtime import AgentRuntime
from ._agent_type import AgentType
from ._cancellation_token import CancellationToken
from ._intervention import DropMessage, InterventionHandler
from ._message_context import MessageContext
from ._message_handler_context import MessageHandlerContext
from ._runtime_impl_helpers import SubscriptionManager, get_impl
from ._serialization import JSON_DATA_CONTENT_TYPE, MessageSerializer, SerializationRegistry
from ._subscription import Subscription
from ._telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata
from ._topic import TopicId
from .exceptions import MessageDroppedException
logger = logging.getLogger("autogen_core")
event_logger = logging.getLogger("autogen_core.events")
# We use a type parameter in some functions which shadows the built-in `type` function.
# This is a workaround to avoid shadowing the built-in `type` function.
type_func_alias = type
@dataclass(kw_only=True)
class PublishMessageEnvelope:
"""用于向所有能处理类型 T 消息的代理发布消息的消息信封。"""
message: Any
cancellation_token: CancellationToken
sender: AgentId | None
topic_id: TopicId
metadata: EnvelopeMetadata | None = None
message_id: str
@dataclass(kw_only=True)
class SendMessageEnvelope:
"""用于向特定能处理类型 T 消息的代理发送消息的消息信封。"""
message: Any
sender: AgentId | None
recipient: AgentId
future: Future[Any]
cancellation_token: CancellationToken
metadata: EnvelopeMetadata | None = None
message_id: str
@dataclass(kw_only=True)
class ResponseMessageEnvelope:
"""用于发送消息回复的消息信封。"""
message: Any
future: Future[Any]
sender: AgentId
recipient: AgentId | None
metadata: EnvelopeMetadata | None = None
P = ParamSpec("P")
T = TypeVar("T", bound=Agent)
class RunContext:
def __init__(self, runtime: SingleThreadedAgentRuntime) -> None:
self._runtime = runtime
self._run_task = asyncio.create_task(self._run())
self._stopped = asyncio.Event()
async def _run(self) -> None:
while True:
if self._stopped.is_set():
return
await self._runtime._process_next() # type: ignore
async def stop(self) -> None:
self._stopped.set()
self._runtime._message_queue.shutdown(immediate=True) # type: ignore
await self._run_task
async def stop_when_idle(self) -> None:
await self._runtime._message_queue.join() # type: ignore
self._stopped.set()
self._runtime._message_queue.shutdown(immediate=True) # type: ignore
await self._run_task
async def stop_when(self, condition: Callable[[], bool], check_period: float = 1.0) -> None:
async def check_condition() -> None:
while not condition():
await asyncio.sleep(check_period)
await self.stop()
await asyncio.create_task(check_condition())
def _warn_if_none(value: Any, handler_name: str) -> None:
"""
用于检查干预处理程序是否返回 None 并发出警告的实用函数。
Args:
value: 要检查的返回值
handler_name: 警告消息中使用的干预处理方法名称
"""
if value is None:
warnings.warn(
f"Intervention handler {handler_name} returned None. This might be unintentional. "
"Consider returning the original message or DropMessage explicitly.",
RuntimeWarning,
stacklevel=2,
)
[文档]
class SingleThreadedAgentRuntime(AgentRuntime):
"""一个单线程代理运行时,使用单个 asyncio 队列处理所有消息。
消息按接收顺序传递,运行时在单独的 asyncio 任务中并发处理每条消息。
.. note::
此运行时适用于开发和独立应用程序。
不适合高吞吐量或高并发场景。
Args:
intervention_handlers (List[InterventionHandler], optional): 可以在消息发送或发布前拦截的干预处理程序列表。默认为 None。
tracer_provider (TracerProvider, optional): 用于追踪的追踪器提供者。默认为 None。
ignore_unhandled_exceptions (bool, optional): 是否忽略代理事件处理程序中发生的未处理异常。任何后台异常将在下次调用 `process_next` 或等待 `stop`、`stop_when_idle` 或 `stop_when` 时抛出。注意,这不适用于 RPC 处理程序。默认为 True。
Examples:
创建运行时、注册代理、发送消息并停止运行时的简单示例:
.. code-block:: python
import asyncio
from dataclasses import dataclass
from autogen_core import AgentId, MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler
@dataclass
class MyMessage:
content: str
class MyAgent(RoutedAgent):
@message_handler
async def handle_my_message(self, message: MyMessage, ctx: MessageContext) -> None:
print(f"Received message: {message.content}")
async def main() -> None:
# Create a runtime and register the agent
runtime = SingleThreadedAgentRuntime()
await MyAgent.register(runtime, "my_agent", lambda: MyAgent("My agent"))
# Start the runtime, send a message and stop the runtime
runtime.start()
await runtime.send_message(MyMessage("Hello, world!"), recipient=AgentId("my_agent", "default"))
await runtime.stop()
asyncio.run(main())
创建运行时、注册代理、发布消息并停止运行时的示例:
.. code-block:: python
import asyncio
from dataclasses import dataclass
from autogen_core import (
DefaultTopicId,
MessageContext,
RoutedAgent,
SingleThreadedAgentRuntime,
default_subscription,
message_handler,
)
@dataclass
class MyMessage:
content: str
# The agent is subscribed to the default topic.
@default_subscription
class MyAgent(RoutedAgent):
@message_handler
async def handle_my_message(self, message: MyMessage, ctx: MessageContext) -> None:
print(f"Received message: {message.content}")
async def main() -> None:
# Create a runtime and register the agent
runtime = SingleThreadedAgentRuntime()
await MyAgent.register(runtime, "my_agent", lambda: MyAgent("My agent"))
# Start the runtime.
runtime.start()
# Publish a message to the default topic that the agent is subscribed to.
await runtime.publish_message(MyMessage("Hello, world!"), DefaultTopicId())
# Wait for the message to be processed and then stop the runtime.
await runtime.stop_when_idle()
asyncio.run(main())
"""
def __init__(
self,
*,
intervention_handlers: List[InterventionHandler] | None = None,
tracer_provider: TracerProvider | None = None,
ignore_unhandled_exceptions: bool = True,
) -> None:
self._tracer_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("SingleThreadedAgentRuntime"))
self._message_queue: Queue[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = Queue()
# (namespace, type) -> List[AgentId]
self._agent_factories: Dict[
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
] = {}
self._instantiated_agents: Dict[AgentId, Agent] = {}
self._intervention_handlers = intervention_handlers
self._background_tasks: Set[Task[Any]] = set()
self._subscription_manager = SubscriptionManager()
self._run_context: RunContext | None = None
self._serialization_registry = SerializationRegistry()
self._ignore_unhandled_handler_exceptions = ignore_unhandled_exceptions
self._background_exception: BaseException | None = None
self._agent_instance_types: Dict[str, Type[Agent]] = {}
@property
def unprocessed_messages_count(
self,
) -> int:
return self._message_queue.qsize()
@property
def _known_agent_names(self) -> Set[str]:
return set(self._agent_factories.keys())
async def _create_otel_attributes(
self,
sender_agent_id: AgentId | None = None,
recipient_agent_id: AgentId | None = None,
message_context: MessageContext | None = None,
message: Any = None,
) -> Mapping[str, str]:
"""为给定的代理和消息创建 OpenTelemetry 属性。
Args:
sender_agent (Agent, optional): 发送方代理实例。
recipient_agent (Agent, optional): 接收方代理实例。
message (Any): 消息实例。
Returns:
Attributes: OpenTelemetry 属性字典。
"""
if not sender_agent_id and not recipient_agent_id and not message:
return {}
attributes: Dict[str, str] = {}
if sender_agent_id:
sender_agent = await self._get_agent(sender_agent_id)
attributes["sender_agent_type"] = sender_agent.id.type
attributes["sender_agent_class"] = sender_agent.__class__.__name__
if recipient_agent_id:
recipient_agent = await self._get_agent(recipient_agent_id)
attributes["recipient_agent_type"] = recipient_agent.id.type
attributes["recipient_agent_class"] = recipient_agent.__class__.__name__
if message_context:
serialized_message_context = {
"sender": str(message_context.sender),
"topic_id": str(message_context.topic_id),
"is_rpc": message_context.is_rpc,
"message_id": message_context.message_id,
}
attributes["message_context"] = json.dumps(serialized_message_context)
if message:
try:
serialized_message = self._try_serialize(message)
except Exception as e:
serialized_message = str(e)
else:
serialized_message = "No Message"
attributes["message"] = serialized_message
return attributes
# Returns the response of the message
[文档]
async def send_message(
self,
message: Any,
recipient: AgentId,
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> Any:
if cancellation_token is None:
cancellation_token = CancellationToken()
if message_id is None:
message_id = str(uuid.uuid4())
event_logger.info(
MessageEvent(
payload=self._try_serialize(message),
sender=sender,
receiver=recipient,
kind=MessageKind.DIRECT,
delivery_stage=DeliveryStage.SEND,
)
)
with self._tracer_helper.trace_block(
"create",
recipient,
parent=None,
extraAttributes={"message_type": type(message).__name__},
):
future = asyncio.get_event_loop().create_future()
if recipient.type not in self._known_agent_names:
future.set_exception(Exception("Recipient not found"))
return await future
content = message.__dict__ if hasattr(message, "__dict__") else message
logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}")
await self._message_queue.put(
SendMessageEnvelope(
message=message,
recipient=recipient,
future=future,
cancellation_token=cancellation_token,
sender=sender,
metadata=get_telemetry_envelope_metadata(),
message_id=message_id,
)
)
cancellation_token.link_future(future)
return await future
[文档]
async def publish_message(
self,
message: Any,
topic_id: TopicId,
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> None:
with self._tracer_helper.trace_block(
"create",
topic_id,
parent=None,
extraAttributes={"message_type": type(message).__name__},
):
if cancellation_token is None:
cancellation_token = CancellationToken()
content = message.__dict__ if hasattr(message, "__dict__") else message
logger.info(f"Publishing message of type {type(message).__name__} to all subscribers: {content}")
if message_id is None:
message_id = str(uuid.uuid4())
event_logger.info(
MessageEvent(
payload=self._try_serialize(message),
sender=sender,
receiver=topic_id,
kind=MessageKind.PUBLISH,
delivery_stage=DeliveryStage.SEND,
)
)
await self._message_queue.put(
PublishMessageEnvelope(
message=message,
cancellation_token=cancellation_token,
sender=sender,
topic_id=topic_id,
metadata=get_telemetry_envelope_metadata(),
message_id=message_id,
)
)
[文档]
async def save_state(self) -> Mapping[str, Any]:
"""保存所有已实例化代理的状态。
此方法调用每个代理的 :meth:`~autogen_core.BaseAgent.save_state` 方法,并返回一个字典,
将代理ID映射到其状态。
.. note::
此方法当前不保存订阅状态。我们将在未来添加此功能。
Returns:
一个将代理ID映射到其状态的字典。
"""
state: Dict[str, Dict[str, Any]] = {}
for agent_id in self._instantiated_agents:
state[str(agent_id)] = dict(await (await self._get_agent(agent_id)).save_state())
return state
[文档]
async def load_state(self, state: Mapping[str, Any]) -> None:
"""加载所有已实例化代理的状态。
此方法使用字典中提供的状态调用每个代理的 :meth:`~autogen_core.BaseAgent.load_state` 方法。
字典的键是代理ID,值是由 :meth:`~autogen_core.BaseAgent.save_state` 方法返回的状态字典。
.. note::
此方法当前不加载订阅状态。我们将在未来添加此功能。
"""
for agent_id_str in state:
agent_id = AgentId.from_str(agent_id_str)
if agent_id.type in self._known_agent_names:
await (await self._get_agent(agent_id)).load_state(state[str(agent_id)])
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
with self._tracer_helper.trace_block("send", message_envelope.recipient, parent=message_envelope.metadata):
recipient = message_envelope.recipient
if recipient.type not in self._known_agent_names:
raise LookupError(f"Agent type '{recipient.type}' does not exist.")
try:
sender_id = str(message_envelope.sender) if message_envelope.sender is not None else "Unknown"
logger.info(
f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_id}"
)
event_logger.info(
MessageEvent(
payload=self._try_serialize(message_envelope.message),
sender=message_envelope.sender,
receiver=recipient,
kind=MessageKind.DIRECT,
delivery_stage=DeliveryStage.DELIVER,
)
)
recipient_agent = await self._get_agent(recipient)
message_context = MessageContext(
sender=message_envelope.sender,
topic_id=None,
is_rpc=True,
cancellation_token=message_envelope.cancellation_token,
message_id=message_envelope.message_id,
)
with self._tracer_helper.trace_block(
"process",
recipient_agent.id,
parent=message_envelope.metadata,
attributes=await self._create_otel_attributes(
sender_agent_id=message_envelope.sender,
recipient_agent_id=recipient,
message_context=message_context,
message=message_envelope.message,
),
):
with MessageHandlerContext.populate_context(recipient_agent.id):
response = await recipient_agent.on_message(
message_envelope.message,
ctx=message_context,
)
except CancelledError as e:
if not message_envelope.future.cancelled():
message_envelope.future.set_exception(e)
self._message_queue.task_done()
event_logger.info(
MessageHandlerExceptionEvent(
payload=self._try_serialize(message_envelope.message),
handling_agent=recipient,
exception=e,
)
)
return
except BaseException as e:
message_envelope.future.set_exception(e)
self._message_queue.task_done()
event_logger.info(
MessageHandlerExceptionEvent(
payload=self._try_serialize(message_envelope.message),
handling_agent=recipient,
exception=e,
)
)
return
event_logger.info(
MessageEvent(
payload=self._try_serialize(response),
sender=message_envelope.recipient,
receiver=message_envelope.sender,
kind=MessageKind.RESPOND,
delivery_stage=DeliveryStage.SEND,
)
)
await self._message_queue.put(
ResponseMessageEnvelope(
message=response,
future=message_envelope.future,
sender=message_envelope.recipient,
recipient=message_envelope.sender,
metadata=get_telemetry_envelope_metadata(),
)
)
self._message_queue.task_done()
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
with self._tracer_helper.trace_block("publish", message_envelope.topic_id, parent=message_envelope.metadata):
try:
responses: List[Awaitable[Any]] = []
recipients = await self._subscription_manager.get_subscribed_recipients(message_envelope.topic_id)
for agent_id in recipients:
# Avoid sending the message back to the sender
if message_envelope.sender is not None and agent_id == message_envelope.sender:
continue
sender_agent = (
await self._get_agent(message_envelope.sender) if message_envelope.sender is not None else None
)
sender_name = str(sender_agent.id) if sender_agent is not None else "Unknown"
logger.info(
f"Calling message handler for {agent_id.type} with message type {type(message_envelope.message).__name__} published by {sender_name}"
)
event_logger.info(
MessageEvent(
payload=self._try_serialize(message_envelope.message),
sender=message_envelope.sender,
receiver=None,
kind=MessageKind.PUBLISH,
delivery_stage=DeliveryStage.DELIVER,
)
)
message_context = MessageContext(
sender=message_envelope.sender,
topic_id=message_envelope.topic_id,
is_rpc=False,
cancellation_token=message_envelope.cancellation_token,
message_id=message_envelope.message_id,
)
agent = await self._get_agent(agent_id)
async def _on_message(agent: Agent, message_context: MessageContext) -> Any:
with self._tracer_helper.trace_block(
"process",
agent.id,
parent=message_envelope.metadata,
attributes=await self._create_otel_attributes(
sender_agent_id=message_envelope.sender,
recipient_agent_id=agent.id,
message_context=message_context,
message=message_envelope.message,
),
):
with MessageHandlerContext.populate_context(agent.id):
try:
return await agent.on_message(
message_envelope.message,
ctx=message_context,
)
except BaseException as e:
logger.error(f"Error processing publish message for {agent.id}", exc_info=True)
event_logger.info(
MessageHandlerExceptionEvent(
payload=self._try_serialize(message_envelope.message),
handling_agent=agent.id,
exception=e,
)
)
raise e
future = _on_message(agent, message_context)
responses.append(future)
await asyncio.gather(*responses)
except BaseException as e:
if not self._ignore_unhandled_handler_exceptions:
self._background_exception = e
finally:
self._message_queue.task_done()
# TODO if responses are given for a publish
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
with self._tracer_helper.trace_block(
"ack",
message_envelope.recipient,
parent=message_envelope.metadata,
attributes=await self._create_otel_attributes(
sender_agent_id=message_envelope.sender,
recipient_agent_id=message_envelope.recipient,
message=message_envelope.message,
),
):
content = (
message_envelope.message.__dict__
if hasattr(message_envelope.message, "__dict__")
else message_envelope.message
)
logger.info(
f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.type}: {content}"
)
event_logger.info(
MessageEvent(
payload=self._try_serialize(message_envelope.message),
sender=message_envelope.sender,
receiver=message_envelope.recipient,
kind=MessageKind.RESPOND,
delivery_stage=DeliveryStage.DELIVER,
)
)
if not message_envelope.future.cancelled():
message_envelope.future.set_result(message_envelope.message)
self._message_queue.task_done()
[文档]
async def process_next(self) -> None:
"""处理队列中的下一条消息。
如果后台任务中存在未处理的异常,将在此处抛出。在抛出未处理异常后,不能再次调用 `process_next`。
"""
await self._process_next()
async def _process_next(self) -> None:
"""处理队列中的下一条消息。"""
if self._background_exception is not None:
e = self._background_exception
self._background_exception = None
self._message_queue.shutdown(immediate=True) # type: ignore
raise e
try:
message_envelope = await self._message_queue.get()
except QueueShutDown:
if self._background_exception is not None:
e = self._background_exception
self._background_exception = None
raise e from None
return
match message_envelope:
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._intervention_handlers is not None:
for handler in self._intervention_handlers:
with self._tracer_helper.trace_block(
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
):
try:
message_context = MessageContext(
sender=sender,
topic_id=None,
is_rpc=True,
cancellation_token=message_envelope.cancellation_token,
message_id=message_envelope.message_id,
)
temp_message = await handler.on_send(
message, message_context=message_context, recipient=recipient
)
_warn_if_none(temp_message, "on_send")
except BaseException as e:
future.set_exception(e)
return
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
event_logger.info(
MessageDroppedEvent(
payload=self._try_serialize(message),
sender=sender,
receiver=recipient,
kind=MessageKind.DIRECT,
)
)
future.set_exception(MessageDroppedException())
return
message_envelope.message = temp_message
task = asyncio.create_task(self._process_send(message_envelope))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
case PublishMessageEnvelope(
message=message,
sender=sender,
topic_id=topic_id,
):
if self._intervention_handlers is not None:
for handler in self._intervention_handlers:
with self._tracer_helper.trace_block(
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
):
try:
message_context = MessageContext(
sender=sender,
topic_id=topic_id,
is_rpc=False,
cancellation_token=message_envelope.cancellation_token,
message_id=message_envelope.message_id,
)
temp_message = await handler.on_publish(message, message_context=message_context)
_warn_if_none(temp_message, "on_publish")
except BaseException as e:
# TODO: we should raise the intervention exception to the publisher.
logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True)
return
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
event_logger.info(
MessageDroppedEvent(
payload=self._try_serialize(message),
sender=sender,
receiver=topic_id,
kind=MessageKind.PUBLISH,
)
)
return
message_envelope.message = temp_message
task = asyncio.create_task(self._process_publish(message_envelope))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._intervention_handlers is not None:
for handler in self._intervention_handlers:
try:
temp_message = await handler.on_response(message, sender=sender, recipient=recipient)
_warn_if_none(temp_message, "on_response")
except BaseException as e:
# TODO: should we raise the exception to sender of the response instead?
future.set_exception(e)
return
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
event_logger.info(
MessageDroppedEvent(
payload=self._try_serialize(message),
sender=sender,
receiver=recipient,
kind=MessageKind.RESPOND,
)
)
future.set_exception(MessageDroppedException())
return
message_envelope.message = temp_message
task = asyncio.create_task(self._process_response(message_envelope))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
# Yield control to the message loop to allow other tasks to run
await asyncio.sleep(0)
[文档]
def start(self) -> None:
"""启动运行时消息处理循环。此操作在后台任务中运行。
示例:
.. code-block:: python
import asyncio
from autogen_core import SingleThreadedAgentRuntime
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
runtime.start()
# ... do other things ...
await runtime.stop()
asyncio.run(main())
"""
if self._run_context is not None:
raise RuntimeError("Runtime is already started")
self._run_context = RunContext(self)
[文档]
async def close(self) -> None:
"""如果适用则调用 :meth:`stop` 方法,并对所有已实例化的代理调用 :meth:`Agent.close` 方法"""
# stop the runtime if it hasn't been stopped yet
if self._run_context is not None:
await self.stop()
# close all the agents that have been instantiated
for agent_id in self._instantiated_agents:
agent = await self._get_agent(agent_id)
await agent.close()
[文档]
async def stop(self) -> None:
"""立即停止运行时消息处理循环。当前正在处理的消息将会完成,但之后的所有消息都将被丢弃。"""
if self._run_context is None:
raise RuntimeError("Runtime is not started")
try:
await self._run_context.stop()
finally:
self._run_context = None
self._message_queue = Queue()
[文档]
async def stop_when_idle(self) -> None:
"""当没有正在处理或排队中的消息时,停止运行时消息处理循环。这是停止运行时最常用的方式。"""
if self._run_context is None:
raise RuntimeError("Runtime is not started")
try:
await self._run_context.stop_when_idle()
finally:
self._run_context = None
self._message_queue = Queue()
[文档]
async def stop_when(self, condition: Callable[[], bool]) -> None:
"""当满足条件时停止运行时消息处理循环。
.. 警告::
不建议使用此方法,此处仅为向后兼容保留。
该方法会生成一个繁忙循环来持续检查条件。
调用`stop_when_idle`或`stop`会更加高效。
如果需要基于条件停止运行时,建议使用后台任务和asyncio.Event来
在条件满足时发出信号,然后由后台任务调用stop方法。
"""
if self._run_context is None:
raise RuntimeError("Runtime is not started")
await self._run_context.stop_when(condition)
self._run_context = None
self._message_queue = Queue()
[文档]
async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
return await (await self._get_agent(agent)).save_state()
[文档]
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
await (await self._get_agent(agent)).load_state(state)
[文档]
async def register_factory(
self,
type: str | AgentType,
agent_factory: Callable[[], T | Awaitable[T]],
*,
expected_class: type[T] | None = None,
) -> AgentType:
if isinstance(type, str):
type = AgentType(type)
if type.type in self._agent_factories:
raise ValueError(f"Agent with type {type} already exists.")
async def factory_wrapper() -> T:
maybe_agent_instance = agent_factory()
if inspect.isawaitable(maybe_agent_instance):
agent_instance = await maybe_agent_instance
else:
agent_instance = maybe_agent_instance
if expected_class is not None and type_func_alias(agent_instance) != expected_class:
raise ValueError("Factory registered using the wrong type.")
return agent_instance
self._agent_factories[type.type] = factory_wrapper
return type
[文档]
async def register_agent_instance(
self,
agent_instance: Agent,
agent_id: AgentId,
) -> AgentId:
def agent_factory() -> Agent:
raise RuntimeError(
"Agent factory was invoked for an agent instance that was not registered. This is likely due to the agent type being incorrectly subscribed to a topic. If this exception occurs when publishing a message to the DefaultTopicId, then it is likely that `skip_class_subscriptions` needs to be turned off when registering the agent."
)
if agent_id in self._instantiated_agents:
raise ValueError(f"Agent with id {agent_id} already exists.")
if agent_id.type not in self._agent_factories:
self._agent_factories[agent_id.type] = agent_factory
self._agent_instance_types[agent_id.type] = type_func_alias(agent_instance)
else:
if self._agent_factories[agent_id.type].__code__ != agent_factory.__code__:
raise ValueError("Agent factories and agent instances cannot be registered to the same type.")
if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance):
raise ValueError("Agent instances must be the same object type.")
await agent_instance.bind_id_and_runtime(id=agent_id, runtime=self)
self._instantiated_agents[agent_id] = agent_instance
return agent_id
async def _invoke_agent_factory(
self,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
agent_id: AgentId,
) -> T:
with AgentInstantiationContext.populate_context((self, agent_id)):
try:
if len(inspect.signature(agent_factory).parameters) == 0:
factory_one = cast(Callable[[], T], agent_factory)
agent = factory_one()
elif len(inspect.signature(agent_factory).parameters) == 2:
warnings.warn(
"Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.",
stacklevel=2,
)
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
agent = factory_two(self, agent_id)
else:
raise ValueError("Agent factory must take 0 or 2 arguments.")
if inspect.isawaitable(agent):
agent = cast(T, await agent)
return agent
except BaseException as e:
event_logger.info(
AgentConstructionExceptionEvent(
agent_id=agent_id,
exception=e,
)
)
logger.error(f"Error constructing agent {agent_id}", exc_info=True)
raise
async def _get_agent(self, agent_id: AgentId) -> Agent:
if agent_id in self._instantiated_agents:
return self._instantiated_agents[agent_id]
if agent_id.type not in self._agent_factories:
raise LookupError(f"Agent with name {agent_id.type} not found.")
agent_factory = self._agent_factories[agent_id.type]
agent = await self._invoke_agent_factory(agent_factory, agent_id)
self._instantiated_agents[agent_id] = agent
return agent
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
[文档]
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
if id.type not in self._agent_factories:
raise LookupError(f"Agent with name {id.type} not found.")
# TODO: check if remote
agent_instance = await self._get_agent(id)
if not isinstance(agent_instance, type):
raise TypeError(
f"Agent with name {id.type} is not of type {type.__name__}. It is of type {type_func_alias(agent_instance).__name__}"
)
return agent_instance
[文档]
async def add_subscription(self, subscription: Subscription) -> None:
await self._subscription_manager.add_subscription(subscription)
[文档]
async def remove_subscription(self, id: str) -> None:
await self._subscription_manager.remove_subscription(id)
[文档]
async def get(
self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True
) -> AgentId:
return await get_impl(
id_or_type=id_or_type,
key=key,
lazy=lazy,
instance_getter=self._get_agent,
)
[文档]
def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
self._serialization_registry.add_serializer(serializer)
def _try_serialize(self, message: Any) -> str:
try:
type_name = self._serialization_registry.type_name(message)
return self._serialization_registry.serialize(
message, type_name=type_name, data_content_type=JSON_DATA_CONTENT_TYPE
).decode("utf-8")
except ValueError:
return "Message could not be serialized"