autogen_ext.models.openai._openai_client 源代码

import asyncio
import inspect
import json
import logging
import math
import os
import re
import warnings
from asyncio import Task
from dataclasses import dataclass
from importlib.metadata import PackageNotFoundError, version
from typing import (
    Any,
    AsyncGenerator,
    Callable,
    Dict,
    List,
    Mapping,
    Optional,
    Sequence,
    Set,
    Type,
    Union,
    cast,
)

import tiktoken
from autogen_core import (
    EVENT_LOGGER_NAME,
    TRACE_LOGGER_NAME,
    CancellationToken,
    Component,
    FunctionCall,
    Image,
)
from autogen_core.logging import LLMCallEvent, LLMStreamEndEvent, LLMStreamStartEvent
from autogen_core.models import (
    AssistantMessage,
    ChatCompletionClient,
    ChatCompletionTokenLogprob,
    CreateResult,
    LLMMessage,
    ModelCapabilities,  # type: ignore
    ModelFamily,
    ModelInfo,
    RequestUsage,
    SystemMessage,
    TopLogprob,
    UserMessage,
    validate_model_info,
)
from autogen_core.tools import Tool, ToolSchema
from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import (
    ChatCompletion,
    ChatCompletionChunk,
    ChatCompletionContentPartParam,
    ChatCompletionMessageParam,
    ChatCompletionRole,
    ChatCompletionToolParam,
    ParsedChatCompletion,
    ParsedChoice,
    completion_create_params,
)
from openai.types.chat.chat_completion import Choice
from openai.types.shared_params import (
    FunctionDefinition,
    FunctionParameters,
    ResponseFormatJSONObject,
    ResponseFormatText,
)
from pydantic import BaseModel, SecretStr
from typing_extensions import Self, Unpack

from .._utils.normalize_stop_reason import normalize_stop_reason
from .._utils.parse_r1_content import parse_r1_content
from . import _model_info
from ._transformation import (
    get_transformer,
)
from ._utils import assert_valid_name
from .config import (
    AzureOpenAIClientConfiguration,
    AzureOpenAIClientConfigurationConfigModel,
    OpenAIClientConfiguration,
    OpenAIClientConfigurationConfigModel,
)

logger = logging.getLogger(EVENT_LOGGER_NAME)
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)

openai_init_kwargs = set(inspect.getfullargspec(AsyncOpenAI.__init__).kwonlyargs)
aopenai_init_kwargs = set(inspect.getfullargspec(AsyncAzureOpenAI.__init__).kwonlyargs)

create_kwargs = set(completion_create_params.CompletionCreateParamsBase.__annotations__.keys()) | set(
    ("timeout", "stream")
)
# Only single choice allowed
disallowed_create_args = set(["stream", "messages", "function_call", "functions", "n"])
required_create_args: Set[str] = set(["model"])

USER_AGENT_HEADER_NAME = "User-Agent"

try:
    version_info = version("autogen-ext")
except PackageNotFoundError:
    version_info = "dev"
AZURE_OPENAI_USER_AGENT = f"autogen-python/{version_info}"


def _azure_openai_client_from_config(config: Mapping[str, Any]) -> AsyncAzureOpenAI:
    # Take a copy
    copied_config = dict(config).copy()
    # Shave down the config to just the AzureOpenAIChatCompletionClient kwargs
    azure_config = {k: v for k, v in copied_config.items() if k in aopenai_init_kwargs}

    DEFAULT_HEADERS_KEY = "default_headers"
    if DEFAULT_HEADERS_KEY not in azure_config:
        azure_config[DEFAULT_HEADERS_KEY] = {}

    azure_config[DEFAULT_HEADERS_KEY][USER_AGENT_HEADER_NAME] = (
        f"{AZURE_OPENAI_USER_AGENT} {azure_config[DEFAULT_HEADERS_KEY][USER_AGENT_HEADER_NAME]}"
        if USER_AGENT_HEADER_NAME in azure_config[DEFAULT_HEADERS_KEY]
        else AZURE_OPENAI_USER_AGENT
    )

    return AsyncAzureOpenAI(**azure_config)


def _openai_client_from_config(config: Mapping[str, Any]) -> AsyncOpenAI:
    # Shave down the config to just the OpenAI kwargs
    openai_config = {k: v for k, v in config.items() if k in openai_init_kwargs}
    return AsyncOpenAI(**openai_config)


def _create_args_from_config(config: Mapping[str, Any]) -> Dict[str, Any]:
    create_args = {k: v for k, v in config.items() if k in create_kwargs}
    create_args_keys = set(create_args.keys())
    if not required_create_args.issubset(create_args_keys):
        raise ValueError(f"Required create args are missing: {required_create_args - create_args_keys}")
    if disallowed_create_args.intersection(create_args_keys):
        raise ValueError(f"Disallowed create args are present: {disallowed_create_args.intersection(create_args_keys)}")
    return create_args


# TODO check types
# oai_system_message_schema = type2schema(ChatCompletionSystemMessageParam)
# oai_user_message_schema = type2schema(ChatCompletionUserMessageParam)
# oai_assistant_message_schema = type2schema(ChatCompletionAssistantMessageParam)
# oai_tool_message_schema = type2schema(ChatCompletionToolMessageParam)


def type_to_role(message: LLMMessage) -> ChatCompletionRole:
    if isinstance(message, SystemMessage):
        return "system"
    elif isinstance(message, UserMessage):
        return "user"
    elif isinstance(message, AssistantMessage):
        return "assistant"
    else:
        return "tool"


def to_oai_type(
    message: LLMMessage, prepend_name: bool = False, model: str = "unknown", model_family: str = ModelFamily.UNKNOWN
) -> Sequence[ChatCompletionMessageParam]:
    context = {
        "prepend_name": prepend_name,
    }
    transformers = get_transformer("openai", model, model_family)

    def raise_value_error(message: LLMMessage, context: Dict[str, Any]) -> Sequence[ChatCompletionMessageParam]:
        raise ValueError(f"Unknown message type: {type(message)}")

    transformer: Callable[[LLMMessage, Dict[str, Any]], Sequence[ChatCompletionMessageParam]] = transformers.get(
        type(message), raise_value_error
    )
    result = transformer(message, context)
    return result


