autogen_ext.models.cache._chat_completion_cache 源代码

import hashlib
import json
import warnings
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union, cast

from autogen_core import CacheStore, CancellationToken, Component, ComponentModel, InMemoryStore
from autogen_core.models import (
    ChatCompletionClient,
    CreateResult,
    LLMMessage,
    ModelCapabilities,  # type: ignore
    ModelInfo,
    RequestUsage,
)
from autogen_core.tools import Tool, ToolSchema
from pydantic import BaseModel
from typing_extensions import Self

CHAT_CACHE_VALUE_TYPE = Union[CreateResult, List[Union[str, CreateResult]]]


class ChatCompletionCacheConfig(BaseModel):
    """ """

    client: ComponentModel
    store: Optional[ComponentModel] = None


[文档] class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheConfig]): """ 对 :class:`~autogen_ext.models.cache.ChatCompletionClient` 的包装器,用于缓存底层客户端的创建结果。 缓存命中不会计入原始客户端的令牌使用量。 典型用法: 以使用 `openai` 客户端进行磁盘缓存为例。 首先安装带有必要包的 `autogen-ext`: .. code-block:: bash pip install -U "autogen-ext[openai, diskcache]" 并按如下方式使用: .. code-block:: python import asyncio import tempfile from autogen_core.models import UserMessage from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_ext.models.cache import ChatCompletionCache, CHAT_CACHE_VALUE_TYPE from autogen_ext.cache_store.diskcache import DiskCacheStore from diskcache import Cache async def main(): with tempfile.TemporaryDirectory() as tmpdirname: # 初始化原始客户端 openai_model_client = OpenAIChatCompletionClient(model="gpt-4o") # 然后初始化 CacheStore,这里使用 diskcache.Cache。 # 也可以使用 redis 例如: # from autogen_ext.cache_store.redis import RedisStore # import redis # redis_instance = redis.Redis() # cache_store = RedisCacheStore[CHAT_CACHE_VALUE_TYPE](redis_instance) cache_store = DiskCacheStore[CHAT_CACHE_VALUE_TYPE](Cache(tmpdirname)) cache_client = ChatCompletionCache(openai_model_client, cache_store) response = await cache_client.create([UserMessage(content="Hello, how are you?", source="user")]) print(response) # 应打印来自 OpenAI 的响应 response = await cache_client.create([UserMessage(content="Hello, how are you?", source="user")]) print(response) # 应打印缓存的响应 asyncio.run(main()) 现在可以像使用原始客户端一样使用 `cached_client`,但启用了缓存功能。 Args: client (ChatCompletionClient): 要包装的原始 ChatCompletionClient。 store (CacheStore): 实现 get 和 set 方法的存储对象。 用户需负责管理存储的生命周期及清理(如果需要)。 默认使用内存缓存。 """ component_type = "chat_completion_cache" component_provider_override = "autogen_ext.models.cache.ChatCompletionCache" component_config_schema = ChatCompletionCacheConfig def __init__( self, client: ChatCompletionClient, store: Optional[CacheStore[CHAT_CACHE_VALUE_TYPE]] = None, ): self.client = client self.store = store or InMemoryStore[CHAT_CACHE_VALUE_TYPE]() def _check_cache( self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema], json_output: Optional[bool | type[BaseModel]], extra_create_args: Mapping[str, Any], ) -> tuple[Optional[Union[CreateResult, List[Union[str, CreateResult]]]], str]: """ 用于检查缓存结果的辅助函数。 返回一个元组 (cached_result, cache_key)。 """ json_output_data: str | bool | None = None if isinstance(json_output, type) and issubclass(json_output, BaseModel): json_output_data = json.dumps(json_output.model_json_schema()) elif isinstance(json_output, bool): json_output_data = json_output data = { "messages": [message.model_dump() for message in messages], "tools": [(tool.schema if isinstance(tool, Tool) else tool) for tool in tools], "json_output": json_output_data, "extra_create_args": extra_create_args, } serialized_data = json.dumps(data, sort_keys=True) cache_key = hashlib.sha256(serialized_data.encode()).hexdigest() cached_result = cast(Optional[CreateResult], self.store.get(cache_key)) if cached_result is not None: return cached_result, cache_key return None, cache_key
[文档] async def create( self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = [], json_output: Optional[bool | type[BaseModel]] = None, extra_create_args: Mapping[str, Any] = {}, cancellation_token: Optional[CancellationToken] = None, ) -> CreateResult: """ ChatCompletionClient.create 的缓存版本。 如果 create 调用的结果已被缓存,将立即返回缓存结果 而不会调用底层客户端。 注意:对于缓存结果,cancellation_token 将被忽略。 """ cached_result, cache_key = self._check_cache(messages, tools, json_output, extra_create_args) if cached_result: assert isinstance(cached_result, CreateResult) cached_result.cached = True return cached_result result = await self.client.create( messages, tools=tools, json_output=json_output, extra_create_args=extra_create_args, cancellation_token=cancellation_token, ) self.store.set(cache_key, result) return result
[文档] def create_stream( self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = [], json_output: Optional[bool | type[BaseModel]] = None, extra_create_args: Mapping[str, Any] = {}, cancellation_token: Optional[CancellationToken] = None, ) -> AsyncGenerator[Union[str, CreateResult], None]: """ ChatCompletionClient.create_stream 的缓存版本。 如果调用 create_stream 的结果已被缓存,将直接返回缓存结果 而不会从底层客户端进行流式传输。 注意:对于缓存结果,cancellation_token 将被忽略。 """ async def _generator() -> AsyncGenerator[Union[str, CreateResult], None]: cached_result, cache_key = self._check_cache( messages, tools, json_output, extra_create_args, ) if cached_result: assert isinstance(cached_result, list) for result in cached_result: if isinstance(result, CreateResult): result.cached = True yield result return result_stream = self.client.create_stream( messages, tools=tools, json_output=json_output, extra_create_args=extra_create_args, cancellation_token=cancellation_token, ) output_results: List[Union[str, CreateResult]] = [] self.store.set(cache_key, output_results) async for result in result_stream: output_results.append(result) yield result return _generator()
[文档] async def close(self) -> None: await self.client.close()
[文档] def actual_usage(self) -> RequestUsage: return self.client.actual_usage()
[文档] def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: return self.client.count_tokens(messages, tools=tools)
@property def capabilities(self) -> ModelCapabilities: # type: ignore warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2) return self.client.capabilities @property def model_info(self) -> ModelInfo: return self.client.model_info
[文档] def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: return self.client.remaining_tokens(messages, tools=tools)
[文档] def total_usage(self) -> RequestUsage: return self.client.total_usage()
[文档] def _to_config(self) -> ChatCompletionCacheConfig: return ChatCompletionCacheConfig( client=self.client.dump_component(), store=self.store.dump_component() if not isinstance(self.store, InMemoryStore) else None, )
[文档] @classmethod def _from_config(cls, config: ChatCompletionCacheConfig) -> Self: client = ChatCompletionClient.load_component(config.client) store: Optional[CacheStore[CHAT_CACHE_VALUE_TYPE]] = ( CacheStore.load_component(config.store) if config.store else InMemoryStore() ) return cls(client=client, store=store)