from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from typing import Literal, Mapping, Optional, Sequence, TypeAlias
from pydantic import BaseModel
from typing_extensions import Any, AsyncGenerator, Required, TypedDict, Union, deprecated
from .. import CancellationToken
from .._component_config import ComponentBase
from ..tools import Tool, ToolSchema
from ._types import CreateResult, LLMMessage, RequestUsage
[文档]
class ModelFamily:
"""模型家族(model family)是指从能力角度具有相似特征的一组模型。这与具体的支持特性(如视觉识别、函数调用和JSON输出等)是不同的。
此命名空间类包含AutoGen所识别的模型家族常量。当然还存在其他模型家族,可以用字符串表示,但AutoGen会将其视为未知类型。"""
GPT_41 = "gpt-41"
GPT_45 = "gpt-45"
GPT_4O = "gpt-4o"
O1 = "o1"
O3 = "o3"
O4 = "o4"
GPT_4 = "gpt-4"
GPT_35 = "gpt-35"
R1 = "r1"
GEMINI_1_5_FLASH = "gemini-1.5-flash"
GEMINI_1_5_PRO = "gemini-1.5-pro"
GEMINI_2_0_FLASH = "gemini-2.0-flash"
GEMINI_2_5_PRO = "gemini-2.5-pro"
GEMINI_2_5_FLASH = "gemini-2.5-flash"
CLAUDE_3_HAIKU = "claude-3-haiku"
CLAUDE_3_SONNET = "claude-3-sonnet"
CLAUDE_3_OPUS = "claude-3-opus"
CLAUDE_3_5_HAIKU = "claude-3-5-haiku"
CLAUDE_3_5_SONNET = "claude-3-5-sonnet"
CLAUDE_3_7_SONNET = "claude-3-7-sonnet"
LLAMA_3_3_8B = "llama-3.3-8b"
LLAMA_3_3_70B = "llama-3.3-70b"
LLAMA_4_SCOUT = "llama-4-scout"
LLAMA_4_MAVERICK = "llama-4-maverick"
CODESRAL = "codestral"
OPEN_CODESRAL_MAMBA = "open-codestral-mamba"
MISTRAL = "mistral"
MINISTRAL = "ministral"
PIXTRAL = "pixtral"
UNKNOWN = "unknown"
ANY: TypeAlias = Literal[
# openai_models
"gpt-41",
"gpt-45",
"gpt-4o",
"o1",
"o3",
"o4",
"gpt-4",
"gpt-35",
"r1",
# google_models
"gemini-1.5-flash",
"gemini-1.5-pro",
"gemini-2.0-flash",
"gemini-2.5-pro",
"gemini-2.5-flash"
# anthropic_models
"claude-3-haiku",
"claude-3-sonnet",
"claude-3-opus",
"claude-3-5-haiku",
"claude-3-5-sonnet",
"claude-3-7-sonnet",
# llama_models
"llama-3.3-8b",
"llama-3.3-70b",
"llama-4-scout",
"llama-4-maverick",
# mistral_models
"codestral",
"open-codestral-mamba",
"mistral",
"ministral",
"pixtral",
# unknown
"unknown",
]
def __new__(cls, *args: Any, **kwargs: Any) -> ModelFamily:
raise TypeError(f"{cls.__name__} is a namespace class and cannot be instantiated.")
[文档]
@staticmethod
def is_claude(family: str) -> bool:
return family in (
ModelFamily.CLAUDE_3_HAIKU,
ModelFamily.CLAUDE_3_SONNET,
ModelFamily.CLAUDE_3_OPUS,
ModelFamily.CLAUDE_3_5_HAIKU,
ModelFamily.CLAUDE_3_5_SONNET,
ModelFamily.CLAUDE_3_7_SONNET,
)
[文档]
@staticmethod
def is_gemini(family: str) -> bool:
return family in (
ModelFamily.GEMINI_1_5_FLASH,
ModelFamily.GEMINI_1_5_PRO,
ModelFamily.GEMINI_2_0_FLASH,
ModelFamily.GEMINI_2_5_PRO,
ModelFamily.GEMINI_2_5_FLASH,
)
[文档]
@staticmethod
def is_openai(family: str) -> bool:
return family in (
ModelFamily.GPT_45,
ModelFamily.GPT_41,
ModelFamily.GPT_4O,
ModelFamily.O1,
ModelFamily.O3,
ModelFamily.O4,
ModelFamily.GPT_4,
ModelFamily.GPT_35,
)
[文档]
@staticmethod
def is_llama(family: str) -> bool:
return family in (
ModelFamily.LLAMA_3_3_8B,
ModelFamily.LLAMA_3_3_70B,
ModelFamily.LLAMA_4_SCOUT,
ModelFamily.LLAMA_4_MAVERICK,
)
[文档]
@staticmethod
def is_mistral(family: str) -> bool:
return family in (
ModelFamily.CODESRAL,
ModelFamily.OPEN_CODESRAL_MAMBA,
ModelFamily.MISTRAL,
ModelFamily.MINISTRAL,
ModelFamily.PIXTRAL,
)
[文档]
@deprecated("Use the ModelInfo class instead ModelCapabilities.")
class ModelCapabilities(TypedDict, total=False):
vision: Required[bool]
function_calling: Required[bool]
json_output: Required[bool]
[文档]
class ModelInfo(TypedDict, total=False):
"""ModelInfo是一个包含模型属性信息的字典。
预期用于模型客户端的model_info属性中。
随着我们添加更多功能,预计这个结构会不断扩展。
"""
vision: Required[bool]
"""如果模型支持视觉功能(即图像输入)则为True,否则为False。"""
function_calling: Required[bool]
"""如果模型支持函数调用则为 True,否则为 False。"""
json_output: Required[bool]
"""如果模型支持 json 输出则为 True,否则为 False。注意:这与结构化 json 不同。"""
family: Required[ModelFamily.ANY | str]
"""模型家族应为 :py:class:`ModelFamily` 中的常量之一,或表示未知模型家族的字符串。"""
structured_output: Required[bool]
"""如果模型支持结构化输出则为 True,否则为 False。这与 json_output 不同。"""
multiple_system_messages: Optional[bool]
"""如果模型支持多个非连续的系统消息则为 True,否则为 False。"""
[文档]
def validate_model_info(model_info: ModelInfo) -> None:
"""验证模型信息字典。
抛出:
ValueError: 如果模型信息字典缺少必填字段。
"""
required_fields = ["vision", "function_calling", "json_output", "family"]
for field in required_fields:
if field not in model_info:
raise ValueError(
f"Missing required field '{field}' in ModelInfo. "
"Starting in v0.4.7, the required fields are enforced."
)
new_required_fields = ["structured_output"]
for field in new_required_fields:
if field not in model_info:
warnings.warn(
f"Missing required field '{field}' in ModelInfo. "
"This field will be required in a future version of AutoGen.",
UserWarning,
stacklevel=2,
)
[文档]
class ChatCompletionClient(ComponentBase[BaseModel], ABC):
# Caching has to be handled internally as they can depend on the create args that were stored in the constructor
[文档]
@abstractmethod
async def create(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
"""从模型创建单个响应。
Args:
messages (Sequence[LLMMessage]): 发送给模型的消息。
tools (Sequence[Tool | ToolSchema], optional): 与模型一起使用的工具。默认为 []。
json_output (Optional[bool | type[BaseModel]], optional): 是否使用 JSON 模式、结构化输出或都不使用。
默认为 None。如果设置为 `Pydantic BaseModel <https://docs.pydantic.dev/latest/usage/models/#model>`_ 类型,
将用作结构化输出的输出类型。
如果设置为布尔值,将用于确定是否使用 JSON 模式。
如果设置为 `True`,请确保在指令或提示中指示模型生成 JSON 输出。
extra_create_args (Mapping[str, Any], optional): 传递给底层客户端的额外参数。默认为 {}。
cancellation_token (Optional[CancellationToken], optional): 用于取消的令牌。默认为 None。
Returns:
CreateResult: 模型调用的结果。
"""
...
[文档]
@abstractmethod
def create_stream(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]:
"""从模型创建以 CreateResult 结尾的字符串块流。
Args:
messages (Sequence[LLMMessage]): 发送给模型的消息。
tools (Sequence[Tool | ToolSchema], optional): 与模型一起使用的工具。默认为 []。
json_output (Optional[bool | type[BaseModel]], optional): 是否使用 JSON 模式、结构化输出或都不使用。
默认为 None。如果设置为 `Pydantic BaseModel <https://docs.pydantic.dev/latest/usage/models/#model>`_ 类型,
将用作结构化输出的输出类型。
如果设置为布尔值,将用于确定是否使用 JSON 模式。
如果设置为 `True`,请确保在指令或提示中指示模型生成 JSON 输出。
extra_create_args (Mapping[str, Any], optional): 传递给底层客户端的额外参数。默认为 {}。
cancellation_token (Optional[CancellationToken], optional): 用于取消的令牌。默认为 None。
Returns:
AsyncGenerator[Union[str, CreateResult], None]: 生成字符串块并以 :py:class:`CreateResult` 结尾的生成器。
"""
...
[文档]
@abstractmethod
async def close(self) -> None: ...
[文档]
@abstractmethod
def actual_usage(self) -> RequestUsage: ...
[文档]
@abstractmethod
def total_usage(self) -> RequestUsage: ...
[文档]
@abstractmethod
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ...
[文档]
@abstractmethod
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ...
# Deprecated
@property
@abstractmethod
def capabilities(self) -> ModelCapabilities: ... # type: ignore
@property
@abstractmethod
def model_info(self) -> ModelInfo:
warnings.warn(
"Model client in use does not implement model_info property. Falling back to capabilities property. The capabilities property is deprecated and will be removed soon, please implement model_info instead in the model client class.",
stacklevel=2,
)
base_info: ModelInfo = self.capabilities # type: ignore
base_info["family"] = ModelFamily.UNKNOWN
return base_info