反思模式#

反思模式是一种设计模式,其中LLM生成后接一个反思过程, 这本身是另一个基于首轮输出生成的LLM响应。 例如,给定编写代码的任务时,第一个LLM可以生成代码片段, 而第二个LLM可以生成对该代码片段的评审意见。

在AutoGen和智能体的上下文中,反思模式可以实现为一对 智能体,其中第一个智能体生成消息,第二个智能体 生成对该消息的响应。两个智能体持续交互 直到满足停止条件,例如达到最大迭代次数 或获得第二个智能体的批准。

让我们使用AutoGen智能体实现一个简单的反思设计模式。 将有两个智能体:编码器智能体和评审器智能体,编码器智能体 将生成代码片段,评审器智能体将生成对 该代码片段的评审意见。

消息协议#

在定义智能体之前,我们需要先定义智能体之间的消息协议。

from dataclasses import dataclass


@dataclass
class CodeWritingTask:
    task: str


@dataclass
class CodeWritingResult:
    task: str
    code: str
    review: str


@dataclass
class CodeReviewTask:
    session_id: str
    code_writing_task: str
    code_writing_scratchpad: str
    code: str


@dataclass
class CodeReviewResult:
    review: str
    session_id: str
    approved: bool

上述消息集合定义了我们的反思设计模式示例协议:

  • 应用程序向编码器智能体发送CodeWritingTask消息

  • 编码器智能体生成CodeReviewTask消息,发送给评审器智能体

  • 评审器智能体生成CodeReviewResult消息,返回给编码器智能体

  • 根据CodeReviewResult消息,如果代码获得批准,编码器智能体发送CodeWritingResult消息 返回给应用程序;否则,编码器智能体发送另一个CodeReviewTask消息给评审器智能体, 该过程持续进行。

我们可以用数据流图来可视化消息协议:

编码器-评审器数据流

智能体#

现在,让我们为反思设计模式定义智能体。

import json
import re
import uuid
from typing import Dict, List, Union

from autogen_core import MessageContext, RoutedAgent, TopicId, default_subscription, message_handler
from autogen_core.models import (
    AssistantMessage,
    ChatCompletionClient,
    LLMMessage,
    SystemMessage,
    UserMessage,
)

我们使用广播API 来实现该设计模式。智能体实现了发布/订阅模型。 编码器智能体订阅CodeWritingTaskCodeReviewResult消息, 并发布CodeReviewTaskCodeWritingResult消息。

@default_subscription
class CoderAgent(RoutedAgent):
    """An agent that performs code writing tasks."""

    def __init__(self, model_client: ChatCompletionClient) -> None:
        super().__init__("A code writing agent.")
        self._system_messages: List[LLMMessage] = [
            SystemMessage(
                content="""You are a proficient coder. You write code to solve problems.
Work with the reviewer to improve your code.
Always put all finished code in a single Markdown code block.
For example:
```python
def hello_world():
    print("Hello, World!")
```

Respond using the following format:

Thoughts: <Your comments>
Code: <Your code>
""",
            )
        ]
        self._model_client = model_client
        self._session_memory: Dict[str, List[CodeWritingTask | CodeReviewTask | CodeReviewResult]] = {}

    @message_handler
    async def handle_code_writing_task(self, message: CodeWritingTask, ctx: MessageContext) -> None:
        # 仅针对当前请求将消息存储在临时内存中。
        session_id = str(uuid.uuid4())
        self._session_memory.setdefault(session_id, []).append(message)
        # 使用聊天补全API生成响应
        response = await self._model_client.create(
            self._system_messages + [UserMessage(content=message.task, source=self.metadata["type"])],
            cancellation_token=ctx.cancellation_token,
        )
        assert isinstance(response.content, str)
        # 从响应中提取代码块
        code_block = self._extract_code_block(response.content)
        if code_block is None:
            raise ValueError("Code block not found.")
        # 创建代码审查任务
        code_review_task = CodeReviewTask(
            session_id=session_id,
            code_writing_task=message.task,
            code_writing_scratchpad=response.content,
            code=code_block,
        )
        # 将代码审查任务存储到会话内存中
        self._session_memory[session_id].append(code_review_task)
        # 发布代码审查任务
        await self.publish_message(code_review_task, topic_id=TopicId("default", self.id.key))

    @message_handler
    async def handle_code_review_result(self, message: CodeReviewResult, ctx: MessageContext) -> None:
        # 将评审结果存入会话内存
        self._session_memory[message.session_id].append(message)
        # 从先前消息中获取请求
        review_request = next(
            m for m in reversed(self._session_memory[message.session_id]) if isinstance(m, CodeReviewTask)
        )
        assert review_request is not None
        # 检查代码是否已获批准
        if message.approved:
            # 发布代码编写结果
            await self.publish_message(
                CodeWritingResult(
                    code=review_request.code,
                    task=review_request.code_writing_task,
                    review=message.review,
                ),
                topic_id=TopicId("default", self.id.key),
            )
            print("Code Writing Result:")
            print("-" * 80)
            print(f"Task:\n{review_request.code_writing_task}")
            print("-" * 80)
            print(f"Code:\n{review_request.code}")
            print("-" * 80)
            print(f"Review:\n{message.review}")
            print("-" * 80)
        else:
            # 创建要发送给模型的LLM消息列表
            messages: List[LLMMessage] = [*self._system_messages]
            for m in self._session_memory[message.session_id]:
                if isinstance(m, CodeReviewResult):
                    messages.append(UserMessage(content=m.review, source="Reviewer"))
                elif isinstance(m, CodeReviewTask):
                    messages.append(AssistantMessage(content=m.code_writing_scratchpad, source="Coder"))
                elif isinstance(m, CodeWritingTask):
                    messages.append(UserMessage(content=m.task, source="User"))
                else:
                    raise ValueError(f"Unexpected message type: {m}")
            # 使用聊天补全API生成修订版本
            response = await self._model_client.create(messages, cancellation_token=ctx.cancellation_token)
            assert isinstance(response.content, str)
            # 从响应中提取代码块
            code_block = self._extract_code_block(response.content)
            if code_block is None:
                raise ValueError("Code block not found.")
            # 创建新的代码审查任务
            code_review_task = CodeReviewTask(
                session_id=message.session_id,
                code_writing_task=review_request.code_writing_task,
                code_writing_scratchpad=response.content,
                code=code_block,
            )
            # 将代码审查任务存储到会话内存中
            self._session_memory[message.session_id].append(code_review_task)
            # 发布新的代码审查任务
            await self.publish_message(code_review_task, topic_id=TopicId("default", self.id.key))

    def _extract_code_block(self, markdown_text: str) -> Union[str, None]:
        pattern = r"```(\w+)\n(.*?)\n```"
        # 在Markdown文本中搜索模式
        match = re.search(pattern, markdown_text, re.DOTALL)
        # 如果找到匹配项,则提取语言和代码块
        if match:
            return match.group(2)
        return None

