import asyncio
from typing import Any, Callable, List, Mapping, Sequence
from autogen_core import AgentRuntime, Component, ComponentModel
from pydantic import BaseModel
from typing_extensions import Self
from ...base import ChatAgent, TerminationCondition
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory
from ...state import RoundRobinManagerState
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
from ._events import GroupChatTermination
class RoundRobinGroupChatManager(BaseGroupChatManager):
"""一个以轮询方式选择下一个发言者的群聊管理器。"""
def __init__(
self,
name: str,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
emit_team_events: bool,
) -> None:
super().__init__(
name,
group_topic_type,
output_topic_type,
participant_topic_types,
participant_names,
participant_descriptions,
output_message_queue,
termination_condition,
max_turns,
message_factory,
emit_team_events,
)
self._next_speaker_index = 0
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
pass
async def reset(self) -> None:
self._current_turn = 0
self._message_thread.clear()
if self._termination_condition is not None:
await self._termination_condition.reset()
self._next_speaker_index = 0
async def save_state(self) -> Mapping[str, Any]:
state = RoundRobinManagerState(
message_thread=[message.dump() for message in self._message_thread],
current_turn=self._current_turn,
next_speaker_index=self._next_speaker_index,
)
return state.model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
round_robin_state = RoundRobinManagerState.model_validate(state)
self._message_thread = [self._message_factory.create(message) for message in round_robin_state.message_thread]
self._current_turn = round_robin_state.current_turn
self._next_speaker_index = round_robin_state.next_speaker_index
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
"""以轮询方式从参与者中选择一个发言者。
.. note::
该方法始终返回单个发言者。
"""
current_speaker_index = self._next_speaker_index
self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_names)
current_speaker = self._participant_names[current_speaker_index]
return current_speaker
class RoundRobinGroupChatConfig(BaseModel):
"""轮询式群聊的声明式配置。"""
participants: List[ComponentModel]
termination_condition: ComponentModel | None = None
max_turns: int | None = None
emit_team_events: bool = False
[文档]
class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
"""一个运行群聊的团队,参与者以轮询方式轮流向所有人发布消息。
如果团队中只有一个参与者,该参与者将是唯一的发言者。
Args:
participants (List[BaseChatAgent]): 群聊中的参与者列表。
termination_condition (TerminationCondition, optional): 群聊的终止条件。默认为 None。
如果没有终止条件,群聊将无限期运行。
max_turns (int, optional): 群聊停止前的最大轮次。默认为 None,表示无限制。
custom_message_types (List[type[BaseAgentEvent | BaseChatMessage]], optional): 群聊中使用的自定义消息类型列表。
如果使用自定义消息类型或您的代理生成了自定义消息类型,需要在此指定。
确保您的自定义消息类型是 :class:`~autogen_agentchat.messages.BaseAgentEvent` 或 :class:`~autogen_agentchat.messages.BaseChatMessage` 的子类。
emit_team_events (bool, optional): 是否通过 :meth:`BaseGroupChat.run_stream` 方法发出团队事件。默认为 False。
Raises:
ValueError: 如果未提供参与者或参与者名称不唯一。
Examples:
包含一个带工具参与者的团队:
.. code-block:: python
import asyncio
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.conditions import TextMentionTermination
from autogen_agentchat.ui import Console
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
async def get_weather(location: str) -> str:
return f"The weather in {location} is sunny."
assistant = AssistantAgent(
"Assistant",
model_client=model_client,
tools=[get_weather],
)
termination = TextMentionTermination("TERMINATE")
team = RoundRobinGroupChat([assistant], termination_condition=termination)
await Console(team.run_stream(task="What's the weather in New York?"))
asyncio.run(main())
包含多个参与者的团队:
.. code-block:: python
import asyncio
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.conditions import TextMentionTermination
from autogen_agentchat.ui import Console
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent1 = AssistantAgent("Assistant1", model_client=model_client)
agent2 = AssistantAgent("Assistant2", model_client=model_client)
termination = TextMentionTermination("TERMINATE")
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
await Console(team.run_stream(task="Tell me some jokes."))
asyncio.run(main())
"""
component_config_schema = RoundRobinGroupChatConfig
component_provider_override = "autogen_agentchat.teams.RoundRobinGroupChat"
# TODO: Add * to the constructor to separate the positional parameters from the kwargs.
# This may be a breaking change so let's wait until a good time to do it.
def __init__(
self,
participants: List[ChatAgent],
termination_condition: TerminationCondition | None = None,
max_turns: int | None = None,
runtime: AgentRuntime | None = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
emit_team_events: bool = False,
) -> None:
super().__init__(
participants,
group_chat_manager_name="RoundRobinGroupChatManager",
group_chat_manager_class=RoundRobinGroupChatManager,
termination_condition=termination_condition,
max_turns=max_turns,
runtime=runtime,
custom_message_types=custom_message_types,
emit_team_events=emit_team_events,
)
def _create_group_chat_manager_factory(
self,
name: str,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
) -> Callable[[], RoundRobinGroupChatManager]:
def _factory() -> RoundRobinGroupChatManager:
return RoundRobinGroupChatManager(
name,
group_topic_type,
output_topic_type,
participant_topic_types,
participant_names,
participant_descriptions,
output_message_queue,
termination_condition,
max_turns,
message_factory,
self._emit_team_events,
)
return _factory
[文档]
def _to_config(self) -> RoundRobinGroupChatConfig:
participants = [participant.dump_component() for participant in self._participants]
termination_condition = self._termination_condition.dump_component() if self._termination_condition else None
return RoundRobinGroupChatConfig(
participants=participants,
termination_condition=termination_condition,
max_turns=self._max_turns,
emit_team_events=self._emit_team_events,
)
[文档]
@classmethod
def _from_config(cls, config: RoundRobinGroupChatConfig) -> Self:
participants = [ChatAgent.load_component(participant) for participant in config.participants]
termination_condition = (
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
)
return cls(
participants,
termination_condition=termination_condition,
max_turns=config.max_turns,
emit_team_events=config.emit_team_events,
)