import asyncio
from typing import Any, Callable, List, Mapping, Sequence
from autogen_core import AgentRuntime, Component, ComponentModel
from pydantic import BaseModel
from ...base import ChatAgent, TerminationCondition
from ...messages import BaseAgentEvent, BaseChatMessage, HandoffMessage, MessageFactory
from ...state import SwarmManagerState
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
from ._events import GroupChatTermination
class SwarmGroupChatManager(BaseGroupChatManager):
"""一个仅基于交接消息选择下一位发言者的群聊管理器。"""
def __init__(
self,
name: str,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
emit_team_events: bool,
) -> None:
super().__init__(
name,
group_topic_type,
output_topic_type,
participant_topic_types,
participant_names,
participant_descriptions,
output_message_queue,
termination_condition,
max_turns,
message_factory,
emit_team_events,
)
self._current_speaker = self._participant_names[0]
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
"""验证群聊的初始消息。"""
# Check if any of the start messages is a handoff message.
if messages:
for message in messages:
if isinstance(message, HandoffMessage):
if message.target not in self._participant_names:
raise ValueError(
f"The target {message.target} is not one of the participants {self._participant_names}. "
"If you are resuming Swarm with a new HandoffMessage make sure to set the target to a valid participant as the target."
)
return
# Check if there is a handoff message in the thread that is not targeting a valid participant.
for existing_message in reversed(self._message_thread):
if isinstance(existing_message, HandoffMessage):
if existing_message.target not in self._participant_names:
raise ValueError(
f"The existing handoff target {existing_message.target} is not one of the participants {self._participant_names}. "
"If you are resuming Swarm with a new task make sure to include in your task "
"a HandoffMessage with a valid participant as the target. For example, if you are "
"resuming from a HandoffTermination, make sure the new task is a HandoffMessage "
"with a valid participant as the target."
)
# The latest handoff message should always target a valid participant.
# Do not look past the latest handoff message.
return
async def reset(self) -> None:
self._current_turn = 0
self._message_thread.clear()
if self._termination_condition is not None:
await self._termination_condition.reset()
self._current_speaker = self._participant_names[0]
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
"""根据交接消息从参与者中选择发言人。
查找线程中最后的交接消息来确定下一位发言人。
.. note::
该方法始终返回单个发言人。
"""
if len(thread) == 0:
return [self._current_speaker]
for message in reversed(thread):
if isinstance(message, HandoffMessage):
self._current_speaker = message.target
# The latest handoff message should always target a valid participant.
assert self._current_speaker in self._participant_names
return [self._current_speaker]
return self._current_speaker
async def save_state(self) -> Mapping[str, Any]:
state = SwarmManagerState(
message_thread=[msg.dump() for msg in self._message_thread],
current_turn=self._current_turn,
current_speaker=self._current_speaker,
)
return state.model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
swarm_state = SwarmManagerState.model_validate(state)
self._message_thread = [self._message_factory.create(message) for message in swarm_state.message_thread]
self._current_turn = swarm_state.current_turn
self._current_speaker = swarm_state.current_speaker
class SwarmConfig(BaseModel):
"""Swarm 的声明式配置。"""
participants: List[ComponentModel]
termination_condition: ComponentModel | None = None
max_turns: int | None = None
emit_team_events: bool = False
[文档]
class Swarm(BaseGroupChat, Component[SwarmConfig]):
"""一个仅基于交接消息选择下一位发言者的群聊团队。
参与者列表中的第一位成员是初始发言者。
下一位发言者根据当前发言者发送的 :class:`~autogen_agentchat.messages.HandoffMessage` 消息选择。
如果没有发送交接消息,则当前发言者继续发言。
Args:
participants (List[ChatAgent]): 参与群聊的代理列表。列表中的第一个代理是初始发言者。
termination_condition (TerminationCondition, optional): 群聊的终止条件。默认为 None。
若无终止条件,群聊将无限运行。
max_turns (int, optional): 群聊停止前的最大轮次。默认为 None,表示无限制。
custom_message_types (List[type[BaseAgentEvent | BaseChatMessage]], optional): 群聊中使用的自定义消息类型列表。
如果使用自定义消息类型或代理会产生自定义消息类型,需在此指定。
确保自定义消息类型是 :class:`~autogen_agentchat.messages.BaseAgentEvent` 或 :class:`~autogen_agentchat.messages.BaseChatMessage` 的子类。
emit_team_events (bool, optional): 是否通过 :meth:`BaseGroupChat.run_stream` 发出团队事件。默认为 False。
基础示例:
.. code-block:: python
import asyncio
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import Swarm
from autogen_agentchat.conditions import MaxMessageTermination
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent1 = AssistantAgent(
"Alice",
model_client=model_client,
handoffs=["Bob"],
system_message="You are Alice and you only answer questions about yourself.",
)
agent2 = AssistantAgent(
"Bob", model_client=model_client, system_message="You are Bob and your birthday is on 1st January."
)
termination = MaxMessageTermination(3)
team = Swarm([agent1, agent2], termination_condition=termination)
stream = team.run_stream(task="What is bob's birthday?")
async for message in stream:
print(message)
asyncio.run(main())
使用 :class:`~autogen_agentchat.conditions.HandoffTermination` 实现人机协作交接:
.. code-block:: python
import asyncio
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import Swarm
from autogen_agentchat.conditions import HandoffTermination, MaxMessageTermination
from autogen_agentchat.ui import Console
from autogen_agentchat.messages import HandoffMessage
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent(
"Alice",
model_client=model_client,
handoffs=["user"],
system_message="You are Alice and you only answer questions about yourself, ask the user for help if needed.",
)
termination = HandoffTermination(target="user") | MaxMessageTermination(3)
team = Swarm([agent], termination_condition=termination)
# 开始对话。
await Console(team.run_stream(task="What is bob's birthday?"))
# 根据用户反馈继续。
await Console(
team.run_stream(
task=HandoffMessage(source="user", target="Alice", content="Bob's birthday is on 1st January.")
)
)
asyncio.run(main())
"""
component_config_schema = SwarmConfig
component_provider_override = "autogen_agentchat.teams.Swarm"
# TODO: Add * to the constructor to separate the positional parameters from the kwargs.
# This may be a breaking change so let's wait until a good time to do it.
def __init__(
self,
participants: List[ChatAgent],
termination_condition: TerminationCondition | None = None,
max_turns: int | None = None,
runtime: AgentRuntime | None = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
emit_team_events: bool = False,
) -> None:
super().__init__(
participants,
group_chat_manager_name="SwarmGroupChatManager",
group_chat_manager_class=SwarmGroupChatManager,
termination_condition=termination_condition,
max_turns=max_turns,
runtime=runtime,
custom_message_types=custom_message_types,
emit_team_events=emit_team_events,
)
# The first participant must be able to produce handoff messages.
first_participant = self._participants[0]
if HandoffMessage not in first_participant.produced_message_types:
raise ValueError("The first participant must be able to produce a handoff messages.")
def _create_group_chat_manager_factory(
self,
name: str,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
) -> Callable[[], SwarmGroupChatManager]:
def _factory() -> SwarmGroupChatManager:
return SwarmGroupChatManager(
name,
group_topic_type,
output_topic_type,
participant_topic_types,
participant_names,
participant_descriptions,
output_message_queue,
termination_condition,
max_turns,
message_factory,
self._emit_team_events,
)
return _factory
[文档]
def _to_config(self) -> SwarmConfig:
participants = [participant.dump_component() for participant in self._participants]
termination_condition = self._termination_condition.dump_component() if self._termination_condition else None
return SwarmConfig(
participants=participants,
termination_condition=termination_condition,
max_turns=self._max_turns,
emit_team_events=self._emit_team_events,
)
[文档]
@classmethod
def _from_config(cls, config: SwarmConfig) -> "Swarm":
participants = [ChatAgent.load_component(participant) for participant in config.participants]
termination_condition = (
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
)
return cls(
participants,
termination_condition=termination_condition,
max_turns=config.max_turns,
emit_team_events=config.emit_team_events,
)