autogen_agentchat.teams._group_chat._round_robin_group_chat 源代码

import asyncio
from typing import Any, Callable, List, Mapping, Sequence

from autogen_core import AgentRuntime, Component, ComponentModel
from pydantic import BaseModel
from typing_extensions import Self

from ...base import ChatAgent, TerminationCondition
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory
from ...state import RoundRobinManagerState
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
from ._events import GroupChatTermination


class RoundRobinGroupChatManager(BaseGroupChatManager):
    """一个以轮询方式选择下一个发言者的群聊管理器。"""

    def __init__(
        self,
        name: str,
        group_topic_type: str,
        output_topic_type: str,
        participant_topic_types: List[str],
        participant_names: List[str],
        participant_descriptions: List[str],
        output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
        termination_condition: TerminationCondition | None,
        max_turns: int | None,
        message_factory: MessageFactory,
        emit_team_events: bool,
    ) -> None:
        super().__init__(
            name,
            group_topic_type,
            output_topic_type,
            participant_topic_types,
            participant_names,
            participant_descriptions,
            output_message_queue,
            termination_condition,
            max_turns,
            message_factory,
            emit_team_events,
        )
        self._next_speaker_index = 0

    async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
        pass

    async def reset(self) -> None:
        self._current_turn = 0
        self._message_thread.clear()
        if self._termination_condition is not None:
            await self._termination_condition.reset()
        self._next_speaker_index = 0

    async def save_state(self) -> Mapping[str, Any]:
        state = RoundRobinManagerState(
            message_thread=[message.dump() for message in self._message_thread],
            current_turn=self._current_turn,
            next_speaker_index=self._next_speaker_index,
        )
        return state.model_dump()

    async def load_state(self, state: Mapping[str, Any]) -> None:
        round_robin_state = RoundRobinManagerState.model_validate(state)
        self._message_thread = [self._message_factory.create(message) for message in round_robin_state.message_thread]
        self._current_turn = round_robin_state.current_turn
        self._next_speaker_index = round_robin_state.next_speaker_index

    async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
        """以轮询方式从参与者中选择一个发言者。

        .. note::

            该方法始终返回单个发言者。
        """
        current_speaker_index = self._next_speaker_index
        self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_names)
        current_speaker = self._participant_names[current_speaker_index]
        return current_speaker


class RoundRobinGroupChatConfig(BaseModel):
    """轮询式群聊的声明式配置。"""

    participants: List[ComponentModel]
    termination_condition: ComponentModel | None = None
    max_turns: int | None = None
    emit_team_events: bool = False


[文档] class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]): """一个运行群聊的团队,参与者以轮询方式轮流向所有人发布消息。 如果团队中只有一个参与者,该参与者将是唯一的发言者。 Args: participants (List[BaseChatAgent]): 群聊中的参与者列表。 termination_condition (TerminationCondition, optional): 群聊的终止条件。默认为 None。 如果没有终止条件,群聊将无限期运行。 max_turns (int, optional): 群聊停止前的最大轮次。默认为 None,表示无限制。 custom_message_types (List[type[BaseAgentEvent | BaseChatMessage]], optional): 群聊中使用的自定义消息类型列表。 如果使用自定义消息类型或您的代理生成了自定义消息类型,需要在此指定。 确保您的自定义消息类型是 :class:`~autogen_agentchat.messages.BaseAgentEvent` 或 :class:`~autogen_agentchat.messages.BaseChatMessage` 的子类。 emit_team_events (bool, optional): 是否通过 :meth:`BaseGroupChat.run_stream` 方法发出团队事件。默认为 False。 Raises: ValueError: 如果未提供参与者或参与者名称不唯一。 Examples: 包含一个带工具参与者的团队: .. code-block:: python import asyncio from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.teams import RoundRobinGroupChat from autogen_agentchat.conditions import TextMentionTermination from autogen_agentchat.ui import Console async def main() -> None: model_client = OpenAIChatCompletionClient(model="gpt-4o") async def get_weather(location: str) -> str: return f"The weather in {location} is sunny." assistant = AssistantAgent( "Assistant", model_client=model_client, tools=[get_weather], ) termination = TextMentionTermination("TERMINATE") team = RoundRobinGroupChat([assistant], termination_condition=termination) await Console(team.run_stream(task="What's the weather in New York?")) asyncio.run(main()) 包含多个参与者的团队: .. code-block:: python import asyncio from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.teams import RoundRobinGroupChat from autogen_agentchat.conditions import TextMentionTermination from autogen_agentchat.ui import Console async def main() -> None: model_client = OpenAIChatCompletionClient(model="gpt-4o") agent1 = AssistantAgent("Assistant1", model_client=model_client) agent2 = AssistantAgent("Assistant2", model_client=model_client) termination = TextMentionTermination("TERMINATE") team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) await Console(team.run_stream(task="Tell me some jokes.")) asyncio.run(main()) """ component_config_schema = RoundRobinGroupChatConfig component_provider_override = "autogen_agentchat.teams.RoundRobinGroupChat" # TODO: Add * to the constructor to separate the positional parameters from the kwargs. # This may be a breaking change so let's wait until a good time to do it. def __init__( self, participants: List[ChatAgent], termination_condition: TerminationCondition | None = None, max_turns: int | None = None, runtime: AgentRuntime | None = None, custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, emit_team_events: bool = False, ) -> None: super().__init__( participants, group_chat_manager_name="RoundRobinGroupChatManager", group_chat_manager_class=RoundRobinGroupChatManager, termination_condition=termination_condition, max_turns=max_turns, runtime=runtime, custom_message_types=custom_message_types, emit_team_events=emit_team_events, ) def _create_group_chat_manager_factory( self, name: str, group_topic_type: str, output_topic_type: str, participant_topic_types: List[str], participant_names: List[str], participant_descriptions: List[str], output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination], termination_condition: TerminationCondition | None, max_turns: int | None, message_factory: MessageFactory, ) -> Callable[[], RoundRobinGroupChatManager]: def _factory() -> RoundRobinGroupChatManager: return RoundRobinGroupChatManager( name, group_topic_type, output_topic_type, participant_topic_types, participant_names, participant_descriptions, output_message_queue, termination_condition, max_turns, message_factory, self._emit_team_events, ) return _factory
[文档] def _to_config(self) -> RoundRobinGroupChatConfig: participants = [participant.dump_component() for participant in self._participants] termination_condition = self._termination_condition.dump_component() if self._termination_condition else None return RoundRobinGroupChatConfig( participants=participants, termination_condition=termination_condition, max_turns=self._max_turns, emit_team_events=self._emit_team_events, )
[文档] @classmethod def _from_config(cls, config: RoundRobinGroupChatConfig) -> Self: participants = [ChatAgent.load_component(participant) for participant in config.participants] termination_condition = ( TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None ) return cls( participants, termination_condition=termination_condition, max_turns=config.max_turns, emit_team_events=config.emit_team_events, )