autogen_agentchat.teams._group_chat._swarm_group_chat 源代码

import asyncio
from typing import Any, Callable, List, Mapping, Sequence

from autogen_core import AgentRuntime, Component, ComponentModel
from pydantic import BaseModel

from ...base import ChatAgent, TerminationCondition
from ...messages import BaseAgentEvent, BaseChatMessage, HandoffMessage, MessageFactory
from ...state import SwarmManagerState
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
from ._events import GroupChatTermination


class SwarmGroupChatManager(BaseGroupChatManager):
    """一个仅基于交接消息选择下一位发言者的群聊管理器。"""

    def __init__(
        self,
        name: str,
        group_topic_type: str,
        output_topic_type: str,
        participant_topic_types: List[str],
        participant_names: List[str],
        participant_descriptions: List[str],
        output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
        termination_condition: TerminationCondition | None,
        max_turns: int | None,
        message_factory: MessageFactory,
        emit_team_events: bool,
    ) -> None:
        super().__init__(
            name,
            group_topic_type,
            output_topic_type,
            participant_topic_types,
            participant_names,
            participant_descriptions,
            output_message_queue,
            termination_condition,
            max_turns,
            message_factory,
            emit_team_events,
        )
        self._current_speaker = self._participant_names[0]

    async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
        """验证群聊的初始消息。"""
        # Check if any of the start messages is a handoff message.
        if messages:
            for message in messages:
                if isinstance(message, HandoffMessage):
                    if message.target not in self._participant_names:
                        raise ValueError(
                            f"The target {message.target} is not one of the participants {self._participant_names}. "
                            "If you are resuming Swarm with a new HandoffMessage make sure to set the target to a valid participant as the target."
                        )
                    return

        # Check if there is a handoff message in the thread that is not targeting a valid participant.
        for existing_message in reversed(self._message_thread):
            if isinstance(existing_message, HandoffMessage):
                if existing_message.target not in self._participant_names:
                    raise ValueError(
                        f"The existing handoff target {existing_message.target} is not one of the participants {self._participant_names}. "
                        "If you are resuming Swarm with a new task make sure to include in your task "
                        "a HandoffMessage with a valid participant as the target. For example, if you are "
                        "resuming from a HandoffTermination, make sure the new task is a HandoffMessage "
                        "with a valid participant as the target."
                    )
                # The latest handoff message should always target a valid participant.
                # Do not look past the latest handoff message.
                return

    async def reset(self) -> None:
        self._current_turn = 0
        self._message_thread.clear()
        if self._termination_condition is not None:
            await self._termination_condition.reset()
        self._current_speaker = self._participant_names[0]

    async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
        """根据交接消息从参与者中选择发言人。
        查找线程中最后的交接消息来确定下一位发言人。

        .. note::

            该方法始终返回单个发言人。
        """
        if len(thread) == 0:
            return [self._current_speaker]
        for message in reversed(thread):
            if isinstance(message, HandoffMessage):
                self._current_speaker = message.target
                # The latest handoff message should always target a valid participant.
                assert self._current_speaker in self._participant_names
                return [self._current_speaker]
        return self._current_speaker

    async def save_state(self) -> Mapping[str, Any]:
        state = SwarmManagerState(
            message_thread=[msg.dump() for msg in self._message_thread],
            current_turn=self._current_turn,
            current_speaker=self._current_speaker,
        )
        return state.model_dump()

    async def load_state(self, state: Mapping[str, Any]) -> None:
        swarm_state = SwarmManagerState.model_validate(state)
        self._message_thread = [self._message_factory.create(message) for message in swarm_state.message_thread]
        self._current_turn = swarm_state.current_turn
        self._current_speaker = swarm_state.current_speaker


class SwarmConfig(BaseModel):
    """Swarm 的声明式配置。"""

    participants: List[ComponentModel]
    termination_condition: ComponentModel | None = None
    max_turns: int | None = None
    emit_team_events: bool = False


