autogen_agentchat.teams._group_chat._graph._digraph_group_chat 源代码

import asyncio
from collections import Counter, deque
from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set

from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel
from pydantic import BaseModel
from typing_extensions import Self

from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import ChatAgent, OrTerminationCondition, Response, TerminationCondition
from autogen_agentchat.conditions import StopMessageTermination
from autogen_agentchat.messages import (
    BaseAgentEvent,
    BaseChatMessage,
    ChatMessage,
    MessageFactory,
    StopMessage,
    TextMessage,
)
from autogen_agentchat.state import BaseGroupChatManagerState
from autogen_agentchat.teams import BaseGroupChat

from ..._group_chat._base_group_chat_manager import BaseGroupChatManager
from ..._group_chat._events import GroupChatTermination

_DIGRAPH_STOP_AGENT_NAME = "DiGraphStopAgent"
_DIGRAPH_STOP_AGENT_MESSAGE = "Digraph execution is complete"


[文档] class DiGraphEdge(BaseModel): """表示 :class:`DiGraph` 中的有向边,带有可选的执行条件。 .. warning:: 这是一个实验性功能,API 将在未来版本中变更。 """ target: str # Target node name condition: str | None = None # Optional execution condition (trigger-based) """(实验性) 执行该边的条件。 如果为 None,则该边是无条件的。如果是字符串,则该边取决于最后一条代理聊天消息中是否包含该字符串。 注意:这是一个实验性功能,未来版本将会变更以允许更好地指定分支条件,类似于 `TerminationCondition` 类。 """
[文档] class DiGraphNode(BaseModel): """表示 :class:`DiGraph` 中的节点(代理),包含其出边和激活类型。 .. warning:: 这是一个实验性功能,API 将在未来版本中变更。 """ name: str # Agent's name edges: List[DiGraphEdge] = [] # Outgoing edges activation: Literal["all", "any"] = "all"
[文档] class DiGraph(BaseModel): """定义了一个包含节点和边的有向图结构。 :class:`GraphFlow` 使用此结构来确定执行顺序和条件。 .. warning:: 这是一个实验性功能,API 将在未来版本中发生变化。 """ nodes: Dict[str, DiGraphNode] # Node name → DiGraphNode mapping default_start_node: str | None = None # Default start node name _has_cycles: bool | None = None # Cyclic graph flag
[文档] def get_parents(self) -> Dict[str, List[str]]: """计算每个节点到其父节点的映射关系。""" parents: Dict[str, List[str]] = {node: [] for node in self.nodes} for node in self.nodes.values(): for edge in node.edges: parents[edge.target].append(node.name) return parents
[文档] def get_start_nodes(self) -> Set[str]: """返回没有入边的节点(入口点)。""" if self.default_start_node: return {self.default_start_node} parents = self.get_parents() return set([node_name for node_name, parent_list in parents.items() if not parent_list])
[文档] def get_leaf_nodes(self) -> Set[str]: """返回没有出边的节点(最终输出节点)。""" return set([name for name, node in self.nodes.items() if not node.edges])
[文档] def has_cycles_with_exit(self) -> bool: """ 检查图中是否存在环,并验证每个环至少有一条条件边。 返回: bool: 如果存在至少一个环且所有环都有退出条件,则返回True。 如果不存在环,则返回False。 抛出: ValueError: 如果存在没有任何条件边的环。 """ visited: Set[str] = set() rec_stack: Set[str] = set() path: List[str] = [] def dfs(node_name: str) -> bool: visited.add(node_name) rec_stack.add(node_name) path.append(node_name) for edge in self.nodes[node_name].edges: target = edge.target if target not in visited: if dfs(target): return True elif target in rec_stack: # Found a cycle → extract the cycle cycle_start_index = path.index(target) cycle_nodes = path[cycle_start_index:] cycle_edges: List[DiGraphEdge] = [] for n in cycle_nodes: cycle_edges.extend(self.nodes[n].edges) if not any(edge.condition for edge in cycle_edges): raise ValueError( f"Cycle detected without exit condition: {' -> '.join(cycle_nodes + cycle_nodes[:1])}" ) return True # Found cycle, but it has an exit condition rec_stack.remove(node_name) path.pop() return False has_cycle = False for node in self.nodes: if node not in visited: if dfs(node): has_cycle = True return has_cycle
[文档] def get_has_cycles(self) -> bool: """指示图中是否至少存在一个循环(具有有效的退出条件)。""" if self._has_cycles is None: self._has_cycles = self.has_cycles_with_exit() return self._has_cycles
[文档] def graph_validate(self) -> None: """验证图结构和执行规则。""" if not self.nodes: raise ValueError("Graph has no nodes.") if not self.get_start_nodes(): raise ValueError("Graph must have at least one start node") if not self.get_leaf_nodes(): raise ValueError("Graph must have at least one leaf node") # Outgoing edge condition validation (per node) for node in self.nodes.values(): # Check that if a node has an outgoing conditional edge, then all outgoing edges are conditional has_condition = any(edge.condition for edge in node.edges) has_unconditioned = any(edge.condition is None for edge in node.edges) if has_condition and has_unconditioned: raise ValueError(f"Node '{node.name}' has a mix of conditional and unconditional edges.") self._has_cycles = self.has_cycles_with_exit()
class GraphFlowManagerState(BaseGroupChatManagerState): """跟踪基于DAG执行的活跃执行状态。""" active_nodes: List[str] = [] # Currently executing nodes type: str = "GraphManagerState" class GraphFlowManager(BaseGroupChatManager): """使用有向图执行模型管理代理的执行。""" def __init__( self, name: str, group_topic_type: str, output_topic_type: str, participant_topic_types: List[str], participant_names: List[str], participant_descriptions: List[str], output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination], termination_condition: TerminationCondition | None, max_turns: int | None, message_factory: MessageFactory, graph: DiGraph, ) -> None: """初始化基于图的执行管理器。""" super().__init__( name=name, group_topic_type=group_topic_type, output_topic_type=output_topic_type, participant_topic_types=participant_topic_types, participant_names=participant_names, participant_descriptions=participant_descriptions, output_message_queue=output_message_queue, termination_condition=termination_condition, max_turns=max_turns, message_factory=message_factory, ) graph.graph_validate() if graph.get_has_cycles() and self._termination_condition is None and self._max_turns is None: raise ValueError("A termination condition is required for cyclic graphs without a maximum turn limit.") self._graph = graph # Lookup table for incoming edges for each node. self._parents = graph.get_parents() # Lookup table for outgoing edges for each node. self._edges: Dict[str, List[DiGraphEdge]] = {n: node.edges for n, node in graph.nodes.items()} # Activation lookup table for each node. self._activation: Dict[str, Literal["any", "all"]] = {n: node.activation for n, node in graph.nodes.items()} # === Mutable states for the graph execution === # Count the number of remaining parents to activate each node. self._remaining: Counter[str] = Counter({n: len(p) for n, p in self._parents.items()}) # Lookup table for nodes that have been enqueued through an any activation. # This is used to prevent re-adding the same node multiple times. self._enqueued_any: Dict[str, bool] = {n: False for n in graph.nodes} # Ready queue for nodes that are ready to execute, starting with the start nodes. self._ready: Deque[str] = deque([n for n in graph.get_start_nodes()]) async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None: await super().update_message_thread(messages) # Find the node that ran in the current turn. message = messages[-1] if message.source not in self._graph.nodes: # Ignore messages from sources outside of the graph. return assert isinstance(message, BaseChatMessage) source = message.source content = message.to_model_text() # Propagate the update to the children of the node. for edge in self._edges[source]: if edge.condition and edge.condition not in content: continue if self._activation[edge.target] == "all": self._remaining[edge.target] -= 1 if self._remaining[edge.target] == 0: # If all parents are done, add to the ready queue. self._ready.append(edge.target) else: # If activation is any, add to the ready queue if not already enqueued. if not self._enqueued_any[edge.target]: self._ready.append(edge.target) self._enqueued_any[edge.target] = True async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str]: # Drain the ready queue for the next set of speakers. speakers: List[str] = [] while self._ready: speaker = self._ready.popleft() speakers.append(speaker) # Reset the bookkeeping for the node that were selected. if self._activation[speaker] == "any": self._enqueued_any[speaker] = False else: self._remaining[speaker] = len(self._parents[speaker]) # If there are no speakers, trigger the stop agent. if not speakers: speakers = [_DIGRAPH_STOP_AGENT_NAME] return speakers async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: pass async def save_state(self) -> Mapping[str, Any]: """保存执行状态。""" state = { "message_thread": [message.dump() for message in self._message_thread], "current_turn": self._current_turn, "remaining": dict(self._remaining), "enqueued_any": dict(self._enqueued_any), "ready": list(self._ready), } return state async def load_state(self, state: Mapping[str, Any]) -> None: """从保存的数据恢复执行状态。""" self._message_thread = [self._message_factory.create(msg) for msg in state["message_thread"]] self._current_turn = state["current_turn"] self._remaining = Counter(state["remaining"]) self._enqueued_any = state["enqueued_any"] self._ready = deque(state["ready"]) async def reset(self) -> None: """将执行状态重置为图的起始状态。""" self._current_turn = 0 self._message_thread.clear() if self._termination_condition: await self._termination_condition.reset() self._remaining = Counter({n: len(p) for n, p in self._parents.items()}) self._enqueued_any = {n: False for n in self._graph.nodes} self._ready = deque([n for n in self._graph.get_start_nodes()]) class _StopAgent(BaseChatAgent): def __init__(self) -> None: super().__init__(_DIGRAPH_STOP_AGENT_NAME, "Agent that terminates the GraphFlow.") @property def produced_message_types(self) -> Sequence[type[ChatMessage]]: return (TextMessage, StopMessage) async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response: return Response(chat_message=StopMessage(content=_DIGRAPH_STOP_AGENT_MESSAGE, source=self.name)) async def on_reset(self, cancellation_token: CancellationToken) -> None: pass class GraphFlowConfig(BaseModel): """GraphFlow 的声明式配置。""" participants: List[ComponentModel] termination_condition: ComponentModel | None = None max_turns: int | None = None graph: DiGraph # The execution graph for agents
[文档] class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]): """一个按照有向图执行模式运行的群聊团队。 .. warning:: 这是一个实验性功能,API 将在未来版本中变更。 该群聊基于有向图 (:class:`DiGraph`) 结构执行代理,支持复杂工作流,包括顺序执行、并行分发、 条件分支、合并模式以及带有显式退出条件的循环。 执行顺序由 `DiGraph` 中定义的边决定。图中每个节点对应一个代理,边定义代理间的消息流向。 节点可配置为在以下情况下激活: - **所有**父节点完成时 (activation="all") → 默认 - **任一**父节点完成时 (activation="any") 通过边条件支持条件分支,根据聊天历史内容选择下一个代理。只要存在最终退出循环的条件, 就允许循环结构。 .. note:: 使用 :class:`DiGraphBuilder` 类可轻松创建 :class:`DiGraph`。它提供了流畅的 API 用于添加节点和边、设置入口点以及验证图结构。 详见 :class:`DiGraphBuilder` 文档。 :class:`GraphFlow` 类设计用于与 :class:`DiGraphBuilder` 配合创建复杂工作流。 Args: participants (List[ChatAgent]): 群聊中的参与者列表。 termination_condition (TerminationCondition, optional): 聊天终止条件。 max_turns (int, optional): 强制终止前的最大轮次。 graph (DiGraph): 定义节点流向和条件的有向执行图。 Raises: ValueError: 如果参与者名称不唯一,或图验证失败(如存在无退出条件的循环)。 Examples: **顺序流: A → B → C** .. code-block:: python import asyncio from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.conditions import MaxMessageTermination from autogen_agentchat.teams import DiGraphBuilder, GraphFlow from autogen_ext.models.openai import OpenAIChatCompletionClient async def main(): # Initialize agents with OpenAI model clients. model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano") agent_a = AssistantAgent("A", model_client=model_client, system_message="You are a helpful assistant.") agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to Chinese.") agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to English.") # Create a directed graph with sequential flow A -> B -> C. builder = DiGraphBuilder() builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) builder.add_edge(agent_a, agent_b).add_edge(agent_b, agent_c) graph = builder.build() # Create a GraphFlow team with the directed graph. team = GraphFlow( participants=[agent_a, agent_b, agent_c], graph=graph, termination_condition=MaxMessageTermination(5), ) # Run the team and print the events. async for event in team.run_stream(task="Write a short story about a cat."): print(event) asyncio.run(main()) **并行分发: A → (B, C)** .. code-block:: python import asyncio from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.conditions import MaxMessageTermination from autogen_agentchat.teams import DiGraphBuilder, GraphFlow from autogen_ext.models.openai import OpenAIChatCompletionClient async def main(): # Initialize agents with OpenAI model clients. model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano") agent_a = AssistantAgent("A", model_client=model_client, system_message="You are a helpful assistant.") agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to Chinese.") agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Japanese.") # Create a directed graph with fan-out flow A -> (B, C). builder = DiGraphBuilder() builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c) graph = builder.build() # Create a GraphFlow team with the directed graph. team = GraphFlow( participants=[agent_a, agent_b, agent_c], graph=graph, termination_condition=MaxMessageTermination(5), ) # Run the team and print the events. async for event in team.run_stream(task="Write a short story about a cat."): print(event) asyncio.run(main()) **条件分支: A → B (若'yes') 或 C (若'no')** .. code-block:: python import asyncio from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.conditions import MaxMessageTermination from autogen_agentchat.teams import DiGraphBuilder, GraphFlow from autogen_ext.models.openai import OpenAIChatCompletionClient async def main(): # Initialize agents with OpenAI model clients. model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano") agent_a = AssistantAgent( "A", model_client=model_client, system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.", ) agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.") agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.") # Create a directed graph with conditional branching flow A -> B ("yes"), A -> C ("no"). builder = DiGraphBuilder() builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) builder.add_edge(agent_a, agent_b, condition="yes") builder.add_edge(agent_a, agent_c, condition="no") graph = builder.build() # Create a GraphFlow team with the directed graph. team = GraphFlow( participants=[agent_a, agent_b, agent_c], graph=graph, termination_condition=MaxMessageTermination(5), ) # Run the team and print the events. async for event in team.run_stream(task="AutoGen is a framework for building AI agents."): print(event) asyncio.run(main()) **带退出条件的循环: A → B → C (若'APPROVE') 或 A (若'REJECT')** .. code-block:: python import asyncio from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.conditions import MaxMessageTermination from autogen_agentchat.teams import DiGraphBuilder, GraphFlow from autogen_ext.models.openai import OpenAIChatCompletionClient async def main(): # Initialize agents with OpenAI model clients. model_client = OpenAIChatCompletionClient(model="gpt-4.1") agent_a = AssistantAgent( "A", model_client=model_client, system_message="You are a helpful assistant.", ) agent_b = AssistantAgent( "B", model_client=model_client, system_message="Provide feedback on the input, if your feedback has been addressed, " "say 'APPROVE', else say 'REJECT' and provide a reason.", ) agent_c = AssistantAgent( "C", model_client=model_client, system_message="Translate the final product to Korean." ) # Create a loop graph with conditional exit: A -> B -> C ("APPROVE"), B -> A ("REJECT"). builder = DiGraphBuilder() builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) builder.add_edge(agent_a, agent_b) builder.add_conditional_edges(agent_b, {"APPROVE": agent_c, "REJECT": agent_a}) builder.set_entry_point(agent_a) graph = builder.build() # Create a GraphFlow team with the directed graph. team = GraphFlow( participants=[agent_a, agent_b, agent_c], graph=graph, termination_condition=MaxMessageTermination(20), # Max 20 messages to avoid infinite loop. ) # Run the team and print the events. async for event in team.run_stream(task="Write a short poem about AI Agents."): print(event) asyncio.run(main()) """ component_config_schema = GraphFlowConfig component_provider_override = "autogen_agentchat.teams.GraphFlow" def __init__( self, participants: List[ChatAgent], graph: DiGraph, termination_condition: TerminationCondition | None = None, max_turns: int | None = None, runtime: AgentRuntime | None = None, custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, ) -> None: self._input_participants = participants self._input_termination_condition = termination_condition stop_agent = _StopAgent() stop_agent_termination = StopMessageTermination() termination_condition = ( stop_agent_termination if not termination_condition else OrTerminationCondition(stop_agent_termination, termination_condition) ) participants = [stop_agent] + participants super().__init__( participants, group_chat_manager_name="GraphManager", group_chat_manager_class=GraphFlowManager, termination_condition=termination_condition, max_turns=max_turns, runtime=runtime, custom_message_types=custom_message_types, ) self._graph = graph def _create_group_chat_manager_factory( self, name: str, group_topic_type: str, output_topic_type: str, participant_topic_types: List[str], participant_names: List[str], participant_descriptions: List[str], output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination], termination_condition: TerminationCondition | None, max_turns: int | None, message_factory: MessageFactory, ) -> Callable[[], GraphFlowManager]: """创建用于初始化基于DiGraph的聊天管理器的工厂方法。""" def _factory() -> GraphFlowManager: return GraphFlowManager( name=name, group_topic_type=group_topic_type, output_topic_type=output_topic_type, participant_topic_types=participant_topic_types, participant_names=participant_names, participant_descriptions=participant_descriptions, output_message_queue=output_message_queue, termination_condition=termination_condition, max_turns=max_turns, message_factory=message_factory, graph=self._graph, ) return _factory def _to_config(self) -> GraphFlowConfig: """将实例转换为配置对象。""" participants = [participant.dump_component() for participant in self._input_participants] termination_condition = ( self._input_termination_condition.dump_component() if self._input_termination_condition else None ) return GraphFlowConfig( participants=participants, termination_condition=termination_condition, max_turns=self._max_turns, graph=self._graph, ) @classmethod def _from_config(cls, config: GraphFlowConfig) -> Self: """从配置对象重建实例。""" participants = [ChatAgent.load_component(participant) for participant in config.participants] termination_condition = ( TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None ) return cls( participants, graph=config.graph, termination_condition=termination_condition, max_turns=config.max_turns )