import asyncio
from collections import Counter, deque
from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set
from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel
from pydantic import BaseModel
from typing_extensions import Self
from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import ChatAgent, OrTerminationCondition, Response, TerminationCondition
from autogen_agentchat.conditions import StopMessageTermination
from autogen_agentchat.messages import (
BaseAgentEvent,
BaseChatMessage,
ChatMessage,
MessageFactory,
StopMessage,
TextMessage,
)
from autogen_agentchat.state import BaseGroupChatManagerState
from autogen_agentchat.teams import BaseGroupChat
from ..._group_chat._base_group_chat_manager import BaseGroupChatManager
from ..._group_chat._events import GroupChatTermination
_DIGRAPH_STOP_AGENT_NAME = "DiGraphStopAgent"
_DIGRAPH_STOP_AGENT_MESSAGE = "Digraph execution is complete"
[文档]
class DiGraphEdge(BaseModel):
"""表示 :class:`DiGraph` 中的有向边,带有可选的执行条件。
.. warning::
这是一个实验性功能,API 将在未来版本中变更。
"""
target: str # Target node name
condition: str | None = None # Optional execution condition (trigger-based)
"""(实验性) 执行该边的条件。
如果为 None,则该边是无条件的。如果是字符串,则该边取决于最后一条代理聊天消息中是否包含该字符串。
注意:这是一个实验性功能,未来版本将会变更以允许更好地指定分支条件,类似于 `TerminationCondition` 类。
"""
[文档]
class DiGraphNode(BaseModel):
"""表示 :class:`DiGraph` 中的节点(代理),包含其出边和激活类型。
.. warning::
这是一个实验性功能,API 将在未来版本中变更。
"""
name: str # Agent's name
edges: List[DiGraphEdge] = [] # Outgoing edges
activation: Literal["all", "any"] = "all"
[文档]
class DiGraph(BaseModel):
"""定义了一个包含节点和边的有向图结构。
:class:`GraphFlow` 使用此结构来确定执行顺序和条件。
.. warning::
这是一个实验性功能,API 将在未来版本中发生变化。
"""
nodes: Dict[str, DiGraphNode] # Node name → DiGraphNode mapping
default_start_node: str | None = None # Default start node name
_has_cycles: bool | None = None # Cyclic graph flag
[文档]
def get_parents(self) -> Dict[str, List[str]]:
"""计算每个节点到其父节点的映射关系。"""
parents: Dict[str, List[str]] = {node: [] for node in self.nodes}
for node in self.nodes.values():
for edge in node.edges:
parents[edge.target].append(node.name)
return parents
[文档]
def get_start_nodes(self) -> Set[str]:
"""返回没有入边的节点(入口点)。"""
if self.default_start_node:
return {self.default_start_node}
parents = self.get_parents()
return set([node_name for node_name, parent_list in parents.items() if not parent_list])
[文档]
def get_leaf_nodes(self) -> Set[str]:
"""返回没有出边的节点(最终输出节点)。"""
return set([name for name, node in self.nodes.items() if not node.edges])
[文档]
def has_cycles_with_exit(self) -> bool:
"""
检查图中是否存在环,并验证每个环至少有一条条件边。
返回:
bool: 如果存在至少一个环且所有环都有退出条件,则返回True。
如果不存在环,则返回False。
抛出:
ValueError: 如果存在没有任何条件边的环。
"""
visited: Set[str] = set()
rec_stack: Set[str] = set()
path: List[str] = []
def dfs(node_name: str) -> bool:
visited.add(node_name)
rec_stack.add(node_name)
path.append(node_name)
for edge in self.nodes[node_name].edges:
target = edge.target
if target not in visited:
if dfs(target):
return True
elif target in rec_stack:
# Found a cycle → extract the cycle
cycle_start_index = path.index(target)
cycle_nodes = path[cycle_start_index:]
cycle_edges: List[DiGraphEdge] = []
for n in cycle_nodes:
cycle_edges.extend(self.nodes[n].edges)
if not any(edge.condition for edge in cycle_edges):
raise ValueError(
f"Cycle detected without exit condition: {' -> '.join(cycle_nodes + cycle_nodes[:1])}"
)
return True # Found cycle, but it has an exit condition
rec_stack.remove(node_name)
path.pop()
return False
has_cycle = False
for node in self.nodes:
if node not in visited:
if dfs(node):
has_cycle = True
return has_cycle
[文档]
def get_has_cycles(self) -> bool:
"""指示图中是否至少存在一个循环(具有有效的退出条件)。"""
if self._has_cycles is None:
self._has_cycles = self.has_cycles_with_exit()
return self._has_cycles
[文档]
def graph_validate(self) -> None:
"""验证图结构和执行规则。"""
if not self.nodes:
raise ValueError("Graph has no nodes.")
if not self.get_start_nodes():
raise ValueError("Graph must have at least one start node")
if not self.get_leaf_nodes():
raise ValueError("Graph must have at least one leaf node")
# Outgoing edge condition validation (per node)
for node in self.nodes.values():
# Check that if a node has an outgoing conditional edge, then all outgoing edges are conditional
has_condition = any(edge.condition for edge in node.edges)
has_unconditioned = any(edge.condition is None for edge in node.edges)
if has_condition and has_unconditioned:
raise ValueError(f"Node '{node.name}' has a mix of conditional and unconditional edges.")
self._has_cycles = self.has_cycles_with_exit()
class GraphFlowManagerState(BaseGroupChatManagerState):
"""跟踪基于DAG执行的活跃执行状态。"""
active_nodes: List[str] = [] # Currently executing nodes
type: str = "GraphManagerState"
class GraphFlowManager(BaseGroupChatManager):
"""使用有向图执行模型管理代理的执行。"""
def __init__(
self,
name: str,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
graph: DiGraph,
) -> None:
"""初始化基于图的执行管理器。"""
super().__init__(
name=name,
group_topic_type=group_topic_type,
output_topic_type=output_topic_type,
participant_topic_types=participant_topic_types,
participant_names=participant_names,
participant_descriptions=participant_descriptions,
output_message_queue=output_message_queue,
termination_condition=termination_condition,
max_turns=max_turns,
message_factory=message_factory,
)
graph.graph_validate()
if graph.get_has_cycles() and self._termination_condition is None and self._max_turns is None:
raise ValueError("A termination condition is required for cyclic graphs without a maximum turn limit.")
self._graph = graph
# Lookup table for incoming edges for each node.
self._parents = graph.get_parents()
# Lookup table for outgoing edges for each node.
self._edges: Dict[str, List[DiGraphEdge]] = {n: node.edges for n, node in graph.nodes.items()}
# Activation lookup table for each node.
self._activation: Dict[str, Literal["any", "all"]] = {n: node.activation for n, node in graph.nodes.items()}
# === Mutable states for the graph execution ===
# Count the number of remaining parents to activate each node.
self._remaining: Counter[str] = Counter({n: len(p) for n, p in self._parents.items()})
# Lookup table for nodes that have been enqueued through an any activation.
# This is used to prevent re-adding the same node multiple times.
self._enqueued_any: Dict[str, bool] = {n: False for n in graph.nodes}
# Ready queue for nodes that are ready to execute, starting with the start nodes.
self._ready: Deque[str] = deque([n for n in graph.get_start_nodes()])
async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
await super().update_message_thread(messages)
# Find the node that ran in the current turn.
message = messages[-1]
if message.source not in self._graph.nodes:
# Ignore messages from sources outside of the graph.
return
assert isinstance(message, BaseChatMessage)
source = message.source
content = message.to_model_text()
# Propagate the update to the children of the node.
for edge in self._edges[source]:
if edge.condition and edge.condition not in content:
continue
if self._activation[edge.target] == "all":
self._remaining[edge.target] -= 1
if self._remaining[edge.target] == 0:
# If all parents are done, add to the ready queue.
self._ready.append(edge.target)
else:
# If activation is any, add to the ready queue if not already enqueued.
if not self._enqueued_any[edge.target]:
self._ready.append(edge.target)
self._enqueued_any[edge.target] = True
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str]:
# Drain the ready queue for the next set of speakers.
speakers: List[str] = []
while self._ready:
speaker = self._ready.popleft()
speakers.append(speaker)
# Reset the bookkeeping for the node that were selected.
if self._activation[speaker] == "any":
self._enqueued_any[speaker] = False
else:
self._remaining[speaker] = len(self._parents[speaker])
# If there are no speakers, trigger the stop agent.
if not speakers:
speakers = [_DIGRAPH_STOP_AGENT_NAME]
return speakers
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
pass
async def save_state(self) -> Mapping[str, Any]:
"""保存执行状态。"""
state = {
"message_thread": [message.dump() for message in self._message_thread],
"current_turn": self._current_turn,
"remaining": dict(self._remaining),
"enqueued_any": dict(self._enqueued_any),
"ready": list(self._ready),
}
return state
async def load_state(self, state: Mapping[str, Any]) -> None:
"""从保存的数据恢复执行状态。"""
self._message_thread = [self._message_factory.create(msg) for msg in state["message_thread"]]
self._current_turn = state["current_turn"]
self._remaining = Counter(state["remaining"])
self._enqueued_any = state["enqueued_any"]
self._ready = deque(state["ready"])
async def reset(self) -> None:
"""将执行状态重置为图的起始状态。"""
self._current_turn = 0
self._message_thread.clear()
if self._termination_condition:
await self._termination_condition.reset()
self._remaining = Counter({n: len(p) for n, p in self._parents.items()})
self._enqueued_any = {n: False for n in self._graph.nodes}
self._ready = deque([n for n in self._graph.get_start_nodes()])
class _StopAgent(BaseChatAgent):
def __init__(self) -> None:
super().__init__(_DIGRAPH_STOP_AGENT_NAME, "Agent that terminates the GraphFlow.")
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
return (TextMessage, StopMessage)
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
return Response(chat_message=StopMessage(content=_DIGRAPH_STOP_AGENT_MESSAGE, source=self.name))
async def on_reset(self, cancellation_token: CancellationToken) -> None:
pass
class GraphFlowConfig(BaseModel):
"""GraphFlow 的声明式配置。"""
participants: List[ComponentModel]
termination_condition: ComponentModel | None = None
max_turns: int | None = None
graph: DiGraph # The execution graph for agents
[文档]
class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
"""一个按照有向图执行模式运行的群聊团队。
.. warning::
这是一个实验性功能,API 将在未来版本中变更。
该群聊基于有向图 (:class:`DiGraph`) 结构执行代理,支持复杂工作流,包括顺序执行、并行分发、
条件分支、合并模式以及带有显式退出条件的循环。
执行顺序由 `DiGraph` 中定义的边决定。图中每个节点对应一个代理,边定义代理间的消息流向。
节点可配置为在以下情况下激活:
- **所有**父节点完成时 (activation="all") → 默认
- **任一**父节点完成时 (activation="any")
通过边条件支持条件分支,根据聊天历史内容选择下一个代理。只要存在最终退出循环的条件,
就允许循环结构。
.. note::
使用 :class:`DiGraphBuilder` 类可轻松创建 :class:`DiGraph`。它提供了流畅的 API
用于添加节点和边、设置入口点以及验证图结构。
详见 :class:`DiGraphBuilder` 文档。
:class:`GraphFlow` 类设计用于与 :class:`DiGraphBuilder` 配合创建复杂工作流。
Args:
participants (List[ChatAgent]): 群聊中的参与者列表。
termination_condition (TerminationCondition, optional): 聊天终止条件。
max_turns (int, optional): 强制终止前的最大轮次。
graph (DiGraph): 定义节点流向和条件的有向执行图。
Raises:
ValueError: 如果参与者名称不唯一,或图验证失败(如存在无退出条件的循环)。
Examples:
**顺序流: A → B → C**
.. code-block:: python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent("A", model_client=model_client, system_message="You are a helpful assistant.")
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to Chinese.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to English.")
# Create a directed graph with sequential flow A -> B -> C.
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
builder.add_edge(agent_a, agent_b).add_edge(agent_b, agent_c)
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="Write a short story about a cat."):
print(event)
asyncio.run(main())
**并行分发: A → (B, C)**
.. code-block:: python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent("A", model_client=model_client, system_message="You are a helpful assistant.")
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to Chinese.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Japanese.")
# Create a directed graph with fan-out flow A -> (B, C).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c)
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="Write a short story about a cat."):
print(event)
asyncio.run(main())
**条件分支: A → B (若'yes') 或 C (若'no')**
.. code-block:: python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
)
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C ("no").
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
builder.add_edge(agent_a, agent_b, condition="yes")
builder.add_edge(agent_a, agent_c, condition="no")
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
print(event)
asyncio.run(main())
**带退出条件的循环: A → B → C (若'APPROVE') 或 A (若'REJECT')**
.. code-block:: python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="You are a helpful assistant.",
)
agent_b = AssistantAgent(
"B",
model_client=model_client,
system_message="Provide feedback on the input, if your feedback has been addressed, "
"say 'APPROVE', else say 'REJECT' and provide a reason.",
)
agent_c = AssistantAgent(
"C", model_client=model_client, system_message="Translate the final product to Korean."
)
# Create a loop graph with conditional exit: A -> B -> C ("APPROVE"), B -> A ("REJECT").
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
builder.add_edge(agent_a, agent_b)
builder.add_conditional_edges(agent_b, {"APPROVE": agent_c, "REJECT": agent_a})
builder.set_entry_point(agent_a)
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(20), # Max 20 messages to avoid infinite loop.
)
# Run the team and print the events.
async for event in team.run_stream(task="Write a short poem about AI Agents."):
print(event)
asyncio.run(main())
"""
component_config_schema = GraphFlowConfig
component_provider_override = "autogen_agentchat.teams.GraphFlow"
def __init__(
self,
participants: List[ChatAgent],
graph: DiGraph,
termination_condition: TerminationCondition | None = None,
max_turns: int | None = None,
runtime: AgentRuntime | None = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
) -> None:
self._input_participants = participants
self._input_termination_condition = termination_condition
stop_agent = _StopAgent()
stop_agent_termination = StopMessageTermination()
termination_condition = (
stop_agent_termination
if not termination_condition
else OrTerminationCondition(stop_agent_termination, termination_condition)
)
participants = [stop_agent] + participants
super().__init__(
participants,
group_chat_manager_name="GraphManager",
group_chat_manager_class=GraphFlowManager,
termination_condition=termination_condition,
max_turns=max_turns,
runtime=runtime,
custom_message_types=custom_message_types,
)
self._graph = graph
def _create_group_chat_manager_factory(
self,
name: str,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
) -> Callable[[], GraphFlowManager]:
"""创建用于初始化基于DiGraph的聊天管理器的工厂方法。"""
def _factory() -> GraphFlowManager:
return GraphFlowManager(
name=name,
group_topic_type=group_topic_type,
output_topic_type=output_topic_type,
participant_topic_types=participant_topic_types,
participant_names=participant_names,
participant_descriptions=participant_descriptions,
output_message_queue=output_message_queue,
termination_condition=termination_condition,
max_turns=max_turns,
message_factory=message_factory,
graph=self._graph,
)
return _factory
def _to_config(self) -> GraphFlowConfig:
"""将实例转换为配置对象。"""
participants = [participant.dump_component() for participant in self._input_participants]
termination_condition = (
self._input_termination_condition.dump_component() if self._input_termination_condition else None
)
return GraphFlowConfig(
participants=participants,
termination_condition=termination_condition,
max_turns=self._max_turns,
graph=self._graph,
)
@classmethod
def _from_config(cls, config: GraphFlowConfig) -> Self:
"""从配置对象重建实例。"""
participants = [ChatAgent.load_component(participant) for participant in config.participants]
termination_condition = (
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
)
return cls(
participants, graph=config.graph, termination_condition=termination_condition, max_turns=config.max_turns
)