autogen_ext.tools.azure._ai_search 源代码

from __future__ import annotations

import asyncio
import logging
import time
from abc import ABC, abstractmethod
from contextvars import ContextVar
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    List,
    Literal,
    Optional,
    Protocol,
    Union,
)

from autogen_core import CancellationToken, Component
from autogen_core.tools import BaseTool, ToolSchema
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
from azure.search.documents.aio import SearchClient
from pydantic import BaseModel, Field

from ._config import (
    DEFAULT_API_VERSION,
    AzureAISearchConfig,
)

SearchDocument = Dict[str, Any]
MetadataDict = Dict[str, Any]
ContentDict = Dict[str, Any]

if TYPE_CHECKING:
    from azure.search.documents.aio import AsyncSearchItemPaged

    SearchResultsIterable = AsyncSearchItemPaged[SearchDocument]
else:
    SearchResultsIterable = Any

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from azure.search.documents.models import (
        VectorizableTextQuery,
        VectorizedQuery,
        VectorQuery,
    )

try:
    from azure.search.documents.models import VectorizableTextQuery, VectorizedQuery, VectorQuery

    has_azure_search = True
except ImportError:
    has_azure_search = False
    logger.error(
        "The 'azure-search-documents' package is required for this tool but was not found. "
        "Please install it with: uv add install azure-search-documents"
    )


if TYPE_CHECKING:
    from typing import Protocol

    class SearchClientProtocol(Protocol):
        async def search(self, **kwargs: Any) -> SearchResultsIterable: ...
        async def close(self) -> None: ...
else:
    SearchClientProtocol = Any

__all__ = [
    "AzureAISearchTool",
    "BaseAzureAISearchTool",
    "SearchQuery",
    "SearchResults",
    "SearchResult",
    "VectorizableTextQuery",
    "VectorizedQuery",
    "VectorQuery",
]
logger = logging.getLogger(__name__)