[文档] class Swarm(BaseGroupChat, Component[SwarmConfig]): """一个仅基于交接消息选择下一位发言者的群聊团队。 参与者列表中的第一位成员是初始发言者。 下一位发言者根据当前发言者发送的 :class:`~autogen_agentchat.messages.HandoffMessage` 消息选择。 如果没有发送交接消息,则当前发言者继续发言。 Args: participants (List[ChatAgent]): 参与群聊的代理列表。列表中的第一个代理是初始发言者。 termination_condition (TerminationCondition, optional): 群聊的终止条件。默认为 None。 若无终止条件,群聊将无限运行。 max_turns (int, optional): 群聊停止前的最大轮次。默认为 None,表示无限制。 custom_message_types (List[type[BaseAgentEvent | BaseChatMessage]], optional): 群聊中使用的自定义消息类型列表。 如果使用自定义消息类型或代理会产生自定义消息类型,需在此指定。 确保自定义消息类型是 :class:`~autogen_agentchat.messages.BaseAgentEvent` 或 :class:`~autogen_agentchat.messages.BaseChatMessage` 的子类。 emit_team_events (bool, optional): 是否通过 :meth:`BaseGroupChat.run_stream` 发出团队事件。默认为 False。 基础示例: .. code-block:: python import asyncio from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.teams import Swarm from autogen_agentchat.conditions import MaxMessageTermination async def main() -> None: model_client = OpenAIChatCompletionClient(model="gpt-4o") agent1 = AssistantAgent( "Alice", model_client=model_client, handoffs=["Bob"], system_message="You are Alice and you only answer questions about yourself.", ) agent2 = AssistantAgent( "Bob", model_client=model_client, system_message="You are Bob and your birthday is on 1st January." ) termination = MaxMessageTermination(3) team = Swarm([agent1, agent2], termination_condition=termination) stream = team.run_stream(task="What is bob's birthday?") async for message in stream: print(message) asyncio.run(main()) 使用 :class:`~autogen_agentchat.conditions.HandoffTermination` 实现人机协作交接: .. code-block:: python import asyncio from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.teams import Swarm from autogen_agentchat.conditions import HandoffTermination, MaxMessageTermination from autogen_agentchat.ui import Console from autogen_agentchat.messages import HandoffMessage async def main() -> None: model_client = OpenAIChatCompletionClient(model="gpt-4o") agent = AssistantAgent( "Alice", model_client=model_client, handoffs=["user"], system_message="You are Alice and you only answer questions about yourself, ask the user for help if needed.", ) termination = HandoffTermination(target="user") | MaxMessageTermination(3) team = Swarm([agent], termination_condition=termination) # 开始对话。 await Console(team.run_stream(task="What is bob's birthday?")) # 根据用户反馈继续。 await Console( team.run_stream( task=HandoffMessage(source="user", target="Alice", content="Bob's birthday is on 1st January.") ) ) asyncio.run(main()) """ component_config_schema = SwarmConfig component_provider_override = "autogen_agentchat.teams.Swarm" # TODO: Add * to the constructor to separate the positional parameters from the kwargs. # This may be a breaking change so let's wait until a good time to do it. def __init__( self, participants: List[ChatAgent], termination_condition: TerminationCondition | None = None, max_turns: int | None = None, runtime: AgentRuntime | None = None, custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, emit_team_events: bool = False, ) -> None: super().__init__( participants, group_chat_manager_name="SwarmGroupChatManager", group_chat_manager_class=SwarmGroupChatManager, termination_condition=termination_condition, max_turns=max_turns, runtime=runtime, custom_message_types=custom_message_types, emit_team_events=emit_team_events, ) # The first participant must be able to produce handoff messages. first_participant = self._participants[0] if HandoffMessage not in first_participant.produced_message_types: raise ValueError("The first participant must be able to produce a handoff messages.") def _create_group_chat_manager_factory( self, name: str, group_topic_type: str, output_topic_type: str, participant_topic_types: List[str], participant_names: List[str], participant_descriptions: List[str], output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination], termination_condition: TerminationCondition | None, max_turns: int | None, message_factory: MessageFactory, ) -> Callable[[], SwarmGroupChatManager]: def _factory() -> SwarmGroupChatManager: return SwarmGroupChatManager( name, group_topic_type, output_topic_type, participant_topic_types, participant_names, participant_descriptions, output_message_queue, termination_condition, max_turns, message_factory, self._emit_team_events, ) return _factory
[文档] def _to_config(self) -> SwarmConfig: participants = [participant.dump_component() for participant in self._participants] termination_condition = self._termination_condition.dump_component() if self._termination_condition else None return SwarmConfig( participants=participants, termination_condition=termination_condition, max_turns=self._max_turns, emit_team_events=self._emit_team_events, )
[文档] @classmethod def _from_config(cls, config: SwarmConfig) -> "Swarm": participants = [ChatAgent.load_component(participant) for participant in config.participants] termination_condition = ( TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None ) return cls( participants, termination_condition=termination_condition, max_turns=config.max_turns, emit_team_events=config.emit_team_events, )