autogen_core._routed_agent 源代码

import logging
from functools import wraps
from typing import (
    Any,
    Callable,
    Coroutine,
    DefaultDict,
    List,
    Literal,
    Protocol,
    Sequence,
    Tuple,
    Type,
    TypeVar,
    cast,
    get_type_hints,
    overload,
    runtime_checkable,
)

from ._base_agent import BaseAgent
from ._message_context import MessageContext
from ._serialization import MessageSerializer, try_get_known_serializers_for_type
from ._type_helpers import AnyType, get_types
from .exceptions import CantHandleException

logger = logging.getLogger("autogen_core")

AgentT = TypeVar("AgentT")
ReceivesT = TypeVar("ReceivesT")
ProducesT = TypeVar("ProducesT", covariant=True)

# TODO: Generic typevar bound binding U to agent type
# Can't do because python doesnt support it


# Pyright and mypy disagree on the variance of ReceivesT. Mypy thinks it should be contravariant here.
# Revisit this later to see if we can remove the ignore.
@runtime_checkable
class MessageHandler(Protocol[AgentT, ReceivesT, ProducesT]):  # type: ignore
    target_types: Sequence[type]
    produces_types: Sequence[type]
    is_message_handler: Literal[True]
    router: Callable[[ReceivesT, MessageContext], bool]

    # agent_instance binds to self in the method
    @staticmethod
    async def __call__(agent_instance: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT: ...


# NOTE: this works on concrete types and not inheritance
# TODO: Use a protocol for the outer function to check checked arg names


@overload
def message_handler(
    func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[AgentT, ReceivesT, ProducesT]: ...


@overload
def message_handler(
    func: None = None,
    *,
    match: None = ...,
    strict: bool = ...,
) -> Callable[
    [Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
    MessageHandler[AgentT, ReceivesT, ProducesT],
]: ...


@overload
def message_handler(
    func: None = None,
    *,
    match: Callable[[ReceivesT, MessageContext], bool],
    strict: bool = ...,
) -> Callable[
    [Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
    MessageHandler[AgentT, ReceivesT, ProducesT],
]: ...


[文档] def message_handler( func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None, *, strict: bool = True, match: None | Callable[[ReceivesT, MessageContext], bool] = None, ) -> ( Callable[ [Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]], MessageHandler[AgentT, ReceivesT, ProducesT], ] | MessageHandler[AgentT, ReceivesT, ProducesT] ): """通用消息处理器的装饰器。 将此装饰器添加到:class:`RoutedAgent`类中用于处理事件和RPC消息的方法上。 这些方法必须遵循特定的签名规范才能生效: - 方法必须是`async`异步方法 - 方法必须使用`@message_handler`装饰器装饰 - 方法必须恰好有3个参数: 1. `self` 2. `message`: 要处理的消息,必须使用其目标处理的消息类型进行类型提示 3. `ctx`: 一个:class:`autogen_core.MessageContext`对象 - 方法必须类型提示其可能返回的响应消息类型,如果不返回任何内容则可以返回`None` 处理器可以通过接受消息类型的Union来处理多种消息类型。也可以通过返回消息类型的Union来返回多种消息类型。 Args: func: 要被装饰的函数 strict: 如果为`True`,当消息类型或返回类型不在目标类型中时会抛出异常。如果为`False`,则只记录警告 match: 一个接收消息和上下文作为参数并返回布尔值的函数。用于消息类型之后的二级路由。对于处理相同消息类型的处理器,匹配函数会按照处理器名称的字母顺序应用,第一个匹配的处理器会被调用,其余跳过。如果为`None`,则按字母顺序调用第一个匹配该消息类型的处理器。 """ def decorator( func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]], ) -> MessageHandler[AgentT, ReceivesT, ProducesT]: type_hints = get_type_hints(func) if "message" not in type_hints: raise AssertionError("message parameter not found in function signature") if "return" not in type_hints: raise AssertionError("return parameter not found in function signature") # Get the type of the message parameter target_types = get_types(type_hints["message"]) if target_types is None: raise AssertionError("Message type not found") # print(type_hints) return_types = get_types(type_hints["return"]) if return_types is None: raise AssertionError("Return type not found") # Convert target_types to list and stash @wraps(func) async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT: if type(message) not in target_types: if strict: raise CantHandleException(f"Message type {type(message)} not in target types {target_types}") else: logger.warning(f"Message type {type(message)} not in target types {target_types}") return_value = await func(self, message, ctx) if AnyType not in return_types and type(return_value) not in return_types: if strict: raise ValueError(f"Return type {type(return_value)} not in return types {return_types}") else: logger.warning(f"Return type {type(return_value)} not in return types {return_types}") return return_value wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper) wrapper_handler.target_types = list(target_types) wrapper_handler.produces_types = list(return_types) wrapper_handler.is_message_handler = True wrapper_handler.router = match or (lambda _message, _ctx: True) return wrapper_handler if func is None and not callable(func): return decorator elif callable(func): return decorator(func) else: raise ValueError("Invalid arguments")
@overload def event( func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]], ) -> MessageHandler[AgentT, ReceivesT, None]: ... @overload def event( func: None = None, *, match: None = ..., strict: bool = ..., ) -> Callable[ [Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]], MessageHandler[AgentT, ReceivesT, None], ]: ... @overload def event( func: None = None, *, match: Callable[[ReceivesT, MessageContext], bool], strict: bool = ..., ) -> Callable[ [Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]], MessageHandler[AgentT, ReceivesT, None], ]: ...
[文档] def event( func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]] = None, *, strict: bool = True, match: None | Callable[[ReceivesT, MessageContext], bool] = None, ) -> ( Callable[ [Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]], MessageHandler[AgentT, ReceivesT, None], ] | MessageHandler[AgentT, ReceivesT, None] ): """事件消息处理器的装饰器。 将此装饰器添加到:class:`RoutedAgent`类中用于处理事件消息的方法上。 这些方法必须遵循特定的签名规范才能生效: - 方法必须是`async`异步方法 - 方法必须使用`@message_handler`装饰器装饰 - 方法必须恰好有3个参数: 1. `self` 2. `message`: 要处理的事件消息,必须使用其目标处理的消息类型进行类型提示 3. `ctx`: 一个:class:`autogen_core.MessageContext`对象 - 方法必须返回`None` 处理器可以通过接受消息类型的Union来处理多种消息类型。 Args: func: 要被装饰的函数 strict: 如果为`True`,当消息类型不在目标类型中时会抛出异常。如果为`False`,则只记录警告 match: 一个接收消息和上下文作为参数并返回布尔值的函数。用于消息类型之后的二级路由。对于处理相同消息类型的处理器,匹配函数会按照处理器名称的字母顺序应用,第一个匹配的处理器会被调用,其余跳过。如果为`None`,则按字母顺序调用第一个匹配该消息类型的处理器。 """ def decorator( func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]], ) -> MessageHandler[AgentT, ReceivesT, None]: type_hints = get_type_hints(func) if "message" not in type_hints: raise AssertionError("message parameter not found in function signature") if "return" not in type_hints: raise AssertionError("return parameter not found in function signature") # Get the type of the message parameter target_types = get_types(type_hints["message"]) if target_types is None: raise AssertionError("Message type not found. Please provide a type hint for the message parameter.") return_types = get_types(type_hints["return"]) if return_types is None: raise AssertionError("Return type not found. Please use `None` as the type hint of the return type.") # Convert target_types to list and stash @wraps(func) async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None: if type(message) not in target_types: if strict: raise CantHandleException(f"Message type {type(message)} not in target types {target_types}") else: logger.warning(f"Message type {type(message)} not in target types {target_types}") return_value = await func(self, message, ctx) # type: ignore if return_value is not None: if strict: raise ValueError(f"Return type {type(return_value)} is not None.") else: logger.warning(f"Return type {type(return_value)} is not None. It will be ignored.") return None wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, None], wrapper) wrapper_handler.target_types = list(target_types) wrapper_handler.produces_types = list(return_types) wrapper_handler.is_message_handler = True # Wrap the match function with a check on the is_rpc flag. wrapper_handler.router = lambda _message, _ctx: (not _ctx.is_rpc) and (match(_message, _ctx) if match else True) return wrapper_handler if func is None and not callable(func): return decorator elif callable(func): return decorator(func) else: raise ValueError("Invalid arguments")
@overload def rpc( func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]], ) -> MessageHandler[AgentT, ReceivesT, ProducesT]: ... @overload def rpc( func: None = None, *, match: None = ..., strict: bool = ..., ) -> Callable[ [Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]], MessageHandler[AgentT, ReceivesT, ProducesT], ]: ... @overload def rpc( func: None = None, *, match: Callable[[ReceivesT, MessageContext], bool], strict: bool = ..., ) -> Callable[ [Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]], MessageHandler[AgentT, ReceivesT, ProducesT], ]: ...
[文档] def rpc( func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None, *, strict: bool = True, match: None | Callable[[ReceivesT, MessageContext], bool] = None, ) -> ( Callable[ [Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]], MessageHandler[AgentT, ReceivesT, ProducesT], ] | MessageHandler[AgentT, ReceivesT, ProducesT] ): """RPC消息处理器的装饰器。 将此装饰器添加到:class:`RoutedAgent`类中用于处理RPC消息的方法上。 这些方法必须遵循特定的签名规范才能生效: - 方法必须是`async`异步方法 - 方法必须使用`@message_handler`装饰器装饰 - 方法必须恰好有3个参数: 1. `self` 2. `message`: 要处理的消息,必须使用其目标处理的消息类型进行类型提示 3. `ctx`: 一个:class:`autogen_core.MessageContext`对象 - 方法必须类型提示其可能返回的响应消息类型,如果不返回任何内容则可以返回`None` 处理器可以通过接受消息类型的Union来处理多种消息类型。也可以通过返回消息类型的Union来返回多种消息类型。 Args: func: 要被装饰的函数 strict: 如果为`True`,当消息类型或返回类型不在目标类型中时会抛出异常。如果为`False`,则只记录警告 match: 一个接收消息和上下文作为参数并返回布尔值的函数。用于消息类型之后的二级路由。对于处理相同消息类型的处理器,匹配函数会按照处理器名称的字母顺序应用,第一个匹配的处理器会被调用,其余跳过。如果为`None`,则按字母顺序调用第一个匹配该消息类型的处理器。 """ def decorator( func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]], ) -> MessageHandler[AgentT, ReceivesT, ProducesT]: type_hints = get_type_hints(func) if "message" not in type_hints: raise AssertionError("message parameter not found in function signature") if "return" not in type_hints: raise AssertionError("return parameter not found in function signature") # Get the type of the message parameter target_types = get_types(type_hints["message"]) if target_types is None: raise AssertionError("Message type not found") # print(type_hints) return_types = get_types(type_hints["return"]) if return_types is None: raise AssertionError("Return type not found") # Convert target_types to list and stash @wraps(func) async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT: if type(message) not in target_types: if strict: raise CantHandleException(f"Message type {type(message)} not in target types {target_types}") else: logger.warning(f"Message type {type(message)} not in target types {target_types}") return_value = await func(self, message, ctx) if AnyType not in return_types and type(return_value) not in return_types: if strict: raise ValueError(f"Return type {type(return_value)} not in return types {return_types}") else: logger.warning(f"Return type {type(return_value)} not in return types {return_types}") return return_value wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper) wrapper_handler.target_types = list(target_types) wrapper_handler.produces_types = list(return_types) wrapper_handler.is_message_handler = True wrapper_handler.router = lambda _message, _ctx: (_ctx.is_rpc) and (match(_message, _ctx) if match else True) return wrapper_handler if func is None and not callable(func): return decorator elif callable(func): return decorator(func) else: raise ValueError("Invalid arguments")
[文档] class RoutedAgent(BaseAgent): """一个基类,用于根据消息类型和可选匹配函数将消息路由到相应的处理程序。 要创建路由代理,请继承此类并添加消息处理方法,这些方法需使用 :func:`event` 或 :func:`rpc` 装饰器进行装饰。 示例: .. code-block:: python from dataclasses import dataclass from autogen_core import MessageContext from autogen_core import RoutedAgent, event, rpc @dataclass class Message: pass @dataclass class MessageWithContent: content: str @dataclass class Response: pass class MyAgent(RoutedAgent): def __init__(self): super().__init__("MyAgent") @event async def handle_event_message(self, message: Message, ctx: MessageContext) -> None: assert ctx.topic_id is not None await self.publish_message(MessageWithContent("event handled"), ctx.topic_id) @rpc(match=lambda message, ctx: message.content == "special") # type: ignore async def handle_special_rpc_message(self, message: MessageWithContent, ctx: MessageContext) -> Response: return Response() """ def __init__(self, description: str) -> None: # Self is already bound to the handlers self._handlers: DefaultDict[ Type[Any], List[MessageHandler[RoutedAgent, Any, Any]], ] = DefaultDict(list) handlers = self._discover_handlers() for message_handler in handlers: for target_type in message_handler.target_types: self._handlers[target_type].append(message_handler) super().__init__(description)
[文档] async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any | None: """通过将消息路由到适当的消息处理程序来处理消息。 不要在子类中重写此方法。相反,应添加使用 :func:`event` 或 :func:`rpc` 装饰器装饰的消息处理方法。""" key_type: Type[Any] = type(message) # type: ignore handlers = self._handlers.get(key_type) # type: ignore if handlers is not None: # Iterate over all handlers for this matching message type. # Call the first handler whose router returns True and then return the result. for h in handlers: if h.router(message, ctx): return await h(self, message, ctx) return await self.on_unhandled_message(message, ctx) # type: ignore
[文档] async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: """当接收到没有匹配消息处理程序的消息时调用。 默认实现会记录一条信息日志。""" logger.info(f"Unhandled message: {message}")
@classmethod def _discover_handlers(cls) -> Sequence[MessageHandler[Any, Any, Any]]: handlers: List[MessageHandler[Any, Any, Any]] = [] for attr in dir(cls): if callable(getattr(cls, attr, None)): # Since we are getting it from the class, self is not bound handler = getattr(cls, attr) if hasattr(handler, "is_message_handler"): handlers.append(cast(MessageHandler[Any, Any, Any], handler)) return handlers @classmethod def _handles_types(cls) -> List[Tuple[Type[Any], List[MessageSerializer[Any]]]]: # TODO handle deduplication handlers = cls._discover_handlers() types: List[Tuple[Type[Any], List[MessageSerializer[Any]]]] = [] types.extend(cls.internal_extra_handles_types) for handler in handlers: for t in handler.target_types: # TODO: support different serializers serializers = try_get_known_serializers_for_type(t) if len(serializers) == 0: raise ValueError(f"No serializers found for type {t}.") types.append((t, try_get_known_serializers_for_type(t))) return types