def calculate_vision_tokens(image: Image, detail: str = "auto") -> int:
    MAX_LONG_EDGE = 2048
    BASE_TOKEN_COUNT = 85
    TOKENS_PER_TILE = 170
    MAX_SHORT_EDGE = 768
    TILE_SIZE = 512

    if detail == "low":
        return BASE_TOKEN_COUNT

    width, height = image.image.size

    # Scale down to fit within a MAX_LONG_EDGE x MAX_LONG_EDGE square if necessary

    if width > MAX_LONG_EDGE or height > MAX_LONG_EDGE:
        aspect_ratio = width / height
        if aspect_ratio > 1:
            # Width is greater than height
            width = MAX_LONG_EDGE
            height = int(MAX_LONG_EDGE / aspect_ratio)
        else:
            # Height is greater than or equal to width
            height = MAX_LONG_EDGE
            width = int(MAX_LONG_EDGE * aspect_ratio)

    # Resize such that the shortest side is MAX_SHORT_EDGE if both dimensions exceed MAX_SHORT_EDGE
    aspect_ratio = width / height
    if width > MAX_SHORT_EDGE and height > MAX_SHORT_EDGE:
        if aspect_ratio > 1:
            # Width is greater than height
            height = MAX_SHORT_EDGE
            width = int(MAX_SHORT_EDGE * aspect_ratio)
        else:
            # Height is greater than or equal to width
            width = MAX_SHORT_EDGE
            height = int(MAX_SHORT_EDGE / aspect_ratio)

    # Calculate the number of tiles based on TILE_SIZE

    tiles_width = math.ceil(width / TILE_SIZE)
    tiles_height = math.ceil(height / TILE_SIZE)
    total_tiles = tiles_width * tiles_height
    # Calculate the total tokens based on the number of tiles and the base token count

    total_tokens = BASE_TOKEN_COUNT + TOKENS_PER_TILE * total_tiles

    return total_tokens


def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage:
    return RequestUsage(
        prompt_tokens=usage1.prompt_tokens + usage2.prompt_tokens,
        completion_tokens=usage1.completion_tokens + usage2.completion_tokens,
    )


def convert_tools(
    tools: Sequence[Tool | ToolSchema],
) -> List[ChatCompletionToolParam]:
    result: List[ChatCompletionToolParam] = []
    for tool in tools:
        if isinstance(tool, Tool):
            tool_schema = tool.schema
        else:
            assert isinstance(tool, dict)
            tool_schema = tool

        result.append(
            ChatCompletionToolParam(
                type="function",
                function=FunctionDefinition(
                    name=tool_schema["name"],
                    description=(tool_schema["description"] if "description" in tool_schema else ""),
                    parameters=(
                        cast(FunctionParameters, tool_schema["parameters"]) if "parameters" in tool_schema else {}
                    ),
                    strict=(tool_schema["strict"] if "strict" in tool_schema else False),
                ),
            )
        )
    # Check if all tools have valid names.
    for tool_param in result:
        assert_valid_name(tool_param["function"]["name"])
    return result


def normalize_name(name: str) -> str:
    """
    LLMs有时会请求函数调用但忽略自身的格式要求,此函数用于将无效字符替换为"_"。

    对于验证用户配置或输入,建议优先使用_assert_valid_name
    """
    return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]


def count_tokens_openai(
    messages: Sequence[LLMMessage],
    model: str,
    *,
    add_name_prefixes: bool = False,
    tools: Sequence[Tool | ToolSchema] = [],
    model_family: str = ModelFamily.UNKNOWN,
) -> int:
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        trace_logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
        encoding = tiktoken.get_encoding("cl100k_base")
    tokens_per_message = 3
    tokens_per_name = 1
    num_tokens = 0

    # Message tokens.
    for message in messages:
        num_tokens += tokens_per_message
        oai_message = to_oai_type(message, prepend_name=add_name_prefixes, model=model, model_family=model_family)
        for oai_message_part in oai_message:
            for key, value in oai_message_part.items():
                if value is None:
                    continue

                if isinstance(message, UserMessage) and isinstance(value, list):
                    typed_message_value = cast(List[ChatCompletionContentPartParam], value)

                    assert len(typed_message_value) == len(
                        message.content
                    ), "Mismatch in message content and typed message value"

                    # We need image properties that are only in the original message
                    for part, content_part in zip(typed_message_value, message.content, strict=False):
                        if isinstance(content_part, Image):
                            # TODO: add detail parameter
                            num_tokens += calculate_vision_tokens(content_part)
                        elif isinstance(part, str):
                            num_tokens += len(encoding.encode(part))
                        else:
                            try:
                                serialized_part = json.dumps(part)
                                num_tokens += len(encoding.encode(serialized_part))
                            except TypeError:
                                trace_logger.warning(f"Could not convert {part} to string, skipping.")
                else:
                    if not isinstance(value, str):
                        try:
                            value = json.dumps(value)
                        except TypeError:
                            trace_logger.warning(f"Could not convert {value} to string, skipping.")
                            continue
                    num_tokens += len(encoding.encode(value))
                    if key == "name":
                        num_tokens += tokens_per_name
    num_tokens += 3  # every reply is primed with <|start|>assistant<|message|>

    # Tool tokens.
    oai_tools = convert_tools(tools)
    for tool in oai_tools:
        function = tool["function"]
        tool_tokens = len(encoding.encode(function["name"]))
        if "description" in function:
            tool_tokens += len(encoding.encode(function["description"]))
        tool_tokens -= 2
        if "parameters" in function:
            parameters = function["parameters"]
            if "properties" in parameters:
                assert isinstance(parameters["properties"], dict)
                for propertiesKey in parameters["properties"]:  # pyright: ignore
                    assert isinstance(propertiesKey, str)
                    tool_tokens += len(encoding.encode(propertiesKey))
                    v = parameters["properties"][propertiesKey]  # pyright: ignore
                    for field in v:  # pyright: ignore
                        if field == "type":
                            tool_tokens += 2
                            tool_tokens += len(encoding.encode(v["type"]))  # pyright: ignore
                        elif field == "description":
                            tool_tokens += 2
                            tool_tokens += len(encoding.encode(v["description"]))  # pyright: ignore
                        elif field == "enum":
                            tool_tokens -= 3
                            for o in v["enum"]:  # pyright: ignore
                                tool_tokens += 3
                                tool_tokens += len(encoding.encode(o))  # pyright: ignore
                        else:
                            trace_logger.warning(f"Not supported field {field}")
                tool_tokens += 11
                if len(parameters["properties"]) == 0:  # pyright: ignore
                    tool_tokens -= 2
        num_tokens += tool_tokens
    num_tokens += 12
    return num_tokens


@dataclass
class CreateParams:
    messages: List[ChatCompletionMessageParam]
    tools: List[ChatCompletionToolParam]
    response_format: Optional[Type[BaseModel]]
    create_args: Dict[str, Any]