关于CoderAgent需要注意的几点:

  • 它在系统消息中使用思维链提示。

  • 它将不同CodeWritingTask的消息历史存储在字典中, 因此每个任务都有自己的历史记录。

  • 当使用其模型客户端进行LLM推理请求时,它会将 消息历史转换为autogen_core.models.LLMMessage对象列表 传递给模型客户端。

审阅者代理订阅CodeReviewTask消息并发布CodeReviewResult消息。

@default_subscription
class ReviewerAgent(RoutedAgent):
    """An agent that performs code review tasks."""

    def __init__(self, model_client: ChatCompletionClient) -> None:
        super().__init__("A code reviewer agent.")
        self._system_messages: List[LLMMessage] = [
            SystemMessage(
                content="""You are a code reviewer. You focus on correctness, efficiency and safety of the code.
Respond using the following JSON format:
{
    "correctness": "<Your comments>",
    "efficiency": "<Your comments>",
    "safety": "<Your comments>",
    "approval": "<APPROVE or REVISE>",
    "suggested_changes": "<Your comments>"
}
""",
            )
        ]
        self._session_memory: Dict[str, List[CodeReviewTask | CodeReviewResult]] = {}
        self._model_client = model_client

    @message_handler
    async def handle_code_review_task(self, message: CodeReviewTask, ctx: MessageContext) -> None:
        # 格式化代码审查的提示
        # 。收集之前可用的反馈。
        previous_feedback = ""
        if message.session_id in self._session_memory:
            previous_review = next(
                (m for m in reversed(self._session_memory[message.session_id]) if isinstance(m, CodeReviewResult)),
                None,
            )
            if previous_review is not None:
                previous_feedback = previous_review.review
        # 仅为此请求将消息存储在临时内存中。
        self._session_memory.setdefault(message.session_id, []).append(message)
        prompt = f"""The problem statement is: {message.code_writing_task}
The code is:
```
{message.code}
```

Previous feedback:
{previous_feedback}

Please review the code. If previous feedback was provided, see if it was addressed.
"""
        # 使用聊天补全API生成响应。
        response = await self._model_client.create(
            self._system_messages + [UserMessage(content=prompt, source=self.metadata["type"])],
            cancellation_token=ctx.cancellation_token,
            json_output=True,
        )
        assert isinstance(response.content, str)
        # 待办: 使用结构化生成库(如guidance)确保响应符合预期格式。
        # 解析响应JSON。
        review = json.loads(response.content)
        # 构建评论文本。
        review_text = "Code review:\n" + "\n".join([f"{k}: {v}" for k, v in review.items()])
        approved = review["approval"].lower().strip() == "approve"
        result = CodeReviewResult(
            review=review_text,
            session_id=message.session_id,
            approved=approved,
        )
        # 将评论结果存储到会话内存中。
        self._session_memory[message.session_id].append(result)
        # 发布评论结果。
        await self.publish_message(result, topic_id=TopicId("default", self.id.key))

