import logging
from functools import wraps
from typing import (
Any,
Callable,
Coroutine,
DefaultDict,
List,
Literal,
Protocol,
Sequence,
Tuple,
Type,
TypeVar,
cast,
get_type_hints,
overload,
runtime_checkable,
)
from ._base_agent import BaseAgent
from ._message_context import MessageContext
from ._serialization import MessageSerializer, try_get_known_serializers_for_type
from ._type_helpers import AnyType, get_types
from .exceptions import CantHandleException
logger = logging.getLogger("autogen_core")
AgentT = TypeVar("AgentT")
ReceivesT = TypeVar("ReceivesT")
ProducesT = TypeVar("ProducesT", covariant=True)
# TODO: Generic typevar bound binding U to agent type
# Can't do because python doesnt support it
# Pyright and mypy disagree on the variance of ReceivesT. Mypy thinks it should be contravariant here.
# Revisit this later to see if we can remove the ignore.
@runtime_checkable
class MessageHandler(Protocol[AgentT, ReceivesT, ProducesT]): # type: ignore
target_types: Sequence[type]
produces_types: Sequence[type]
is_message_handler: Literal[True]
router: Callable[[ReceivesT, MessageContext], bool]
# agent_instance binds to self in the method
@staticmethod
async def __call__(agent_instance: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT: ...
# NOTE: this works on concrete types and not inheritance
# TODO: Use a protocol for the outer function to check checked arg names
@overload
def message_handler(
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[AgentT, ReceivesT, ProducesT]: ...
@overload
def message_handler(
func: None = None,
*,
match: None = ...,
strict: bool = ...,
) -> Callable[
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[AgentT, ReceivesT, ProducesT],
]: ...
@overload
def message_handler(
func: None = None,
*,
match: Callable[[ReceivesT, MessageContext], bool],
strict: bool = ...,
) -> Callable[
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[AgentT, ReceivesT, ProducesT],
]: ...
[文档]
def message_handler(
func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None,
*,
strict: bool = True,
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
) -> (
Callable[
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[AgentT, ReceivesT, ProducesT],
]
| MessageHandler[AgentT, ReceivesT, ProducesT]
):
"""通用消息处理器的装饰器。
将此装饰器添加到:class:`RoutedAgent`类中用于处理事件和RPC消息的方法上。
这些方法必须遵循特定的签名规范才能生效:
- 方法必须是`async`异步方法
- 方法必须使用`@message_handler`装饰器装饰
- 方法必须恰好有3个参数:
1. `self`
2. `message`: 要处理的消息,必须使用其目标处理的消息类型进行类型提示
3. `ctx`: 一个:class:`autogen_core.MessageContext`对象
- 方法必须类型提示其可能返回的响应消息类型,如果不返回任何内容则可以返回`None`
处理器可以通过接受消息类型的Union来处理多种消息类型。也可以通过返回消息类型的Union来返回多种消息类型。
Args:
func: 要被装饰的函数
strict: 如果为`True`,当消息类型或返回类型不在目标类型中时会抛出异常。如果为`False`,则只记录警告
match: 一个接收消息和上下文作为参数并返回布尔值的函数。用于消息类型之后的二级路由。对于处理相同消息类型的处理器,匹配函数会按照处理器名称的字母顺序应用,第一个匹配的处理器会被调用,其余跳过。如果为`None`,则按字母顺序调用第一个匹配该消息类型的处理器。
"""
def decorator(
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[AgentT, ReceivesT, ProducesT]:
type_hints = get_type_hints(func)
if "message" not in type_hints:
raise AssertionError("message parameter not found in function signature")
if "return" not in type_hints:
raise AssertionError("return parameter not found in function signature")
# Get the type of the message parameter
target_types = get_types(type_hints["message"])
if target_types is None:
raise AssertionError("Message type not found")
# print(type_hints)
return_types = get_types(type_hints["return"])
if return_types is None:
raise AssertionError("Return type not found")
# Convert target_types to list and stash
@wraps(func)
async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT:
if type(message) not in target_types:
if strict:
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
else:
logger.warning(f"Message type {type(message)} not in target types {target_types}")
return_value = await func(self, message, ctx)
if AnyType not in return_types and type(return_value) not in return_types:
if strict:
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
else:
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
return return_value
wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper)
wrapper_handler.target_types = list(target_types)
wrapper_handler.produces_types = list(return_types)
wrapper_handler.is_message_handler = True
wrapper_handler.router = match or (lambda _message, _ctx: True)
return wrapper_handler
if func is None and not callable(func):
return decorator
elif callable(func):
return decorator(func)
else:
raise ValueError("Invalid arguments")
@overload
def event(
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]],
) -> MessageHandler[AgentT, ReceivesT, None]: ...
@overload
def event(
func: None = None,
*,
match: None = ...,
strict: bool = ...,
) -> Callable[
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
MessageHandler[AgentT, ReceivesT, None],
]: ...
@overload
def event(
func: None = None,
*,
match: Callable[[ReceivesT, MessageContext], bool],
strict: bool = ...,
) -> Callable[
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
MessageHandler[AgentT, ReceivesT, None],
]: ...
[文档]
def event(
func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]] = None,
*,
strict: bool = True,
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
) -> (
Callable[
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
MessageHandler[AgentT, ReceivesT, None],
]
| MessageHandler[AgentT, ReceivesT, None]
):
"""事件消息处理器的装饰器。
将此装饰器添加到:class:`RoutedAgent`类中用于处理事件消息的方法上。
这些方法必须遵循特定的签名规范才能生效:
- 方法必须是`async`异步方法
- 方法必须使用`@message_handler`装饰器装饰
- 方法必须恰好有3个参数:
1. `self`
2. `message`: 要处理的事件消息,必须使用其目标处理的消息类型进行类型提示
3. `ctx`: 一个:class:`autogen_core.MessageContext`对象
- 方法必须返回`None`
处理器可以通过接受消息类型的Union来处理多种消息类型。
Args:
func: 要被装饰的函数
strict: 如果为`True`,当消息类型不在目标类型中时会抛出异常。如果为`False`,则只记录警告
match: 一个接收消息和上下文作为参数并返回布尔值的函数。用于消息类型之后的二级路由。对于处理相同消息类型的处理器,匹配函数会按照处理器名称的字母顺序应用,第一个匹配的处理器会被调用,其余跳过。如果为`None`,则按字母顺序调用第一个匹配该消息类型的处理器。
"""
def decorator(
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]],
) -> MessageHandler[AgentT, ReceivesT, None]:
type_hints = get_type_hints(func)
if "message" not in type_hints:
raise AssertionError("message parameter not found in function signature")
if "return" not in type_hints:
raise AssertionError("return parameter not found in function signature")
# Get the type of the message parameter
target_types = get_types(type_hints["message"])
if target_types is None:
raise AssertionError("Message type not found. Please provide a type hint for the message parameter.")
return_types = get_types(type_hints["return"])
if return_types is None:
raise AssertionError("Return type not found. Please use `None` as the type hint of the return type.")
# Convert target_types to list and stash
@wraps(func)
async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None:
if type(message) not in target_types:
if strict:
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
else:
logger.warning(f"Message type {type(message)} not in target types {target_types}")
return_value = await func(self, message, ctx) # type: ignore
if return_value is not None:
if strict:
raise ValueError(f"Return type {type(return_value)} is not None.")
else:
logger.warning(f"Return type {type(return_value)} is not None. It will be ignored.")
return None
wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, None], wrapper)
wrapper_handler.target_types = list(target_types)
wrapper_handler.produces_types = list(return_types)
wrapper_handler.is_message_handler = True
# Wrap the match function with a check on the is_rpc flag.
wrapper_handler.router = lambda _message, _ctx: (not _ctx.is_rpc) and (match(_message, _ctx) if match else True)
return wrapper_handler
if func is None and not callable(func):
return decorator
elif callable(func):
return decorator(func)
else:
raise ValueError("Invalid arguments")
@overload
def rpc(
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[AgentT, ReceivesT, ProducesT]: ...
@overload
def rpc(
func: None = None,
*,
match: None = ...,
strict: bool = ...,
) -> Callable[
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[AgentT, ReceivesT, ProducesT],
]: ...
@overload
def rpc(
func: None = None,
*,
match: Callable[[ReceivesT, MessageContext], bool],
strict: bool = ...,
) -> Callable[
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[AgentT, ReceivesT, ProducesT],
]: ...
[文档]
def rpc(
func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None,
*,
strict: bool = True,
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
) -> (
Callable[
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[AgentT, ReceivesT, ProducesT],
]
| MessageHandler[AgentT, ReceivesT, ProducesT]
):
"""RPC消息处理器的装饰器。
将此装饰器添加到:class:`RoutedAgent`类中用于处理RPC消息的方法上。
这些方法必须遵循特定的签名规范才能生效:
- 方法必须是`async`异步方法
- 方法必须使用`@message_handler`装饰器装饰
- 方法必须恰好有3个参数:
1. `self`
2. `message`: 要处理的消息,必须使用其目标处理的消息类型进行类型提示
3. `ctx`: 一个:class:`autogen_core.MessageContext`对象
- 方法必须类型提示其可能返回的响应消息类型,如果不返回任何内容则可以返回`None`
处理器可以通过接受消息类型的Union来处理多种消息类型。也可以通过返回消息类型的Union来返回多种消息类型。
Args:
func: 要被装饰的函数
strict: 如果为`True`,当消息类型或返回类型不在目标类型中时会抛出异常。如果为`False`,则只记录警告
match: 一个接收消息和上下文作为参数并返回布尔值的函数。用于消息类型之后的二级路由。对于处理相同消息类型的处理器,匹配函数会按照处理器名称的字母顺序应用,第一个匹配的处理器会被调用,其余跳过。如果为`None`,则按字母顺序调用第一个匹配该消息类型的处理器。
"""
def decorator(
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[AgentT, ReceivesT, ProducesT]:
type_hints = get_type_hints(func)
if "message" not in type_hints:
raise AssertionError("message parameter not found in function signature")
if "return" not in type_hints:
raise AssertionError("return parameter not found in function signature")
# Get the type of the message parameter
target_types = get_types(type_hints["message"])
if target_types is None:
raise AssertionError("Message type not found")
# print(type_hints)
return_types = get_types(type_hints["return"])
if return_types is None:
raise AssertionError("Return type not found")
# Convert target_types to list and stash
@wraps(func)
async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT:
if type(message) not in target_types:
if strict:
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
else:
logger.warning(f"Message type {type(message)} not in target types {target_types}")
return_value = await func(self, message, ctx)
if AnyType not in return_types and type(return_value) not in return_types:
if strict:
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
else:
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
return return_value
wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper)
wrapper_handler.target_types = list(target_types)
wrapper_handler.produces_types = list(return_types)
wrapper_handler.is_message_handler = True
wrapper_handler.router = lambda _message, _ctx: (_ctx.is_rpc) and (match(_message, _ctx) if match else True)
return wrapper_handler
if func is None and not callable(func):
return decorator
elif callable(func):
return decorator(func)
else:
raise ValueError("Invalid arguments")
[文档]
class RoutedAgent(BaseAgent):
"""一个基类,用于根据消息类型和可选匹配函数将消息路由到相应的处理程序。
要创建路由代理,请继承此类并添加消息处理方法,这些方法需使用 :func:`event` 或 :func:`rpc` 装饰器进行装饰。
示例:
.. code-block:: python
from dataclasses import dataclass
from autogen_core import MessageContext
from autogen_core import RoutedAgent, event, rpc
@dataclass
class Message:
pass
@dataclass
class MessageWithContent:
content: str
@dataclass
class Response:
pass
class MyAgent(RoutedAgent):
def __init__(self):
super().__init__("MyAgent")
@event
async def handle_event_message(self, message: Message, ctx: MessageContext) -> None:
assert ctx.topic_id is not None
await self.publish_message(MessageWithContent("event handled"), ctx.topic_id)
@rpc(match=lambda message, ctx: message.content == "special") # type: ignore
async def handle_special_rpc_message(self, message: MessageWithContent, ctx: MessageContext) -> Response:
return Response()
"""
def __init__(self, description: str) -> None:
# Self is already bound to the handlers
self._handlers: DefaultDict[
Type[Any],
List[MessageHandler[RoutedAgent, Any, Any]],
] = DefaultDict(list)
handlers = self._discover_handlers()
for message_handler in handlers:
for target_type in message_handler.target_types:
self._handlers[target_type].append(message_handler)
super().__init__(description)
[文档]
async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any | None:
"""通过将消息路由到适当的消息处理程序来处理消息。
不要在子类中重写此方法。相反,应添加使用 :func:`event` 或 :func:`rpc` 装饰器装饰的消息处理方法。"""
key_type: Type[Any] = type(message) # type: ignore
handlers = self._handlers.get(key_type) # type: ignore
if handlers is not None:
# Iterate over all handlers for this matching message type.
# Call the first handler whose router returns True and then return the result.
for h in handlers:
if h.router(message, ctx):
return await h(self, message, ctx)
return await self.on_unhandled_message(message, ctx) # type: ignore
[文档]
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
"""当接收到没有匹配消息处理程序的消息时调用。
默认实现会记录一条信息日志。"""
logger.info(f"Unhandled message: {message}")
@classmethod
def _discover_handlers(cls) -> Sequence[MessageHandler[Any, Any, Any]]:
handlers: List[MessageHandler[Any, Any, Any]] = []
for attr in dir(cls):
if callable(getattr(cls, attr, None)):
# Since we are getting it from the class, self is not bound
handler = getattr(cls, attr)
if hasattr(handler, "is_message_handler"):
handlers.append(cast(MessageHandler[Any, Any, Any], handler))
return handlers
@classmethod
def _handles_types(cls) -> List[Tuple[Type[Any], List[MessageSerializer[Any]]]]:
# TODO handle deduplication
handlers = cls._discover_handlers()
types: List[Tuple[Type[Any], List[MessageSerializer[Any]]]] = []
types.extend(cls.internal_extra_handles_types)
for handler in handlers:
for t in handler.target_types:
# TODO: support different serializers
serializers = try_get_known_serializers_for_type(t)
if len(serializers) == 0:
raise ValueError(f"No serializers found for type {t}.")
types.append((t, try_get_known_serializers_for_type(t)))
return types