[文档] class SearchQuery(BaseModel): """搜索查询参数。 这个简化接口只需要一个搜索查询字符串。 所有其他参数(top、filters、vector fields等)都在工具创建时指定 而不是在查询时指定,这使得语言模型更容易生成结构化输出。 Args: query (str): 搜索查询文本。 """ query: str = Field(description="Search query text")
[文档] class SearchResult(BaseModel): """搜索结果。 Args: score (float): 搜索得分。 content (ContentDict): 文档内容。 metadata (MetadataDict): 关于文档的附加元数据。 """ score: float = Field(description="The search score") content: ContentDict = Field(description="The document content") metadata: MetadataDict = Field(description="Additional metadata about the document")
[文档] class SearchResults(BaseModel): """搜索结果的容器。 Args: results (List[SearchResult]): 搜索结果列表。 """ results: List[SearchResult] = Field(description="List of search results")
class EmbeddingProvider(Protocol): """定义嵌入生成接口的协议。""" async def _get_embedding(self, query: str) -> List[float]: """为查询文本生成嵌入向量。""" ... class EmbeddingProviderMixin: """提供嵌入生成功能的混入类。""" search_config: AzureAISearchConfig async def _get_embedding(self, query: str) -> List[float]: """为查询文本生成嵌入向量。""" if not hasattr(self, "search_config"): raise ValueError("Host class must have a search_config attribute") search_config = self.search_config embedding_provider = getattr(search_config, "embedding_provider", None) embedding_model = getattr(search_config, "embedding_model", None) if not embedding_provider or not embedding_model: raise ValueError( "Client-side embedding is not configured. `embedding_provider` and `embedding_model` must be set." ) from None if embedding_provider.lower() == "azure_openai": try: from azure.identity import DefaultAzureCredential from openai import AsyncAzureOpenAI except ImportError: raise ImportError( "Azure OpenAI SDK is required for client-side embedding generation. " "Please install it with: uv add openai azure-identity" ) from None api_key = getattr(search_config, "openai_api_key", None) api_version = getattr(search_config, "openai_api_version", "2023-11-01") endpoint = getattr(search_config, "openai_endpoint", None) if not endpoint: raise ValueError( "Azure OpenAI endpoint (`openai_endpoint`) must be provided for client-side Azure OpenAI embeddings." ) from None if api_key: azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=endpoint) else: def get_token() -> str: credential = DefaultAzureCredential() token = credential.get_token("https://cognitiveservices.azure.com/.default") if not token or not token.token: raise ValueError("Failed to acquire token using DefaultAzureCredential for Azure OpenAI.") return token.token azure_client = AsyncAzureOpenAI( azure_ad_token_provider=get_token, api_version=api_version, azure_endpoint=endpoint ) try: response = await azure_client.embeddings.create(model=embedding_model, input=query) return response.data[0].embedding except Exception as e: raise ValueError(f"Failed to generate embeddings with Azure OpenAI: {str(e)}") from e elif embedding_provider.lower() == "openai": try: from openai import AsyncOpenAI except ImportError: raise ImportError( "OpenAI SDK is required for client-side embedding generation. " "Please install it with: uv add openai" ) from None api_key = getattr(search_config, "openai_api_key", None) openai_client = AsyncOpenAI(api_key=api_key) try: response = await openai_client.embeddings.create(model=embedding_model, input=query) return response.data[0].embedding except Exception as e: raise ValueError(f"Failed to generate embeddings with OpenAI: {str(e)}") from e else: raise ValueError( f"Unsupported client-side embedding provider: {embedding_provider}. " "Currently supported providers are 'azure_openai' and 'openai'." )
[文档] class BaseAzureAISearchTool( BaseTool[SearchQuery, SearchResults], Component[AzureAISearchConfig], EmbeddingProvider, ABC ): """Azure AI 搜索工具的抽象基类。 该类定义了所有 Azure AI 搜索工具的通用接口和功能。 它处理配置管理、客户端初始化以及子类必须实现的抽象方法。 属性: search_config: 搜索服务的配置参数。 注意: 这是一个抽象基类,不应直接实例化。 请使用具体实现或 AzureAISearchTool 中的工厂方法。 """ component_config_schema = AzureAISearchConfig component_provider_override = "autogen_ext.tools.azure.BaseAzureAISearchTool" def __init__( self, name: str, endpoint: str, index_name: str, credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]], description: Optional[str] = None, api_version: str = DEFAULT_API_VERSION, query_type: Literal["simple", "full", "semantic", "vector"] = "simple", search_fields: Optional[List[str]] = None, select_fields: Optional[List[str]] = None, vector_fields: Optional[List[str]] = None, top: Optional[int] = None, filter: Optional[str] = None, semantic_config_name: Optional[str] = None, enable_caching: bool = False, cache_ttl_seconds: int = 300, embedding_provider: Optional[str] = None, embedding_model: Optional[str] = None, openai_api_key: Optional[str] = None, openai_api_version: Optional[str] = None, openai_endpoint: Optional[str] = None, ): """初始化 Azure AI 搜索工具。 Args: name (str): 此工具实例的名称 endpoint (str): Azure AI 搜索服务的完整 URL index_name (str): 要查询的搜索索引名称 credential (Union[AzureKeyCredential, TokenCredential, Dict[str, str]]): 用于身份验证的 Azure 凭证 description (Optional[str]): 可选描述,说明工具的用途 api_version (Optional[str]): 要使用的 Azure AI 搜索 API 版本 query_type (Literal["simple", "full", "semantic", "vector"]): 要执行的搜索类型 search_fields (Optional[List[str]]): 文档中要搜索的字段 select_fields (Optional[List[str]]): 搜索结果中要返回的字段 vector_fields (Optional[List[str]]): 用于向量搜索的字段 top (Optional[int]): 要返回的最大结果数 filter (Optional[str]): 用于优化搜索结果的 OData 筛选表达式 semantic_config_name (Optional[str]): 用于增强结果的语义配置名称 enable_caching (bool): 是否缓存搜索结果 cache_ttl_seconds (int): 缓存结果的持续时间(秒) embedding_provider (Optional[str]): 客户端嵌入的嵌入提供程序名称 embedding_model (Optional[str]): 客户端嵌入的模型名称 openai_api_key (Optional[str]): OpenAI/Azure OpenAI 嵌入的 API 密钥 openai_api_version (Optional[str]): Azure OpenAI 嵌入的 API 版本 openai_endpoint (Optional[str]): Azure OpenAI 嵌入的端点 URL """ if not has_azure_search: raise ImportError( "Azure Search SDK is required but not installed. " "Please install it with: pip install azure-search-documents>=11.4.0" ) if description is None: description = ( f"Search for information in the {index_name} index using Azure AI Search. " f"Supports full-text search with optional filters and semantic capabilities." ) super().__init__( args_type=SearchQuery, return_type=SearchResults, name=name, description=description, ) processed_credential = self._process_credential(credential) self.search_config: AzureAISearchConfig = AzureAISearchConfig( name=name, description=description, endpoint=endpoint, index_name=index_name, credential=processed_credential, api_version=api_version, query_type=query_type, search_fields=search_fields, select_fields=select_fields, vector_fields=vector_fields, top=top, filter=filter, semantic_config_name=semantic_config_name, enable_caching=enable_caching, cache_ttl_seconds=cache_ttl_seconds, embedding_provider=embedding_provider, embedding_model=embedding_model, openai_api_key=openai_api_key, openai_api_version=openai_api_version, openai_endpoint=openai_endpoint, ) self._endpoint = endpoint self._index_name = index_name self._credential = processed_credential self._api_version = api_version self._client: Optional[SearchClient] = None self._cache: Dict[str, Dict[str, Any]] = {} if self.search_config.api_version == "2023-11-01" and self.search_config.vector_fields: warning_message = ( f"When explicitly setting api_version='{self.search_config.api_version}' for vector search: " f"If client-side embedding is NOT configured (e.g., `embedding_model` is not set), " f"this tool defaults to service-side vectorization (VectorizableTextQuery), which may fail or have limitations with this API version. " f"If client-side embedding IS configured, the tool will use VectorizedQuery, which is generally compatible. " f"For robust vector search, consider omitting api_version (recommended to use SDK default) or use a newer API version." ) logger.warning(warning_message)
[文档] async def close(self) -> None: """显式关闭 Azure SearchClient(如需清理)。""" if self._client is not None: try: await self._client.close() except Exception: pass finally: self._client = None
def _process_credential( self, credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]] ) -> Union[AzureKeyCredential, AsyncTokenCredential]: """处理凭证以确保其是异步 SearchClient 的正确类型。 将包含 'api_key' 的字典凭证转换为 AzureKeyCredential 对象。 Args: credential: 凭证,可以是对象或字典形式 Returns: 格式正确的凭证对象 Raises: ValueError: 如果凭证字典不包含 'api_key' TypeError: 如果凭证不是受支持的类型 """ if isinstance(credential, dict): if "api_key" in credential: return AzureKeyCredential(credential["api_key"]) raise ValueError("If credential is a dict, it must contain an 'api_key' key") if isinstance(credential, (AzureKeyCredential, AsyncTokenCredential)): return credential raise TypeError("Credential must be AzureKeyCredential, AsyncTokenCredential, or a valid dict") async def _get_client(self) -> SearchClient: """获取已配置索引的搜索客户端。 Returns: SearchClient: 初始化的搜索客户端 Raises: ValueError: 如果索引不存在或认证失败 """ if self._client is not None: return self._client try: self._client = SearchClient( endpoint=self.search_config.endpoint, index_name=self.search_config.index_name, credential=self.search_config.credential, api_version=self.search_config.api_version, ) return self._client except ResourceNotFoundError as e: raise ValueError(f"Index '{self.search_config.index_name}' not found in Azure AI Search service.") from e except HttpResponseError as e: if e.status_code == 401: raise ValueError("Authentication failed. Please check your credentials.") from e elif e.status_code == 403: raise ValueError("Permission denied to access this index.") from e else: raise ValueError(f"Error connecting to Azure AI Search: {str(e)}") from e except Exception as e: raise ValueError(f"Unexpected error initializing search client: {str(e)}") from e
[文档] async def run( self, args: Union[str, Dict[str, Any], SearchQuery], cancellation_token: Optional[CancellationToken] = None ) -> SearchResults: """对 Azure AI 搜索索引执行搜索。 Args: args: 搜索查询文本或 SearchQuery 对象 cancellation_token: 用于取消操作的可选令牌 Returns: SearchResults: 包含搜索结果和元数据的容器 Raises: ValueError: 如果搜索查询为空或无效 ValueError: 如果存在认证错误或其他搜索问题 asyncio.CancelledError: 如果操作被取消 """ if isinstance(args, str): if not args.strip(): raise ValueError("Search query cannot be empty") search_query = SearchQuery(query=args) elif isinstance(args, dict) and "query" in args: search_query = SearchQuery(query=args["query"]) elif isinstance(args, SearchQuery): search_query = args else: raise ValueError("Invalid search query format. Expected string, dict with 'query', or SearchQuery") if cancellation_token is not None and cancellation_token.is_cancelled(): raise asyncio.CancelledError("Operation cancelled") cache_key = "" if self.search_config.enable_caching: cache_key_parts = [ search_query.query, str(self.search_config.top), self.search_config.query_type, ",".join(sorted(self.search_config.search_fields or [])), ",".join(sorted(self.search_config.select_fields or [])), ",".join(sorted(self.search_config.vector_fields or [])), str(self.search_config.filter or ""), str(self.search_config.semantic_config_name or ""), ] cache_key = ":".join(filter(None, cache_key_parts)) if cache_key in self._cache: cache_entry = self._cache[cache_key] cache_age = time.time() - cache_entry["timestamp"] if cache_age < self.search_config.cache_ttl_seconds: logger.debug(f"Using cached results for query: {search_query.query}") return SearchResults( results=[ SearchResult(score=r.score, content=r.content, metadata=r.metadata) for r in cache_entry["results"] ] ) try: search_kwargs: Dict[str, Any] = {} if self.search_config.query_type != "vector": search_kwargs["search_text"] = search_query.query search_kwargs["query_type"] = self.search_config.query_type if self.search_config.search_fields: search_kwargs["search_fields"] = self.search_config.search_fields # type: ignore[assignment] if self.search_config.query_type == "semantic" and self.search_config.semantic_config_name: search_kwargs["semantic_configuration_name"] = self.search_config.semantic_config_name if self.search_config.select_fields: search_kwargs["select"] = self.search_config.select_fields # type: ignore[assignment] if self.search_config.filter: search_kwargs["filter"] = str(self.search_config.filter) if self.search_config.top is not None: search_kwargs["top"] = self.search_config.top # type: ignore[assignment] if self.search_config.vector_fields and len(self.search_config.vector_fields) > 0: if not search_query.query: raise ValueError("Query text cannot be empty for vector search operations") use_client_side_embeddings = bool( self.search_config.embedding_model and self.search_config.embedding_provider ) vector_queries: List[Union[VectorizedQuery, VectorizableTextQuery]] = [] if use_client_side_embeddings: from azure.search.documents.models import VectorizedQuery embedding_vector: List[float] = await self._get_embedding(search_query.query) for field_spec in self.search_config.vector_fields: fields = field_spec if isinstance(field_spec, str) else ",".join(field_spec) vector_queries.append( VectorizedQuery( vector=embedding_vector, k_nearest_neighbors=self.search_config.top or 5, fields=fields, kind="vector", ) ) else: from azure.search.documents.models import VectorizableTextQuery for field in self.search_config.vector_fields: fields = field if isinstance(field, str) else ",".join(field) vector_queries.append( VectorizableTextQuery( # type: ignore text=search_query.query, k_nearest_neighbors=self.search_config.top or 5, fields=fields, kind="vectorizable", ) ) search_kwargs["vector_queries"] = vector_queries # type: ignore[assignment] if cancellation_token is not None: dummy_task = asyncio.create_task(asyncio.sleep(60)) cancellation_token.link_future(dummy_task) def is_cancelled() -> bool: return cancellation_token.is_cancelled() else: def is_cancelled() -> bool: return False client = await self._get_client() search_results: SearchResultsIterable = await client.search(**search_kwargs) # type: ignore[arg-type] results: List[SearchResult] = [] async for doc in search_results: if is_cancelled(): raise asyncio.CancelledError("Operation was cancelled") try: metadata: Dict[str, Any] = {} content: Dict[str, Any] = {} for key, value in doc.items(): if isinstance(key, str) and key.startswith(("@", "_")): metadata[key] = value else: content[str(key)] = value score = float(metadata.get("@search.score", 0.0)) results.append(SearchResult(score=score, content=content, metadata=metadata)) except Exception as e: logger.warning(f"Error processing search document: {e}") continue if self.search_config.enable_caching: self._cache[cache_key] = {"results": results, "timestamp": time.time()} return SearchResults(results=results) except asyncio.CancelledError: raise except Exception as e: error_msg = str(e) if isinstance(e, HttpResponseError): if hasattr(e, "message") and e.message: error_msg = e.message if "not found" in error_msg.lower(): raise ValueError(f"Index '{self.search_config.index_name}' not found.") from e elif "unauthorized" in error_msg.lower() or "401" in error_msg: raise ValueError(f"Authentication failed: {error_msg}") from e else: raise ValueError(f"Error from Azure AI Search: {error_msg}") from e
def _to_config(self) -> AzureAISearchConfig: """将当前实例转换为配置对象。""" return self.search_config @property def schema(self) -> ToolSchema: """返回该工具的架构。""" return { "name": self.name, "description": self.description, "parameters": { "type": "object", "properties": {"query": {"type": "string", "description": "Search query text"}}, "required": ["query"], "additionalProperties": False, }, "strict": True, }
[文档] def return_value_as_string(self, value: SearchResults) -> str: """将搜索结果转换为字符串表示形式。""" if not value.results: return "No results found." result_strings: List[str] = [] for i, result in enumerate(value.results, 1): content_items = [f"{k}: {str(v) if v is not None else 'None'}" for k, v in result.content.items()] content_str = ", ".join(content_items) result_strings.append(f"Result {i} (Score: {result.score:.2f}): {content_str}") return "\n".join(result_strings)
@classmethod def _validate_config( cls, config_dict: Dict[str, Any], search_type: Literal["full_text", "vector", "hybrid"] ) -> None: """验证特定搜索类型的配置。""" credential = config_dict.get("credential") if isinstance(credential, str): raise TypeError("Credential must be AzureKeyCredential, AsyncTokenCredential, or a valid dict") if isinstance(credential, dict) and "api_key" not in credential: raise ValueError("If credential is a dict, it must contain an 'api_key' key") try: _ = AzureAISearchConfig(**config_dict) except Exception as e: raise ValueError(f"Invalid configuration: {str(e)}") from e if search_type == "vector": vector_fields = config_dict.get("vector_fields") if not vector_fields or len(vector_fields) == 0: raise ValueError("vector_fields must contain at least one field name for vector search") elif search_type == "hybrid": vector_fields = config_dict.get("vector_fields") search_fields = config_dict.get("search_fields") if not vector_fields or len(vector_fields) == 0: raise ValueError("vector_fields must contain at least one field name for hybrid search") if not search_fields or len(search_fields) == 0: raise ValueError("search_fields must contain at least one field name for hybrid search") @classmethod @abstractmethod def _from_config(cls, config: AzureAISearchConfig) -> "BaseAzureAISearchTool": """从配置对象创建工具实例。 这是一个抽象方法,必须由子类实现。 """ if cls is BaseAzureAISearchTool: raise NotImplementedError( "BaseAzureAISearchTool is an abstract base class and cannot be instantiated directly. " "Use a concrete implementation like AzureAISearchTool." ) raise NotImplementedError("Subclasses must implement _from_config") @abstractmethod async def _get_embedding(self, query: str) -> List[float]: """为查询文本生成嵌入向量。""" raise NotImplementedError("Subclasses must implement _get_embedding")
_allow_private_constructor = ContextVar("_allow_private_constructor", default=False)
[文档] class AzureAISearchTool(EmbeddingProviderMixin, BaseAzureAISearchTool): """用于查询 Azure 搜索索引的 Azure AI 搜索工具。 该工具提供了简化的接口,用于使用多种搜索方法查询 Azure AI 搜索索引。 建议使用工厂方法创建针对特定搜索类型定制的实例: 1. **全文搜索**:适用于传统的基于关键词的搜索、Lucene 查询或语义重新排序的结果。 - 使用 `AzureAISearchTool.create_full_text_search()` - 支持的 `query_type`:"simple"(关键词)、"full"(Lucene)、"semantic"(语义)。 2. **向量搜索**:适用于基于向量嵌入的纯相似性搜索。 - 使用 `AzureAISearchTool.create_vector_search()` 3. **混合搜索**:结合向量搜索与全文或语义搜索,以同时获得两者的优势。 - 使用 `AzureAISearchTool.create_hybrid_search()` - 文本组件可以通过 `query_type` 参数设置为 "simple"、"full" 或 "semantic"。 每个工厂方法都会根据所选的搜索策略配置适当的默认值和验证。 .. warning:: 如果设置 `query_type="semantic"`,则还必须提供有效的 `semantic_config_name`。 此配置必须事先在 Azure AI 搜索索引中设置好。 """ component_provider_override = "autogen_ext.tools.azure.AzureAISearchTool" @classmethod def _from_config(cls, config: AzureAISearchConfig) -> "AzureAISearchTool": """从配置对象创建工具实例。 Args: config: 包含工具设置的配置对象 Returns: AzureAISearchTool: 一个已初始化的工具实例 """ token = _allow_private_constructor.set(True) try: instance = cls( name=config.name, description=config.description or "", endpoint=config.endpoint, index_name=config.index_name, credential=config.credential, api_version=config.api_version, query_type=config.query_type, search_fields=config.search_fields, select_fields=config.select_fields, vector_fields=config.vector_fields, top=config.top, filter=config.filter, semantic_config_name=config.semantic_config_name, enable_caching=config.enable_caching, cache_ttl_seconds=config.cache_ttl_seconds, embedding_provider=config.embedding_provider, embedding_model=config.embedding_model, openai_api_key=config.openai_api_key, openai_api_version=config.openai_api_version, openai_endpoint=config.openai_endpoint, ) return instance finally: _allow_private_constructor.reset(token) @classmethod def _create_from_params( cls, config_dict: Dict[str, Any], search_type: Literal["full_text", "vector", "hybrid"] ) -> "AzureAISearchTool": """用于在验证后根据参数创建实例的私有辅助方法。 Args: config_dict: 包含配置参数的字典 search_type: 用于验证的搜索类型 Returns: 配置好的 AzureAISearchTool 实例 """ cls._validate_config(config_dict, search_type) token = _allow_private_constructor.set(True) try: return cls(**config_dict) finally: _allow_private_constructor.reset(token)