[文档] class BaseOpenAIChatCompletionClient(ChatCompletionClient): def __init__( self, client: Union[AsyncOpenAI, AsyncAzureOpenAI], *, create_args: Dict[str, Any], model_capabilities: Optional[ModelCapabilities] = None, # type: ignore model_info: Optional[ModelInfo] = None, add_name_prefixes: bool = False, ): self._client = client self._add_name_prefixes = add_name_prefixes if model_capabilities is None and model_info is None: try: self._model_info = _model_info.get_info(create_args["model"]) except KeyError as err: raise ValueError("model_info is required when model name is not a valid OpenAI model") from err elif model_capabilities is not None and model_info is not None: raise ValueError("model_capabilities and model_info are mutually exclusive") elif model_capabilities is not None and model_info is None: warnings.warn( "model_capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2, ) info = cast(ModelInfo, model_capabilities) info["family"] = ModelFamily.UNKNOWN self._model_info = info elif model_capabilities is None and model_info is not None: self._model_info = model_info # Validate model_info, check if all required fields are present validate_model_info(self._model_info) self._resolved_model: Optional[str] = None if "model" in create_args: self._resolved_model = _model_info.resolve_model(create_args["model"]) if ( not self._model_info["json_output"] and "response_format" in create_args and ( isinstance(create_args["response_format"], dict) and create_args["response_format"]["type"] == "json_object" ) ): raise ValueError("Model does not support JSON output.") self._create_args = create_args self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
[文档] @classmethod def create_from_config(cls, config: Dict[str, Any]) -> ChatCompletionClient: return OpenAIChatCompletionClient(**config)
def _rstrip_last_assistant_message(self, messages: Sequence[LLMMessage]) -> Sequence[LLMMessage]: """ 如果最后一条助手消息为空则移除它。 """ # When Claude models last message is AssistantMessage, It could not end with whitespace if isinstance(messages[-1], AssistantMessage): if isinstance(messages[-1].content, str): messages[-1].content = messages[-1].content.rstrip() return messages def _process_create_args( self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema], json_output: Optional[bool | type[BaseModel]], extra_create_args: Mapping[str, Any], ) -> CreateParams: # Make sure all extra_create_args are valid extra_create_args_keys = set(extra_create_args.keys()) if not create_kwargs.issuperset(extra_create_args_keys): raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}") # Copy the create args and overwrite anything in extra_create_args create_args = self._create_args.copy() create_args.update(extra_create_args) # The response format value to use for the beta client. response_format_value: Optional[Type[BaseModel]] = None if "response_format" in create_args: # Legacy support for getting beta client mode from response_format. value = create_args["response_format"] if isinstance(value, type) and issubclass(value, BaseModel): if self.model_info["structured_output"] is False: raise ValueError("Model does not support structured output.") warnings.warn( "Using response_format to specify the BaseModel for structured output type will be deprecated. " "Use json_output in create and create_stream instead.", DeprecationWarning, stacklevel=2, ) response_format_value = value # Remove response_format from create_args to prevent passing it twice. del create_args["response_format"] # In all other cases when response_format is set to something else, we will # use the regular client. if json_output is not None: if self.model_info["json_output"] is False and json_output is True: raise ValueError("Model does not support JSON output.") if json_output is True: # JSON mode. create_args["response_format"] = ResponseFormatJSONObject(type="json_object") elif json_output is False: # Text mode. create_args["response_format"] = ResponseFormatText(type="text") elif isinstance(json_output, type) and issubclass(json_output, BaseModel): if self.model_info["structured_output"] is False: raise ValueError("Model does not support structured output.") if response_format_value is not None: raise ValueError( "response_format and json_output cannot be set to a Pydantic model class at the same time." ) # Beta client mode with Pydantic model class. response_format_value = json_output else: raise ValueError(f"json_output must be a boolean or a Pydantic model class, got {type(json_output)}") if response_format_value is not None and "response_format" in create_args: warnings.warn( "response_format is found in extra_create_args while json_output is set to a Pydantic model class. " "Skipping the response_format in extra_create_args in favor of the json_output. " "Structured output will be used.", UserWarning, stacklevel=2, ) # If using beta client, remove response_format from create_args to prevent passing it twice del create_args["response_format"] # TODO: allow custom handling. # For now we raise an error if images are present and vision is not supported if self.model_info["vision"] is False: for message in messages: if isinstance(message, UserMessage): if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content): raise ValueError("Model does not support vision and image was provided") if self.model_info["json_output"] is False and json_output is True: raise ValueError("Model does not support JSON output.") if not self.model_info.get("multiple_system_messages", False): # Some models accept only one system message(or, it will read only the last one) # So, merge system messages into one (if multiple and continuous) system_message_content = "" _messages: List[LLMMessage] = [] _first_system_message_idx = -1 _last_system_message_idx = -1 # Index of the first system message for adding the merged system message at the correct position for idx, message in enumerate(messages): if isinstance(message, SystemMessage): if _first_system_message_idx == -1: _first_system_message_idx = idx elif _last_system_message_idx + 1 != idx: # That case, system message is not continuous # Merge system messages only contiues system messages raise ValueError( "Multiple and Not continuous system messages are not supported if model_info['multiple_system_messages'] is False" ) system_message_content += message.content + "\n" _last_system_message_idx = idx else: _messages.append(message) system_message_content = system_message_content.rstrip() if system_message_content != "": system_message = SystemMessage(content=system_message_content) _messages.insert(_first_system_message_idx, system_message) messages = _messages # in that case, for ad-hoc, we using startswith instead of model_family for code consistency if create_args.get("model", "unknown").startswith("claude-"): # When Claude models last message is AssistantMessage, It could not end with whitespace messages = self._rstrip_last_assistant_message(messages) oai_messages_nested = [ to_oai_type( m, prepend_name=self._add_name_prefixes, model=create_args.get("model", "unknown"), model_family=self._model_info["family"], ) for m in messages ] oai_messages = [item for sublist in oai_messages_nested for item in sublist] if self.model_info["function_calling"] is False and len(tools) > 0: raise ValueError("Model does not support function calling") converted_tools = convert_tools(tools) return CreateParams( messages=oai_messages, tools=converted_tools, response_format=response_format_value, create_args=create_args, )
[文档] async def create( self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = [], json_output: Optional[bool | type[BaseModel]] = None, extra_create_args: Mapping[str, Any] = {}, cancellation_token: Optional[CancellationToken] = None, ) -> CreateResult: create_params = self._process_create_args( messages, tools, json_output, extra_create_args, ) future: Union[Task[ParsedChatCompletion[BaseModel]], Task[ChatCompletion]] if create_params.response_format is not None: # Use beta client if response_format is not None future = asyncio.ensure_future( self._client.beta.chat.completions.parse( messages=create_params.messages, tools=(create_params.tools if len(create_params.tools) > 0 else NOT_GIVEN), response_format=create_params.response_format, **create_params.create_args, ) ) else: # Use the regular client future = asyncio.ensure_future( self._client.chat.completions.create( messages=create_params.messages, stream=False, tools=(create_params.tools if len(create_params.tools) > 0 else NOT_GIVEN), **create_params.create_args, ) ) if cancellation_token is not None: cancellation_token.link_future(future) result: Union[ParsedChatCompletion[BaseModel], ChatCompletion] = await future if create_params.response_format is not None: result = cast(ParsedChatCompletion[Any], result) usage = RequestUsage( # TODO backup token counting prompt_tokens=result.usage.prompt_tokens if result.usage is not None else 0, completion_tokens=(result.usage.completion_tokens if result.usage is not None else 0), ) logger.info( LLMCallEvent( messages=cast(List[Dict[str, Any]], create_params.messages), response=result.model_dump(), prompt_tokens=usage.prompt_tokens, completion_tokens=usage.completion_tokens, tools=create_params.tools, ) ) if self._resolved_model is not None: if self._resolved_model != result.model: warnings.warn( f"Resolved model mismatch: {self._resolved_model} != {result.model}. " "Model mapping in autogen_ext.models.openai may be incorrect. " f"Set the model to {result.model} to enhance token/cost estimation and suppress this warning.", stacklevel=2, ) # Limited to a single choice currently. choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], Choice] = result.choices[0] # Detect whether it is a function call or not. # We don't rely on choice.finish_reason as it is not always accurate, depending on the API used. content: Union[str, List[FunctionCall]] thought: str | None = None if choice.message.function_call is not None: raise ValueError("function_call is deprecated and is not supported by this model client.") elif choice.message.tool_calls is not None and len(choice.message.tool_calls) > 0: if choice.finish_reason != "tool_calls": warnings.warn( f"Finish reason mismatch: {choice.finish_reason} != tool_calls " "when tool_calls are present. Finish reason may not be accurate. " "This may be due to the API used that is not returning the correct finish reason.", stacklevel=2, ) if choice.message.content is not None and choice.message.content != "": # Put the content in the thought field. thought = choice.message.content # NOTE: If OAI response type changes, this will need to be updated content = [] for tool_call in choice.message.tool_calls: if not isinstance(tool_call.function.arguments, str): warnings.warn( f"Tool call function arguments field is not a string: {tool_call.function.arguments}." "This is unexpected and may due to the API used not returning the correct type. " "Attempting to convert it to string.", stacklevel=2, ) if isinstance(tool_call.function.arguments, dict): tool_call.function.arguments = json.dumps(tool_call.function.arguments) content.append( FunctionCall( id=tool_call.id, arguments=tool_call.function.arguments, name=normalize_name(tool_call.function.name), ) ) finish_reason = "tool_calls" else: # if not tool_calls, then it is a text response and we populate the content and thought fields. finish_reason = choice.finish_reason content = choice.message.content or "" # if there is a reasoning_content field, then we populate the thought field. This is for models such as R1 - direct from deepseek api. if choice.message.model_extra is not None: reasoning_content = choice.message.model_extra.get("reasoning_content") if reasoning_content is not None: thought = reasoning_content logprobs: Optional[List[ChatCompletionTokenLogprob]] = None if choice.logprobs and choice.logprobs.content: logprobs = [ ChatCompletionTokenLogprob( token=x.token, logprob=x.logprob, top_logprobs=[TopLogprob(logprob=y.logprob, bytes=y.bytes) for y in x.top_logprobs], bytes=x.bytes, ) for x in choice.logprobs.content ] # This is for local R1 models. if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1 and thought is None: thought, content = parse_r1_content(content) response = CreateResult( finish_reason=normalize_stop_reason(finish_reason), content=content, usage=usage, cached=False, logprobs=logprobs, thought=thought, ) self._total_usage = _add_usage(self._total_usage, usage) self._actual_usage = _add_usage(self._actual_usage, usage) # TODO - why is this cast needed? return response
[文档] async def create_stream( self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = [], json_output: Optional[bool | type[BaseModel]] = None, extra_create_args: Mapping[str, Any] = {}, cancellation_token: Optional[CancellationToken] = None, max_consecutive_empty_chunk_tolerance: int = 0, ) -> AsyncGenerator[Union[str, CreateResult], None]: """创建一个以 :class:`~autogen_core.models.CreateResult` 结尾的模型字符串块流。 扩展 :meth:`autogen_core.models.ChatCompletionClient.create_stream` 以支持 OpenAI API。 在流式传输中,默认行为是不返回令牌使用计数。 参见: `OpenAI API 参考了解可能的参数 <https://platform.openai.com/docs/api-reference/chat/create>`_。 您可以设置 `extra_create_args={"stream_options": {"include_usage": True}}` (如果所访问的API支持)来 返回一个最终块,其usage设置为 :class:`~autogen_core.models.RequestUsage` 对象 包含提示和完成令牌计数, 所有前面的块都将usage设为`None`。 参见: `OpenAI API 关于流选项的参考 <https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options>`_。 其他可以在 `extra_create_args` 中包含的支持参数示例: - `temperature` (float): 控制输出的随机性。较高的值(如0.8)使输出更随机,而较低的值(如0.2)使其更集中和确定。 - `max_tokens` (int): 完成中生成的最大令牌数。 - `top_p` (float): 温度采样的替代方法,称为核采样,模型考虑具有top_p概率质量的令牌结果。 - `frequency_penalty` (float): -2.0到2.0之间的值,根据新令牌在文本中的现有频率进行惩罚,降低重复短语的可能性。 - `presence_penalty` (float): -2.0到2.0之间的值,根据新令牌是否出现在文本中进行惩罚,鼓励模型讨论新主题。 """ create_params = self._process_create_args( messages, tools, json_output, extra_create_args, ) if max_consecutive_empty_chunk_tolerance != 0: warnings.warn( "The 'max_consecutive_empty_chunk_tolerance' parameter is deprecated and will be removed in the future releases. All of empty chunks will be skipped with a warning.", DeprecationWarning, stacklevel=2, ) if create_params.response_format is not None: chunks = self._create_stream_chunks_beta_client( tool_params=create_params.tools, oai_messages=create_params.messages, response_format=create_params.response_format, create_args_no_response_format=create_params.create_args, cancellation_token=cancellation_token, ) else: chunks = self._create_stream_chunks( tool_params=create_params.tools, oai_messages=create_params.messages, create_args=create_params.create_args, cancellation_token=cancellation_token, ) # Prepare data to process streaming chunks. chunk: ChatCompletionChunk | None = None stop_reason = None maybe_model = None content_deltas: List[str] = [] thought_deltas: List[str] = [] full_tool_calls: Dict[int, FunctionCall] = {} logprobs: Optional[List[ChatCompletionTokenLogprob]] = None empty_chunk_warning_has_been_issued: bool = False empty_chunk_warning_threshold: int = 10 empty_chunk_count = 0 first_chunk = True is_reasoning = False # Process the stream of chunks. async for chunk in chunks: if first_chunk: first_chunk = False # Emit the start event. logger.info( LLMStreamStartEvent( messages=cast(List[Dict[str, Any]], create_params.messages), ) ) # Set the model from the lastest chunk. maybe_model = chunk.model # Empty chunks has been observed when the endpoint is under heavy load. # https://github.com/microsoft/autogen/issues/4213 if len(chunk.choices) == 0: empty_chunk_count += 1 if not empty_chunk_warning_has_been_issued and empty_chunk_count >= empty_chunk_warning_threshold: empty_chunk_warning_has_been_issued = True warnings.warn( f"Received more than {empty_chunk_warning_threshold} consecutive empty chunks. Empty chunks are being ignored.", stacklevel=2, ) continue else: empty_chunk_count = 0 if len(chunk.choices) > 1: # This is a multi-choice chunk, we need to warn the user. warnings.warn( f"Received a chunk with {len(chunk.choices)} choices. Only the first choice will be used.", UserWarning, stacklevel=2, ) # Set the choice to the first choice in the chunk. choice = chunk.choices[0] # for liteLLM chunk usage, do the following hack keeping the pervious chunk.stop_reason (if set). # set the stop_reason for the usage chunk to the prior stop_reason stop_reason = choice.finish_reason if chunk.usage is None and stop_reason is None else stop_reason maybe_model = chunk.model reasoning_content: str | None = None if choice.delta.model_extra is not None and "reasoning_content" in choice.delta.model_extra: # If there is a reasoning_content field, then we populate the thought field. This is for models such as R1. reasoning_content = choice.delta.model_extra.get("reasoning_content") if isinstance(reasoning_content, str) and len(reasoning_content) > 0: if not is_reasoning: # Enter reasoning mode. reasoning_content = "<think>" + reasoning_content is_reasoning = True thought_deltas.append(reasoning_content) yield reasoning_content elif is_reasoning: # Exit reasoning mode. reasoning_content = "</think>" thought_deltas.append(reasoning_content) is_reasoning = False yield reasoning_content # First try get content if choice.delta.content: content_deltas.append(choice.delta.content) if len(choice.delta.content) > 0: yield choice.delta.content # NOTE: for OpenAI, tool_calls and content are mutually exclusive it seems, so we can skip the rest of the loop. # However, this may not be the case for other APIs -- we should expect this may need to be updated. continue # Otherwise, get tool calls if choice.delta.tool_calls is not None: for tool_call_chunk in choice.delta.tool_calls: idx = tool_call_chunk.index if idx not in full_tool_calls: # We ignore the type hint here because we want to fill in type when the delta provides it full_tool_calls[idx] = FunctionCall(id="", arguments="", name="") if tool_call_chunk.id is not None: full_tool_calls[idx].id += tool_call_chunk.id if tool_call_chunk.function is not None: if tool_call_chunk.function.name is not None: full_tool_calls[idx].name += tool_call_chunk.function.name if tool_call_chunk.function.arguments is not None: full_tool_calls[idx].arguments += tool_call_chunk.function.arguments if choice.logprobs and choice.logprobs.content: logprobs = [ ChatCompletionTokenLogprob( token=x.token, logprob=x.logprob, top_logprobs=[TopLogprob(logprob=y.logprob, bytes=y.bytes) for y in x.top_logprobs], bytes=x.bytes, ) for x in choice.logprobs.content ] # Finalize the CreateResult. # TODO: can we remove this? if stop_reason == "function_call": raise ValueError("Function calls are not supported in this context") # We need to get the model from the last chunk, if available. model = maybe_model or create_params.create_args["model"] model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API # Because the usage chunk is not guaranteed to be the last chunk, we need to check if it is available. if chunk and chunk.usage: prompt_tokens = chunk.usage.prompt_tokens completion_tokens = chunk.usage.completion_tokens else: prompt_tokens = 0 completion_tokens = 0 usage = RequestUsage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ) # Detect whether it is a function call or just text. content: Union[str, List[FunctionCall]] thought: str | None = None # Determine the content and thought based on what was collected if full_tool_calls: # This is a tool call response content = list(full_tool_calls.values()) if content_deltas: # Store any text alongside tool calls as thoughts thought = "".join(content_deltas) else: # This is a text response (possibly with thoughts) if content_deltas: content = "".join(content_deltas) else: warnings.warn( "No text content or tool calls are available. Model returned empty result.", stacklevel=2, ) content = "" # Set thoughts if we have any reasoning content. if thought_deltas: thought = "".join(thought_deltas).lstrip("<think>").rstrip("</think>") # This is for local R1 models whose reasoning content is within the content string. if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1 and thought is None: thought, content = parse_r1_content(content) # Create the result. result = CreateResult( finish_reason=normalize_stop_reason(stop_reason), content=content, usage=usage, cached=False, logprobs=logprobs, thought=thought, ) # Log the end of the stream. logger.info( LLMStreamEndEvent( response=result.model_dump(), prompt_tokens=usage.prompt_tokens, completion_tokens=usage.completion_tokens, ) ) # Update the total usage. self._total_usage = _add_usage(self._total_usage, usage) self._actual_usage = _add_usage(self._actual_usage, usage) # Yield the CreateResult. yield result
async def _create_stream_chunks( self, tool_params: List[ChatCompletionToolParam], oai_messages: List[ChatCompletionMessageParam], create_args: Dict[str, Any], cancellation_token: Optional[CancellationToken], ) -> AsyncGenerator[ChatCompletionChunk, None]: stream_future = asyncio.ensure_future( self._client.chat.completions.create( messages=oai_messages, stream=True, tools=tool_params if len(tool_params) > 0 else NOT_GIVEN, **create_args, ) ) if cancellation_token is not None: cancellation_token.link_future(stream_future) stream = await stream_future while True: try: chunk_future = asyncio.ensure_future(anext(stream)) if cancellation_token is not None: cancellation_token.link_future(chunk_future) chunk = await chunk_future yield chunk except StopAsyncIteration: break async def _create_stream_chunks_beta_client( self, tool_params: List[ChatCompletionToolParam], oai_messages: List[ChatCompletionMessageParam], create_args_no_response_format: Dict[str, Any], response_format: Optional[Type[BaseModel]], cancellation_token: Optional[CancellationToken], ) -> AsyncGenerator[ChatCompletionChunk, None]: async with self._client.beta.chat.completions.stream( messages=oai_messages, tools=tool_params if len(tool_params) > 0 else NOT_GIVEN, response_format=(response_format if response_format is not None else NOT_GIVEN), **create_args_no_response_format, ) as stream: while True: try: event_future = asyncio.ensure_future(anext(stream)) if cancellation_token is not None: cancellation_token.link_future(event_future) event = await event_future if event.type == "chunk": chunk = event.chunk yield chunk # We don't handle other event types from the beta client stream. # As the other event types are auxiliary to the chunk event. # See: https://github.com/openai/openai-python/blob/main/helpers.md#chat-completions-events. # Once the beta client is stable, we can move all the logic to the beta client. # Then we can consider handling other event types which may simplify the code overall. except StopAsyncIteration: break
[文档] async def close(self) -> None: await self._client.close()
[文档] def actual_usage(self) -> RequestUsage: return self._actual_usage
[文档] def total_usage(self) -> RequestUsage: return self._total_usage
[文档] def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: return count_tokens_openai( messages, self._create_args["model"], add_name_prefixes=self._add_name_prefixes, tools=tools, model_family=self._model_info["family"], )
[文档] def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: token_limit = _model_info.get_token_limit(self._create_args["model"]) return token_limit - self.count_tokens(messages, tools=tools)
@property def capabilities(self) -> ModelCapabilities: # type: ignore warnings.warn( "capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2, ) return self._model_info @property def model_info(self) -> ModelInfo: return self._model_info
[文档] class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenAIClientConfigurationConfigModel]): """OpenAI托管模型的聊天补全客户端。 使用此客户端需要安装`openai`扩展包: .. code-block:: bash pip install "autogen-ext[openai]" 此客户端也可用于兼容OpenAI的ChatCompletion端点。 **将此客户端用于非OpenAI模型未经测试且不保证可用性。** 对于非OpenAI模型,请先查看我们的`社区扩展 <https://microsoft.github.io/autogen/dev/user-guide/extensions-user-guide/index.html>`_ 以获取其他模型客户端。 参数: model (str): 使用的OpenAI模型名称。 api_key (可选, str): API密钥。**当环境变量中未找到'OPENAI_API_KEY'时必须提供。** organization (可选, str): 组织ID。 base_url (可选, str): 基础URL。**当模型未托管在OpenAI时必须提供。** timeout: (可选, float): 请求超时时间(秒)。 max_retries (可选, int): 最大重试次数。 model_info (可选, ModelInfo): 模型能力描述。**当模型名称不是有效的OpenAI模型时必须提供。** frequency_penalty (可选, float): logit_bias: (可选, dict[str, int]): max_tokens (可选, int): n (可选, int): presence_penalty (可选, float): response_format (可选, Dict[str, Any]): 响应格式。可选值包括: .. code-block:: text # 文本响应(默认) {"type": "text"} .. code-block:: text # JSON响应,需确保指示模型返回JSON {"type": "json_object"} .. code-block:: text # 结构化输出响应,带有预定义JSON模式 { "type": "json_schema", "json_schema": { "name": "模式名称,必须是标识符", "description": "模型描述", # 可通过`model_json_schema()`方法将Pydantic(v2)模型转为JSON模式 "schema": "<JSON模式本身>", # 是否在生成输出时严格遵循模式 # 设为true时模型将完全遵循`schema`字段定义的模式 # 仅支持JSON Schema的子集 # 详见https://platform.openai.com/docs/guides/structured-outputs "strict": False, # 或True }, } 对于结构化输出,建议使用 :meth:`~autogen_ext.models.openai.BaseOpenAIChatCompletionClient.create`或 :meth:`~autogen_ext.models.openai.BaseOpenAIChatCompletionClient.create_stream` 方法中的`json_output`参数而非`response_format`。 `json_output`参数更灵活,可直接指定Pydantic模型类。 seed (可选, int): stop (可选, str | List[str]): temperature (可选, float): top_p (可选, float): user (可选, str): default_headers (可选, dict[str, str]): 自定义头部,用于认证或其他定制需求。 add_name_prefixes (可选, bool): 是否在每条:class:`~autogen_core.models.UserMessage`内容前添加`source`值。 例如"this is content"变为"Reviewer said: this is content." 适用于不支持消息中`name`字段的模型。默认为False。 stream_options (可选, dict): 流式传输的附加选项。目前仅支持`include_usage`。 示例: 以下代码片段展示如何使用OpenAI模型客户端: .. code-block:: python from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_core.models import UserMessage openai_client = OpenAIChatCompletionClient( model="gpt-4o-2024-08-06", # api_key="sk-...", # 如果设置了OPENAI_API_KEY环境变量则可选 ) result = await openai_client.create([UserMessage(content="What is the capital of France?", source="user")]) # type: ignore print(result) # 使用完毕后关闭客户端 # await openai_client.close() 使用非OpenAI模型时需提供模型基础URL和模型信息。 例如使用Ollama的代码片段: .. code-block:: python from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_core.models import ModelFamily custom_model_client = OpenAIChatCompletionClient( model="deepseek-r1:1.5b", base_url="http://localhost:11434/v1", api_key="placeholder", model_info={ "vision": False, "function_calling": False, "json_output": False, "family": ModelFamily.R1, "structured_output": True, }, ) # 使用完毕后关闭客户端 # await custom_model_client.close() 使用流式模式的代码片段: .. code-block:: python import asyncio from autogen_core.models import UserMessage from autogen_ext.models.openai import OpenAIChatCompletionClient async def main() -> None: # AzureOpenAIChatCompletionClient用法类似 model_client = OpenAIChatCompletionClient(model="gpt-4o") # 假设环境变量已设置OPENAI_API_KEY messages = [UserMessage(content="Write a very short story about a dragon.", source="user")] # 创建流 stream = model_client.create_stream(messages=messages) # 遍历流并打印响应 print("流式响应:") async for response in stream: if isinstance(response, str): # 部分响应是字符串 print(response, flush=True, end="") else: # 最终响应是包含完整消息的CreateResult对象 print("\\n\\n------------\\n") print("完整响应:", flush=True) print(response.content, flush=True) # 使用完毕后关闭客户端 await model_client.close() asyncio.run(main()) 同时使用结构化输出和函数调用的代码片段: .. code-block:: python import asyncio from typing import Literal from autogen_core.models import ( AssistantMessage, FunctionExecutionResult, FunctionExecutionResultMessage, SystemMessage, UserMessage, ) from autogen_core.tools import FunctionTool from autogen_ext.models.openai import OpenAIChatCompletionClient from pydantic import BaseModel # 定义结构化输出格式 class AgentResponse(BaseModel): thoughts: str response: Literal["happy", "sad", "neutral"] # 定义作为工具调用的函数 def sentiment_analysis(text: str) -> str: \"\"\"给定文本返回情感倾向\"\"\" return "happy" if "happy" in text else "sad" if "sad" in text else "neutral" # 创建FunctionTool实例,需设置`strict=True`以支持结构化输出模式 tool = FunctionTool(sentiment_analysis, description="情感分析", strict=True) async def main() -> None: # 创建OpenAIChatCompletionClient实例 model_client = OpenAIChatCompletionClient(model="gpt-4o-mini") # 使用工具生成响应 response1 = await model_client.create( messages=[ SystemMessage(content="使用提供的工具分析输入文本情感"), UserMessage(content="I am happy.", source="user"), ], tools=[tool], ) print(response1.content) # 应为工具调用列表 # [FunctionCall(name="sentiment_analysis", arguments={"text": "I am happy."}, ...)] assert isinstance(response1.content, list) response2 = await model_client.create( messages=[ SystemMessage(content="使用提供的工具分析输入文本情感"), UserMessage(content="I am happy.", source="user"), AssistantMessage(content=response1.content, source="assistant"), FunctionExecutionResultMessage( content=[FunctionExecutionResult(content="happy", call_id=response1.content[0].id, is_error=False, name="sentiment_analysis")] ), ], # 使用结构化输出格式 json_output=AgentResponse, ) print(response2.content) # 应为结构化输出 # {"thoughts": "用户很高兴", "response": "happy"} # 使用完毕后关闭客户端 await model_client.close() asyncio.run(main()) 从配置加载客户端的示例: .. code-block:: python from autogen_core.models import ChatCompletionClient config = { "provider": "OpenAIChatCompletionClient", "config": {"model": "gpt-4o", "api_key": "REPLACE_WITH_YOUR_API_KEY"}, } client = ChatCompletionClient.load_component(config) 完整配置选项列表请参考:py:class:`OpenAIClientConfigurationConfigModel`类。 """ component_type = "model" component_config_schema = OpenAIClientConfigurationConfigModel component_provider_override = "autogen_ext.models.openai.OpenAIChatCompletionClient" def __init__(self, **kwargs: Unpack[OpenAIClientConfiguration]): if "model" not in kwargs: raise ValueError("model is required for OpenAIChatCompletionClient") model_capabilities: Optional[ModelCapabilities] = None # type: ignore self._raw_config: Dict[str, Any] = dict(kwargs).copy() copied_args = dict(kwargs).copy() if "model_capabilities" in kwargs: model_capabilities = kwargs["model_capabilities"] del copied_args["model_capabilities"] model_info: Optional[ModelInfo] = None if "model_info" in kwargs: model_info = kwargs["model_info"] del copied_args["model_info"] add_name_prefixes: bool = False if "add_name_prefixes" in kwargs: add_name_prefixes = kwargs["add_name_prefixes"] # Special handling for Gemini model. assert "model" in copied_args and isinstance(copied_args["model"], str) if copied_args["model"].startswith("gemini-"): if "base_url" not in copied_args: copied_args["base_url"] = _model_info.GEMINI_OPENAI_BASE_URL if "api_key" not in copied_args and "GEMINI_API_KEY" in os.environ: copied_args["api_key"] = os.environ["GEMINI_API_KEY"] if copied_args["model"].startswith("claude-"): if "base_url" not in copied_args: copied_args["base_url"] = _model_info.ANTHROPIC_OPENAI_BASE_URL if "api_key" not in copied_args and "ANTHROPIC_API_KEY" in os.environ: copied_args["api_key"] = os.environ["ANTHROPIC_API_KEY"] if copied_args["model"].startswith("Llama-"): if "base_url" not in copied_args: copied_args["base_url"] = _model_info.LLAMA_API_BASE_URL if "api_key" not in copied_args and "LLAMA_API_KEY" in os.environ: copied_args["api_key"] = os.environ["LLAMA_API_KEY"] client = _openai_client_from_config(copied_args) create_args = _create_args_from_config(copied_args) super().__init__( client=client, create_args=create_args, model_capabilities=model_capabilities, model_info=model_info, add_name_prefixes=add_name_prefixes, ) def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() state["_client"] = None return state def __setstate__(self, state: Dict[str, Any]) -> None: self.__dict__.update(state) self._client = _openai_client_from_config(state["_raw_config"])
[文档] def _to_config(self) -> OpenAIClientConfigurationConfigModel: copied_config = self._raw_config.copy() return OpenAIClientConfigurationConfigModel(**copied_config)
[文档] @classmethod def _from_config(cls, config: OpenAIClientConfigurationConfigModel) -> Self: copied_config = config.model_copy().model_dump(exclude_none=True) # Handle api_key as SecretStr if "api_key" in copied_config and isinstance(config.api_key, SecretStr): copied_config["api_key"] = config.api_key.get_secret_value() return cls(**copied_config)
[文档] class AzureOpenAIChatCompletionClient( BaseOpenAIChatCompletionClient, Component[AzureOpenAIClientConfigurationConfigModel] ): """Azure OpenAI托管模型的聊天补全客户端。 使用此客户端需要安装`azure`和`openai`扩展包: .. code-block:: bash pip install "autogen-ext[openai,azure]" 参数: model (str): 使用的OpenAI模型名称。 azure_endpoint (str): Azure模型端点。**Azure模型必须提供。** azure_deployment (str): Azure模型部署名称。**Azure模型必须提供。** api_version (str): API版本。**Azure模型必须提供。** azure_ad_token (str): Azure AD令牌。使用令牌认证时需提供此参数或`azure_ad_token_provider`。 azure_ad_token_provider (可选, Callable[[], Awaitable[str]] | AzureTokenProvider): Azure AD令牌提供者。使用令牌认证时需提供此参数或`azure_ad_token`。 api_key (可选, str): API密钥。使用基于密钥的认证时需提供。如果使用Azure AD令牌认证或设置了`AZURE_OPENAI_API_KEY`环境变量则可选。 timeout: (可选, float): 请求超时时间(秒)。 max_retries (可选, int): 最大重试次数。 model_info (可选, ModelInfo): 模型能力描述。**当模型名称不是有效的OpenAI模型时必须提供。** frequency_penalty (可选, float): logit_bias: (可选, dict[str, int]): max_tokens (可选, int): n (可选, int): presence_penalty (可选, float): response_format (可选, Dict[str, Any]): 响应格式。可选值包括: .. code-block:: text # 文本响应(默认) {"type": "text"} .. code-block:: text # JSON响应,需确保指示模型返回JSON {"type": "json_object"} .. code-block:: text # 结构化输出响应,带有预定义JSON模式 { "type": "json_schema", "json_schema": { "name": "模式名称,必须是标识符", "description": "模型描述", # 可通过`model_json_schema()`方法将Pydantic(v2)模型转为JSON模式 "schema": "<JSON模式本身>", # 是否在生成输出时严格遵循模式 # 设为true时模型将完全遵循`schema`字段定义的模式 # 仅支持JSON Schema的子集 # 详见https://platform.openai.com/docs/guides/structured-outputs "strict": False, # 或True }, } 对于结构化输出,建议使用 :meth:`~autogen_ext.models.openai.BaseOpenAIChatCompletionClient.create`或 :meth:`~autogen_ext.models.openai.BaseOpenAIChatCompletionClient.create_stream` 方法中的`json_output`参数而非`response_format`。 `json_output`参数更灵活,可直接指定Pydantic模型类。 seed (可选, int): stop (可选, str | List[str]): temperature (可选, float): top_p (可选, float): user (可选, str): default_headers (可选, dict[str, str]): 自定义头部,用于认证或其他定制需求。 使用此客户端需提供部署名称、Azure认知服务端点和API版本。 认证方式可选择API密钥或Azure Active Directory(AAD)令牌凭证。 以下代码片段展示如何使用AAD认证。 使用的身份必须被分配`认知服务OpenAI用户 <https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/role-based-access-control#cognitive-services-openai-user>`_角色。 .. code-block:: python from autogen_ext.auth.azure import AzureTokenProvider from autogen_ext.models.openai import AzureOpenAIChatCompletionClient from azure.identity import DefaultAzureCredential # 创建令牌提供者 token_provider = AzureTokenProvider( DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default", ) az_model_client = AzureOpenAIChatCompletionClient( azure_deployment="{your-azure-deployment}", model="{model-name, such as gpt-4o}", api_version="2024-06-01", azure_endpoint="https://{your-custom-endpoint}.openai.azure.com/", azure_ad_token_provider=token_provider, # 如果选择基于密钥的认证则可选 # api_key="sk-...", # 基于密钥的认证 ) 其他用法示例请参考:class:`OpenAIChatCompletionClient`类。 从配置加载使用基于身份认证的客户端: .. code-block:: python from autogen_core.models import ChatCompletionClient config = { "provider": "AzureOpenAIChatCompletionClient", "config": { "model": "gpt-4o-2024-05-13", "azure_endpoint": "https://{your-custom-endpoint}.openai.azure.com/", "azure_deployment": "{your-azure-deployment}", "api_version": "2024-06-01", "azure_ad_token_provider": { "provider": "autogen_ext.auth.azure.AzureTokenProvider", "config": { "provider_kind": "DefaultAzureCredential", "scopes": ["https://cognitiveservices.azure.com/.default"], }, }, }, } client = ChatCompletionClient.load_component(config) 完整配置选项列表请参考:py:class:`AzureOpenAIClientConfigurationConfigModel`类。 .. note:: 目前仅支持`DefaultAzureCredential`且不能传递额外参数。 .. note:: Azure OpenAI客户端默认设置User-Agent头为`autogen-python/{version}`。要覆盖此设置,可将环境变量`autogen_ext.models.openai.AZURE_OPENAI_USER_AGENT`设为空字符串。 直接使用Azure客户端或获取更多信息,请参阅`此处 <https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/managed-identity#chat-completions>`_。 """ component_type = "model" component_config_schema = AzureOpenAIClientConfigurationConfigModel component_provider_override = "autogen_ext.models.openai.AzureOpenAIChatCompletionClient" def __init__(self, **kwargs: Unpack[AzureOpenAIClientConfiguration]): model_capabilities: Optional[ModelCapabilities] = None # type: ignore copied_args = dict(kwargs).copy() if "model_capabilities" in kwargs: model_capabilities = kwargs["model_capabilities"] del copied_args["model_capabilities"] model_info: Optional[ModelInfo] = None if "model_info" in kwargs: model_info = kwargs["model_info"] del copied_args["model_info"] add_name_prefixes: bool = False if "add_name_prefixes" in kwargs: add_name_prefixes = kwargs["add_name_prefixes"] client = _azure_openai_client_from_config(copied_args) create_args = _create_args_from_config(copied_args) self._raw_config: Dict[str, Any] = copied_args super().__init__( client=client, create_args=create_args, model_capabilities=model_capabilities, model_info=model_info, add_name_prefixes=add_name_prefixes, ) def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() state["_client"] = None return state def __setstate__(self, state: Dict[str, Any]) -> None: self.__dict__.update(state) self._client = _azure_openai_client_from_config(state["_raw_config"])
[文档] def _to_config(self) -> AzureOpenAIClientConfigurationConfigModel: from ...auth.azure import AzureTokenProvider copied_config = self._raw_config.copy() if "azure_ad_token_provider" in copied_config: if not isinstance(copied_config["azure_ad_token_provider"], AzureTokenProvider): raise ValueError("azure_ad_token_provider must be a AzureTokenProvider to be component serialized") copied_config["azure_ad_token_provider"] = ( copied_config["azure_ad_token_provider"].dump_component().model_dump(exclude_none=True) ) return AzureOpenAIClientConfigurationConfigModel(**copied_config)
[文档] @classmethod def _from_config(cls, config: AzureOpenAIClientConfigurationConfigModel) -> Self: from ...auth.azure import AzureTokenProvider copied_config = config.model_copy().model_dump(exclude_none=True) # Handle api_key as SecretStr if "api_key" in copied_config and isinstance(config.api_key, SecretStr): copied_config["api_key"] = config.api_key.get_secret_value() if "azure_ad_token_provider" in copied_config: copied_config["azure_ad_token_provider"] = AzureTokenProvider.load_component( copied_config["azure_ad_token_provider"] ) return cls(**copied_config)