ReviewerAgent 在进行LLM推理请求时使用JSON模式, 并在系统消息中采用思维链提示。

日志记录#

开启日志功能可查看代理间交换的消息。

import logging

logging.basicConfig(level=logging.WARNING)
logging.getLogger("autogen_core").setLevel(logging.DEBUG)

运行设计模式#

让我们用一个编码任务来测试这个设计模式。 由于所有代理都使用了 default_subscription() 类装饰器, 创建的代理将自动订阅默认主题。 我们向默认主题发布 CodeWritingTask 消息来启动反思流程。

from autogen_core import DefaultTopicId, SingleThreadedAgentRuntime
from autogen_ext.models.openai import OpenAIChatCompletionClient

runtime = SingleThreadedAgentRuntime()
model_client = OpenAIChatCompletionClient(model="gpt-4o-mini")
await ReviewerAgent.register(runtime, "ReviewerAgent", lambda: ReviewerAgent(model_client=model_client))
await CoderAgent.register(runtime, "CoderAgent", lambda: CoderAgent(model_client=model_client))
runtime.start()
await runtime.publish_message(
    message=CodeWritingTask(task="Write a function to find the sum of all even numbers in a list."),
    topic_id=DefaultTopicId(),
)

# 持续处理消息直到空闲状态。
await runtime.stop_when_idle()
# 关闭模型客户端。
await model_client.close()
INFO:autogen_core:Publishing message of type CodeWritingTask to all subscribers: {'task': 'Write a function to find the sum of all even numbers in a list.'}
INFO:autogen_core:Calling message handler for ReviewerAgent with message type CodeWritingTask published by Unknown
INFO:autogen_core:Calling message handler for CoderAgent with message type CodeWritingTask published by Unknown
INFO:autogen_core:Unhandled message: CodeWritingTask(task='Write a function to find the sum of all even numbers in a list.')
INFO:autogen_core.events:{"prompt_tokens": 101, "completion_tokens": 88, "type": "LLMCall"}
INFO:autogen_core:Publishing message of type CodeReviewTask to all subscribers: {'session_id': '51db93d5-3e29-4b7f-9f96-77be7bb02a5e', 'code_writing_task': 'Write a function to find the sum of all even numbers in a list.', 'code_writing_scratchpad': 'Thoughts: To find the sum of all even numbers in a list, we can use a list comprehension to filter out the even numbers and then use the `sum()` function to calculate their total. The implementation should handle edge cases like an empty list or a list with no even numbers.\n\nCode:\n```python\ndef sum_of_even_numbers(numbers):\n    return sum(num for num in numbers if num % 2 == 0)\n```', 'code': 'def sum_of_even_numbers(numbers):\n    return sum(num for num in numbers if num % 2 == 0)'}
INFO:autogen_core:Calling message handler for ReviewerAgent with message type CodeReviewTask published by CoderAgent:default
INFO:autogen_core.events:{"prompt_tokens": 163, "completion_tokens": 235, "type": "LLMCall"}
INFO:autogen_core:Publishing message of type CodeReviewResult to all subscribers: {'review': "Code review:\ncorrectness: The function correctly identifies and sums all even numbers in the provided list. The use of a generator expression ensures that only even numbers are processed, which is correct.\nefficiency: The function is efficient as it utilizes a generator expression that avoids creating an intermediate list, therefore using less memory. The time complexity is O(n) where n is the number of elements in the input list, which is optimal for this task.\nsafety: The function does not include checks for input types. If a non-iterable or a list containing non-integer types is passed, it could lead to unexpected behavior or errors. It’s advisable to handle such cases.\napproval: REVISE\nsuggested_changes: Consider adding input validation to ensure that 'numbers' is a list and contains only integers. You could raise a ValueError if the input is invalid. Example: 'if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers): raise ValueError('Input must be a list of integers')'. This will make the function more robust.", 'session_id': '51db93d5-3e29-4b7f-9f96-77be7bb02a5e', 'approved': False}
INFO:autogen_core:Calling message handler for CoderAgent with message type CodeReviewResult published by ReviewerAgent:default
INFO:autogen_core.events:{"prompt_tokens": 421, "completion_tokens": 119, "type": "LLMCall"}
INFO:autogen_core:Publishing message of type CodeReviewTask to all subscribers: {'session_id': '51db93d5-3e29-4b7f-9f96-77be7bb02a5e', 'code_writing_task': 'Write a function to find the sum of all even numbers in a list.', 'code_writing_scratchpad': "Thoughts: I appreciate the reviewer's feedback on input validation. Adding type checks ensures that the function can handle unexpected inputs gracefully. I will implement the suggested changes and include checks for both the input type and the elements within the list to confirm that they are integers.\n\nCode:\n```python\ndef sum_of_even_numbers(numbers):\n    if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):\n        raise ValueError('Input must be a list of integers')\n    \n    return sum(num for num in numbers if num % 2 == 0)\n```", 'code': "def sum_of_even_numbers(numbers):\n    if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):\n        raise ValueError('Input must be a list of integers')\n    \n    return sum(num for num in numbers if num % 2 == 0)"}
INFO:autogen_core:Calling message handler for ReviewerAgent with message type CodeReviewTask published by CoderAgent:default
INFO:autogen_core.events:{"prompt_tokens": 420, "completion_tokens": 153, "type": "LLMCall"}
INFO:autogen_core:Publishing message of type CodeReviewResult to all subscribers: {'review': 'Code review:\ncorrectness: The function correctly sums all even numbers in the provided list. It raises a ValueError if the input is not a list of integers, which is a necessary check for correctness.\nefficiency: The function remains efficient with a time complexity of O(n) due to the use of a generator expression. There are no unnecessary intermediate lists created, so memory usage is optimal.\nsafety: The function includes input validation, which enhances safety by preventing incorrect input types. It raises a ValueError for invalid inputs, making the function more robust against unexpected data.\napproval: APPROVE\nsuggested_changes: No further changes are necessary as the previous feedback has been adequately addressed.', 'session_id': '51db93d5-3e29-4b7f-9f96-77be7bb02a5e', 'approved': True}
INFO:autogen_core:Calling message handler for CoderAgent with message type CodeReviewResult published by ReviewerAgent:default
INFO:autogen_core:Publishing message of type CodeWritingResult to all subscribers: {'task': 'Write a function to find the sum of all even numbers in a list.', 'code': "def sum_of_even_numbers(numbers):\n    if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):\n        raise ValueError('Input must be a list of integers')\n    \n    return sum(num for num in numbers if num % 2 == 0)", 'review': 'Code review:\ncorrectness: The function correctly sums all even numbers in the provided list. It raises a ValueError if the input is not a list of integers, which is a necessary check for correctness.\nefficiency: The function remains efficient with a time complexity of O(n) due to the use of a generator expression. There are no unnecessary intermediate lists created, so memory usage is optimal.\nsafety: The function includes input validation, which enhances safety by preventing incorrect input types. It raises a ValueError for invalid inputs, making the function more robust against unexpected data.\napproval: APPROVE\nsuggested_changes: No further changes are necessary as the previous feedback has been adequately addressed.'}
INFO:autogen_core:Calling message handler for ReviewerAgent with message type CodeWritingResult published by CoderAgent:default
INFO:autogen_core:Unhandled message: CodeWritingResult(task='Write a function to find the sum of all even numbers in a list.', code="def sum_of_even_numbers(numbers):\n    if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):\n        raise ValueError('Input must be a list of integers')\n    \n    return sum(num for num in numbers if num % 2 == 0)", review='Code review:\ncorrectness: The function correctly sums all even numbers in the provided list. It raises a ValueError if the input is not a list of integers, which is a necessary check for correctness.\nefficiency: The function remains efficient with a time complexity of O(n) due to the use of a generator expression. There are no unnecessary intermediate lists created, so memory usage is optimal.\nsafety: The function includes input validation, which enhances safety by preventing incorrect input types. It raises a ValueError for invalid inputs, making the function more robust against unexpected data.\napproval: APPROVE\nsuggested_changes: No further changes are necessary as the previous feedback has been adequately addressed.')
Code Writing Result:
--------------------------------------------------------------------------------
Task:
Write a function to find the sum of all even numbers in a list.
--------------------------------------------------------------------------------
Code:
def sum_of_even_numbers(numbers):
    if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):
        raise ValueError('Input must be a list of integers')
    
    return sum(num for num in numbers if num % 2 == 0)
--------------------------------------------------------------------------------
Review:
Code review:
correctness: The function correctly sums all even numbers in the provided list. It raises a ValueError if the input is not a list of integers, which is a necessary check for correctness.
efficiency: The function remains efficient with a time complexity of O(n) due to the use of a generator expression. There are no unnecessary intermediate lists created, so memory usage is optimal.
safety: The function includes input validation, which enhances safety by preventing incorrect input types. It raises a ValueError for invalid inputs, making the function more robust against unexpected data.
approval: APPROVE
suggested_changes: No further changes are necessary as the previous feedback has been adequately addressed.
--------------------------------------------------------------------------------

日志消息显示了编码者代理和评审者代理之间的交互过程。 最终输出展示了编码者代理生成的代码片段以及评审者代理生成的评审意见。