autogen_ext.tools.azure._ai_search 源代码
from __future__ import annotations
import asyncio
import logging
import time
from abc import ABC, abstractmethod
from contextvars import ContextVar
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Protocol,
Union,
)
from autogen_core import CancellationToken, Component
from autogen_core.tools import BaseTool, ToolSchema
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
from azure.search.documents.aio import SearchClient
from pydantic import BaseModel, Field
from ._config import (
DEFAULT_API_VERSION,
AzureAISearchConfig,
)
SearchDocument = Dict[str, Any]
MetadataDict = Dict[str, Any]
ContentDict = Dict[str, Any]
if TYPE_CHECKING:
from azure.search.documents.aio import AsyncSearchItemPaged
SearchResultsIterable = AsyncSearchItemPaged[SearchDocument]
else:
SearchResultsIterable = Any
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from azure.search.documents.models import (
VectorizableTextQuery,
VectorizedQuery,
VectorQuery,
)
try:
from azure.search.documents.models import VectorizableTextQuery, VectorizedQuery, VectorQuery
has_azure_search = True
except ImportError:
has_azure_search = False
logger.error(
"The 'azure-search-documents' package is required for this tool but was not found. "
"Please install it with: uv add install azure-search-documents"
)
if TYPE_CHECKING:
from typing import Protocol
class SearchClientProtocol(Protocol):
async def search(self, **kwargs: Any) -> SearchResultsIterable: ...
async def close(self) -> None: ...
else:
SearchClientProtocol = Any
__all__ = [
"AzureAISearchTool",
"BaseAzureAISearchTool",
"SearchQuery",
"SearchResults",
"SearchResult",
"VectorizableTextQuery",
"VectorizedQuery",
"VectorQuery",
]
logger = logging.getLogger(__name__)
[文档]
class SearchQuery(BaseModel):
"""搜索查询参数。
这个简化接口只需要一个搜索查询字符串。
所有其他参数(top、filters、vector fields等)都在工具创建时指定
而不是在查询时指定,这使得语言模型更容易生成结构化输出。
Args:
query (str): 搜索查询文本。
"""
query: str = Field(description="Search query text")
[文档]
class SearchResult(BaseModel):
"""搜索结果。
Args:
score (float): 搜索得分。
content (ContentDict): 文档内容。
metadata (MetadataDict): 关于文档的附加元数据。
"""
score: float = Field(description="The search score")
content: ContentDict = Field(description="The document content")
metadata: MetadataDict = Field(description="Additional metadata about the document")
[文档]
class SearchResults(BaseModel):
"""搜索结果的容器。
Args:
results (List[SearchResult]): 搜索结果列表。
"""
results: List[SearchResult] = Field(description="List of search results")
class EmbeddingProvider(Protocol):
"""定义嵌入生成接口的协议。"""
async def _get_embedding(self, query: str) -> List[float]:
"""为查询文本生成嵌入向量。"""
...
class EmbeddingProviderMixin:
"""提供嵌入生成功能的混入类。"""
search_config: AzureAISearchConfig
async def _get_embedding(self, query: str) -> List[float]:
"""为查询文本生成嵌入向量。"""
if not hasattr(self, "search_config"):
raise ValueError("Host class must have a search_config attribute")
search_config = self.search_config
embedding_provider = getattr(search_config, "embedding_provider", None)
embedding_model = getattr(search_config, "embedding_model", None)
if not embedding_provider or not embedding_model:
raise ValueError(
"Client-side embedding is not configured. `embedding_provider` and `embedding_model` must be set."
) from None
if embedding_provider.lower() == "azure_openai":
try:
from azure.identity import DefaultAzureCredential
from openai import AsyncAzureOpenAI
except ImportError:
raise ImportError(
"Azure OpenAI SDK is required for client-side embedding generation. "
"Please install it with: uv add openai azure-identity"
) from None
api_key = getattr(search_config, "openai_api_key", None)
api_version = getattr(search_config, "openai_api_version", "2023-11-01")
endpoint = getattr(search_config, "openai_endpoint", None)
if not endpoint:
raise ValueError(
"Azure OpenAI endpoint (`openai_endpoint`) must be provided for client-side Azure OpenAI embeddings."
) from None
if api_key:
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=endpoint)
else:
def get_token() -> str:
credential = DefaultAzureCredential()
token = credential.get_token("https://cognitiveservices.azure.com/.default")
if not token or not token.token:
raise ValueError("Failed to acquire token using DefaultAzureCredential for Azure OpenAI.")
return token.token
azure_client = AsyncAzureOpenAI(
azure_ad_token_provider=get_token, api_version=api_version, azure_endpoint=endpoint
)
try:
response = await azure_client.embeddings.create(model=embedding_model, input=query)
return response.data[0].embedding
except Exception as e:
raise ValueError(f"Failed to generate embeddings with Azure OpenAI: {str(e)}") from e
elif embedding_provider.lower() == "openai":
try:
from openai import AsyncOpenAI
except ImportError:
raise ImportError(
"OpenAI SDK is required for client-side embedding generation. "
"Please install it with: uv add openai"
) from None
api_key = getattr(search_config, "openai_api_key", None)
openai_client = AsyncOpenAI(api_key=api_key)
try:
response = await openai_client.embeddings.create(model=embedding_model, input=query)
return response.data[0].embedding
except Exception as e:
raise ValueError(f"Failed to generate embeddings with OpenAI: {str(e)}") from e
else:
raise ValueError(
f"Unsupported client-side embedding provider: {embedding_provider}. "
"Currently supported providers are 'azure_openai' and 'openai'."
)
[文档]
class BaseAzureAISearchTool(
BaseTool[SearchQuery, SearchResults], Component[AzureAISearchConfig], EmbeddingProvider, ABC
):
"""Azure AI 搜索工具的抽象基类。
该类定义了所有 Azure AI 搜索工具的通用接口和功能。
它处理配置管理、客户端初始化以及子类必须实现的抽象方法。
属性:
search_config: 搜索服务的配置参数。
注意:
这是一个抽象基类,不应直接实例化。
请使用具体实现或 AzureAISearchTool 中的工厂方法。
"""
component_config_schema = AzureAISearchConfig
component_provider_override = "autogen_ext.tools.azure.BaseAzureAISearchTool"
def __init__(
self,
name: str,
endpoint: str,
index_name: str,
credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]],
description: Optional[str] = None,
api_version: str = DEFAULT_API_VERSION,
query_type: Literal["simple", "full", "semantic", "vector"] = "simple",
search_fields: Optional[List[str]] = None,
select_fields: Optional[List[str]] = None,
vector_fields: Optional[List[str]] = None,
top: Optional[int] = None,
filter: Optional[str] = None,
semantic_config_name: Optional[str] = None,
enable_caching: bool = False,
cache_ttl_seconds: int = 300,
embedding_provider: Optional[str] = None,
embedding_model: Optional[str] = None,
openai_api_key: Optional[str] = None,
openai_api_version: Optional[str] = None,
openai_endpoint: Optional[str] = None,
):
"""初始化 Azure AI 搜索工具。
Args:
name (str): 此工具实例的名称
endpoint (str): Azure AI 搜索服务的完整 URL
index_name (str): 要查询的搜索索引名称
credential (Union[AzureKeyCredential, TokenCredential, Dict[str, str]]): 用于身份验证的 Azure 凭证
description (Optional[str]): 可选描述,说明工具的用途
api_version (Optional[str]): 要使用的 Azure AI 搜索 API 版本
query_type (Literal["simple", "full", "semantic", "vector"]): 要执行的搜索类型
search_fields (Optional[List[str]]): 文档中要搜索的字段
select_fields (Optional[List[str]]): 搜索结果中要返回的字段
vector_fields (Optional[List[str]]): 用于向量搜索的字段
top (Optional[int]): 要返回的最大结果数
filter (Optional[str]): 用于优化搜索结果的 OData 筛选表达式
semantic_config_name (Optional[str]): 用于增强结果的语义配置名称
enable_caching (bool): 是否缓存搜索结果
cache_ttl_seconds (int): 缓存结果的持续时间(秒)
embedding_provider (Optional[str]): 客户端嵌入的嵌入提供程序名称
embedding_model (Optional[str]): 客户端嵌入的模型名称
openai_api_key (Optional[str]): OpenAI/Azure OpenAI 嵌入的 API 密钥
openai_api_version (Optional[str]): Azure OpenAI 嵌入的 API 版本
openai_endpoint (Optional[str]): Azure OpenAI 嵌入的端点 URL
"""
if not has_azure_search:
raise ImportError(
"Azure Search SDK is required but not installed. "
"Please install it with: pip install azure-search-documents>=11.4.0"
)
if description is None:
description = (
f"Search for information in the {index_name} index using Azure AI Search. "
f"Supports full-text search with optional filters and semantic capabilities."
)
super().__init__(
args_type=SearchQuery,
return_type=SearchResults,
name=name,
description=description,
)
processed_credential = self._process_credential(credential)
self.search_config: AzureAISearchConfig = AzureAISearchConfig(
name=name,
description=description,
endpoint=endpoint,
index_name=index_name,
credential=processed_credential,
api_version=api_version,
query_type=query_type,
search_fields=search_fields,
select_fields=select_fields,
vector_fields=vector_fields,
top=top,
filter=filter,
semantic_config_name=semantic_config_name,
enable_caching=enable_caching,
cache_ttl_seconds=cache_ttl_seconds,
embedding_provider=embedding_provider,
embedding_model=embedding_model,
openai_api_key=openai_api_key,
openai_api_version=openai_api_version,
openai_endpoint=openai_endpoint,
)
self._endpoint = endpoint
self._index_name = index_name
self._credential = processed_credential
self._api_version = api_version
self._client: Optional[SearchClient] = None
self._cache: Dict[str, Dict[str, Any]] = {}
if self.search_config.api_version == "2023-11-01" and self.search_config.vector_fields:
warning_message = (
f"When explicitly setting api_version='{self.search_config.api_version}' for vector search: "
f"If client-side embedding is NOT configured (e.g., `embedding_model` is not set), "
f"this tool defaults to service-side vectorization (VectorizableTextQuery), which may fail or have limitations with this API version. "
f"If client-side embedding IS configured, the tool will use VectorizedQuery, which is generally compatible. "
f"For robust vector search, consider omitting api_version (recommended to use SDK default) or use a newer API version."
)
logger.warning(warning_message)
[文档]
async def close(self) -> None:
"""显式关闭 Azure SearchClient(如需清理)。"""
if self._client is not None:
try:
await self._client.close()
except Exception:
pass
finally:
self._client = None
def _process_credential(
self, credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]]
) -> Union[AzureKeyCredential, AsyncTokenCredential]:
"""处理凭证以确保其是异步 SearchClient 的正确类型。
将包含 'api_key' 的字典凭证转换为 AzureKeyCredential 对象。
Args:
credential: 凭证,可以是对象或字典形式
Returns:
格式正确的凭证对象
Raises:
ValueError: 如果凭证字典不包含 'api_key'
TypeError: 如果凭证不是受支持的类型
"""
if isinstance(credential, dict):
if "api_key" in credential:
return AzureKeyCredential(credential["api_key"])
raise ValueError("If credential is a dict, it must contain an 'api_key' key")
if isinstance(credential, (AzureKeyCredential, AsyncTokenCredential)):
return credential
raise TypeError("Credential must be AzureKeyCredential, AsyncTokenCredential, or a valid dict")
async def _get_client(self) -> SearchClient:
"""获取已配置索引的搜索客户端。
Returns:
SearchClient: 初始化的搜索客户端
Raises:
ValueError: 如果索引不存在或认证失败
"""
if self._client is not None:
return self._client
try:
self._client = SearchClient(
endpoint=self.search_config.endpoint,
index_name=self.search_config.index_name,
credential=self.search_config.credential,
api_version=self.search_config.api_version,
)
return self._client
except ResourceNotFoundError as e:
raise ValueError(f"Index '{self.search_config.index_name}' not found in Azure AI Search service.") from e
except HttpResponseError as e:
if e.status_code == 401:
raise ValueError("Authentication failed. Please check your credentials.") from e
elif e.status_code == 403:
raise ValueError("Permission denied to access this index.") from e
else:
raise ValueError(f"Error connecting to Azure AI Search: {str(e)}") from e
except Exception as e:
raise ValueError(f"Unexpected error initializing search client: {str(e)}") from e
[文档]
async def run(
self, args: Union[str, Dict[str, Any], SearchQuery], cancellation_token: Optional[CancellationToken] = None
) -> SearchResults:
"""对 Azure AI 搜索索引执行搜索。
Args:
args: 搜索查询文本或 SearchQuery 对象
cancellation_token: 用于取消操作的可选令牌
Returns:
SearchResults: 包含搜索结果和元数据的容器
Raises:
ValueError: 如果搜索查询为空或无效
ValueError: 如果存在认证错误或其他搜索问题
asyncio.CancelledError: 如果操作被取消
"""
if isinstance(args, str):
if not args.strip():
raise ValueError("Search query cannot be empty")
search_query = SearchQuery(query=args)
elif isinstance(args, dict) and "query" in args:
search_query = SearchQuery(query=args["query"])
elif isinstance(args, SearchQuery):
search_query = args
else:
raise ValueError("Invalid search query format. Expected string, dict with 'query', or SearchQuery")
if cancellation_token is not None and cancellation_token.is_cancelled():
raise asyncio.CancelledError("Operation cancelled")
cache_key = ""
if self.search_config.enable_caching:
cache_key_parts = [
search_query.query,
str(self.search_config.top),
self.search_config.query_type,
",".join(sorted(self.search_config.search_fields or [])),
",".join(sorted(self.search_config.select_fields or [])),
",".join(sorted(self.search_config.vector_fields or [])),
str(self.search_config.filter or ""),
str(self.search_config.semantic_config_name or ""),
]
cache_key = ":".join(filter(None, cache_key_parts))
if cache_key in self._cache:
cache_entry = self._cache[cache_key]
cache_age = time.time() - cache_entry["timestamp"]
if cache_age < self.search_config.cache_ttl_seconds:
logger.debug(f"Using cached results for query: {search_query.query}")
return SearchResults(
results=[
SearchResult(score=r.score, content=r.content, metadata=r.metadata)
for r in cache_entry["results"]
]
)
try:
search_kwargs: Dict[str, Any] = {}
if self.search_config.query_type != "vector":
search_kwargs["search_text"] = search_query.query
search_kwargs["query_type"] = self.search_config.query_type
if self.search_config.search_fields:
search_kwargs["search_fields"] = self.search_config.search_fields # type: ignore[assignment]
if self.search_config.query_type == "semantic" and self.search_config.semantic_config_name:
search_kwargs["semantic_configuration_name"] = self.search_config.semantic_config_name
if self.search_config.select_fields:
search_kwargs["select"] = self.search_config.select_fields # type: ignore[assignment]
if self.search_config.filter:
search_kwargs["filter"] = str(self.search_config.filter)
if self.search_config.top is not None:
search_kwargs["top"] = self.search_config.top # type: ignore[assignment]
if self.search_config.vector_fields and len(self.search_config.vector_fields) > 0:
if not search_query.query:
raise ValueError("Query text cannot be empty for vector search operations")
use_client_side_embeddings = bool(
self.search_config.embedding_model and self.search_config.embedding_provider
)
vector_queries: List[Union[VectorizedQuery, VectorizableTextQuery]] = []
if use_client_side_embeddings:
from azure.search.documents.models import VectorizedQuery
embedding_vector: List[float] = await self._get_embedding(search_query.query)
for field_spec in self.search_config.vector_fields:
fields = field_spec if isinstance(field_spec, str) else ",".join(field_spec)
vector_queries.append(
VectorizedQuery(
vector=embedding_vector,
k_nearest_neighbors=self.search_config.top or 5,
fields=fields,
kind="vector",
)
)
else:
from azure.search.documents.models import VectorizableTextQuery
for field in self.search_config.vector_fields:
fields = field if isinstance(field, str) else ",".join(field)
vector_queries.append(
VectorizableTextQuery( # type: ignore
text=search_query.query,
k_nearest_neighbors=self.search_config.top or 5,
fields=fields,
kind="vectorizable",
)
)
search_kwargs["vector_queries"] = vector_queries # type: ignore[assignment]
if cancellation_token is not None:
dummy_task = asyncio.create_task(asyncio.sleep(60))
cancellation_token.link_future(dummy_task)
def is_cancelled() -> bool:
return cancellation_token.is_cancelled()
else:
def is_cancelled() -> bool:
return False
client = await self._get_client()
search_results: SearchResultsIterable = await client.search(**search_kwargs) # type: ignore[arg-type]
results: List[SearchResult] = []
async for doc in search_results:
if is_cancelled():
raise asyncio.CancelledError("Operation was cancelled")
try:
metadata: Dict[str, Any] = {}
content: Dict[str, Any] = {}
for key, value in doc.items():
if isinstance(key, str) and key.startswith(("@", "_")):
metadata[key] = value
else:
content[str(key)] = value
score = float(metadata.get("@search.score", 0.0))
results.append(SearchResult(score=score, content=content, metadata=metadata))
except Exception as e:
logger.warning(f"Error processing search document: {e}")
continue
if self.search_config.enable_caching:
self._cache[cache_key] = {"results": results, "timestamp": time.time()}
return SearchResults(results=results)
except asyncio.CancelledError:
raise
except Exception as e:
error_msg = str(e)
if isinstance(e, HttpResponseError):
if hasattr(e, "message") and e.message:
error_msg = e.message
if "not found" in error_msg.lower():
raise ValueError(f"Index '{self.search_config.index_name}' not found.") from e
elif "unauthorized" in error_msg.lower() or "401" in error_msg:
raise ValueError(f"Authentication failed: {error_msg}") from e
else:
raise ValueError(f"Error from Azure AI Search: {error_msg}") from e
def _to_config(self) -> AzureAISearchConfig:
"""将当前实例转换为配置对象。"""
return self.search_config
@property
def schema(self) -> ToolSchema:
"""返回该工具的架构。"""
return {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {"query": {"type": "string", "description": "Search query text"}},
"required": ["query"],
"additionalProperties": False,
},
"strict": True,
}
[文档]
def return_value_as_string(self, value: SearchResults) -> str:
"""将搜索结果转换为字符串表示形式。"""
if not value.results:
return "No results found."
result_strings: List[str] = []
for i, result in enumerate(value.results, 1):
content_items = [f"{k}: {str(v) if v is not None else 'None'}" for k, v in result.content.items()]
content_str = ", ".join(content_items)
result_strings.append(f"Result {i} (Score: {result.score:.2f}): {content_str}")
return "\n".join(result_strings)
@classmethod
def _validate_config(
cls, config_dict: Dict[str, Any], search_type: Literal["full_text", "vector", "hybrid"]
) -> None:
"""验证特定搜索类型的配置。"""
credential = config_dict.get("credential")
if isinstance(credential, str):
raise TypeError("Credential must be AzureKeyCredential, AsyncTokenCredential, or a valid dict")
if isinstance(credential, dict) and "api_key" not in credential:
raise ValueError("If credential is a dict, it must contain an 'api_key' key")
try:
_ = AzureAISearchConfig(**config_dict)
except Exception as e:
raise ValueError(f"Invalid configuration: {str(e)}") from e
if search_type == "vector":
vector_fields = config_dict.get("vector_fields")
if not vector_fields or len(vector_fields) == 0:
raise ValueError("vector_fields must contain at least one field name for vector search")
elif search_type == "hybrid":
vector_fields = config_dict.get("vector_fields")
search_fields = config_dict.get("search_fields")
if not vector_fields or len(vector_fields) == 0:
raise ValueError("vector_fields must contain at least one field name for hybrid search")
if not search_fields or len(search_fields) == 0:
raise ValueError("search_fields must contain at least one field name for hybrid search")
@classmethod
@abstractmethod
def _from_config(cls, config: AzureAISearchConfig) -> "BaseAzureAISearchTool":
"""从配置对象创建工具实例。
这是一个抽象方法,必须由子类实现。
"""
if cls is BaseAzureAISearchTool:
raise NotImplementedError(
"BaseAzureAISearchTool is an abstract base class and cannot be instantiated directly. "
"Use a concrete implementation like AzureAISearchTool."
)
raise NotImplementedError("Subclasses must implement _from_config")
@abstractmethod
async def _get_embedding(self, query: str) -> List[float]:
"""为查询文本生成嵌入向量。"""
raise NotImplementedError("Subclasses must implement _get_embedding")
_allow_private_constructor = ContextVar("_allow_private_constructor", default=False)
[文档]
class AzureAISearchTool(EmbeddingProviderMixin, BaseAzureAISearchTool):
"""用于查询 Azure 搜索索引的 Azure AI 搜索工具。
该工具提供了简化的接口,用于使用多种搜索方法查询 Azure AI 搜索索引。
建议使用工厂方法创建针对特定搜索类型定制的实例:
1. **全文搜索**:适用于传统的基于关键词的搜索、Lucene 查询或语义重新排序的结果。
- 使用 `AzureAISearchTool.create_full_text_search()`
- 支持的 `query_type`:"simple"(关键词)、"full"(Lucene)、"semantic"(语义)。
2. **向量搜索**:适用于基于向量嵌入的纯相似性搜索。
- 使用 `AzureAISearchTool.create_vector_search()`
3. **混合搜索**:结合向量搜索与全文或语义搜索,以同时获得两者的优势。
- 使用 `AzureAISearchTool.create_hybrid_search()`
- 文本组件可以通过 `query_type` 参数设置为 "simple"、"full" 或 "semantic"。
每个工厂方法都会根据所选的搜索策略配置适当的默认值和验证。
.. warning::
如果设置 `query_type="semantic"`,则还必须提供有效的 `semantic_config_name`。
此配置必须事先在 Azure AI 搜索索引中设置好。
"""
component_provider_override = "autogen_ext.tools.azure.AzureAISearchTool"
@classmethod
def _from_config(cls, config: AzureAISearchConfig) -> "AzureAISearchTool":
"""从配置对象创建工具实例。
Args:
config: 包含工具设置的配置对象
Returns:
AzureAISearchTool: 一个已初始化的工具实例
"""
token = _allow_private_constructor.set(True)
try:
instance = cls(
name=config.name,
description=config.description or "",
endpoint=config.endpoint,
index_name=config.index_name,
credential=config.credential,
api_version=config.api_version,
query_type=config.query_type,
search_fields=config.search_fields,
select_fields=config.select_fields,
vector_fields=config.vector_fields,
top=config.top,
filter=config.filter,
semantic_config_name=config.semantic_config_name,
enable_caching=config.enable_caching,
cache_ttl_seconds=config.cache_ttl_seconds,
embedding_provider=config.embedding_provider,
embedding_model=config.embedding_model,
openai_api_key=config.openai_api_key,
openai_api_version=config.openai_api_version,
openai_endpoint=config.openai_endpoint,
)
return instance
finally:
_allow_private_constructor.reset(token)
@classmethod
def _create_from_params(
cls, config_dict: Dict[str, Any], search_type: Literal["full_text", "vector", "hybrid"]
) -> "AzureAISearchTool":
"""用于在验证后根据参数创建实例的私有辅助方法。
Args:
config_dict: 包含配置参数的字典
search_type: 用于验证的搜索类型
Returns:
配置好的 AzureAISearchTool 实例
"""
cls._validate_config(config_dict, search_type)
token = _allow_private_constructor.set(True)
try:
return cls(**config_dict)
finally:
_allow_private_constructor.reset(token)
[文档]
@classmethod
def create_full_text_search(
cls,
name: str,
endpoint: str,
index_name: str,
credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]],
description: Optional[str] = None,
api_version: Optional[str] = None,
query_type: Literal["simple", "full", "semantic"] = "simple",
search_fields: Optional[List[str]] = None,
select_fields: Optional[List[str]] = None,
top: Optional[int] = 5,
filter: Optional[str] = None,
semantic_config_name: Optional[str] = None,
enable_caching: bool = False,
cache_ttl_seconds: int = 300,
) -> "AzureAISearchTool":
"""创建用于传统文本搜索的工具。
此工厂方法创建一个专为全文搜索优化的 AzureAISearchTool,
支持关键词匹配、Lucene 语法和语义搜索功能。
Args:
name: 工具实例的名称
endpoint: Azure AI 搜索服务的完整 URL
index_name: 要查询的搜索索引名称
credential: 用于身份验证的 Azure 凭据(API 密钥或令牌)
description: 可选描述,解释工具的用途
api_version: 要使用的 Azure AI 搜索 API 版本
query_type: 要执行的文本搜索类型:
• **simple** : 基本关键词搜索,匹配精确术语及其变体
• **full**: 使用 Lucene 查询语法进行高级搜索,支持复杂查询
• **semantic**: 基于 AI 的搜索,理解语义和上下文,提供增强的相关性排序
search_fields: 文档中要搜索的字段
select_fields: 搜索结果中要返回的字段
top: 要返回的最大结果数(默认:5)
filter: 用于优化搜索结果的 OData 过滤表达式
semantic_config_name: 语义配置名称(语义 query_type 必需)
enable_caching: 是否缓存搜索结果
cache_ttl_seconds: 缓存结果的持续时间(秒)
Returns:
一个初始化好的用于全文搜索的 AzureAISearchTool
Example:
.. code-block:: python
from azure.core.credentials import AzureKeyCredential
from autogen_ext.tools.azure import AzureAISearchTool
# 基本关键词搜索
tool = AzureAISearchTool.create_full_text_search(
name="doc-search",
endpoint="https://your-search.search.windows.net", # 您的 Azure AI 搜索端点
index_name="<your-index>", # 您的搜索索引名称
credential=AzureKeyCredential("<your-key>"), # 您的 Azure AI 搜索管理员密钥
query_type="simple", # 启用关键词搜索
search_fields=["content", "title"], # 必需:要搜索的字段
select_fields=["content", "title", "url"], # 可选:要返回的字段
top=5,
)
# 全文(Lucene 查询)搜索
full_text_tool = AzureAISearchTool.create_full_text_search(
name="doc-search",
endpoint="https://your-search.search.windows.net", # 您的 Azure AI 搜索端点
index_name="<your-index>", # 您的搜索索引名称
credential=AzureKeyCredential("<your-key>"), # 您的 Azure AI 搜索管理员密钥
query_type="full", # 启用 Lucene 查询语法
search_fields=["content", "title"], # 必需:要搜索的字段
select_fields=["content", "title", "url"], # 可选:要返回的字段
top=5,
)
# 带重新排序的语义搜索
# 注意:确保您的索引已启用语义配置
semantic_tool = AzureAISearchTool.create_full_text_search(
name="semantic-search",
endpoint="https://your-search.search.windows.net",
index_name="<your-index>",
credential=AzureKeyCredential("<your-key>"),
query_type="semantic", # 启用语义排序
semantic_config_name="<your-semantic-config>", # 语义搜索必需
search_fields=["content", "title"], # 必需:要搜索的字段
select_fields=["content", "title", "url"], # 可选:要返回的字段
top=5,
)
# 搜索工具可与 Agent 一起使用
# assistant = Agent("assistant", tools=[semantic_tool])
"""
if query_type == "semantic" and not semantic_config_name:
raise ValueError("semantic_config_name is required when query_type is 'semantic'")
config_dict = {
"name": name,
"endpoint": endpoint,
"index_name": index_name,
"credential": credential,
"description": description,
"api_version": api_version or DEFAULT_API_VERSION,
"query_type": query_type,
"search_fields": search_fields,
"select_fields": select_fields,
"top": top,
"filter": filter,
"semantic_config_name": semantic_config_name,
"enable_caching": enable_caching,
"cache_ttl_seconds": cache_ttl_seconds,
}
return cls._create_from_params(config_dict, "full_text")
[文档]
@classmethod
def create_vector_search(
cls,
name: str,
endpoint: str,
index_name: str,
credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]],
vector_fields: List[str],
description: Optional[str] = None,
api_version: Optional[str] = None,
select_fields: Optional[List[str]] = None,
top: int = 5,
filter: Optional[str] = None,
enable_caching: bool = False,
cache_ttl_seconds: int = 300,
embedding_provider: Optional[str] = None,
embedding_model: Optional[str] = None,
openai_api_key: Optional[str] = None,
openai_api_version: Optional[str] = None,
openai_endpoint: Optional[str] = None,
) -> "AzureAISearchTool":
"""创建用于纯向量/相似性搜索的工具。
此工厂方法创建一个专为向量搜索优化的 AzureAISearchTool,
允许使用向量嵌入进行基于语义相似性的匹配。
Args:
name: 工具实例的名称
endpoint: Azure AI 搜索服务的完整 URL
index_name: 要查询的搜索索引名称
credential: 用于身份验证的 Azure 凭据(API 密钥或令牌)
vector_fields: 用于向量搜索的字段(必需)
description: 可选描述,解释工具的用途
api_version: 要使用的 Azure AI 搜索 API 版本
select_fields: 搜索结果中要返回的字段
top: 要返回的最大结果数 / k-NN 中的 k(默认:5)
filter: 用于优化搜索结果的 OData 过滤表达式
enable_caching: 是否缓存搜索结果
cache_ttl_seconds: 缓存结果的持续时间(秒)
embedding_provider: 客户端嵌入的提供者(如 'azure_openai', 'openai')
embedding_model: 客户端嵌入的模型(如 'text-embedding-ada-002')
openai_api_key: OpenAI/Azure OpenAI 嵌入的 API 密钥
openai_api_version: Azure OpenAI 嵌入的 API 版本
openai_endpoint: Azure OpenAI 嵌入的端点 URL
Returns:
一个初始化好的用于向量搜索的 AzureAISearchTool
Raises:
ValueError: 如果 vector_fields 为空
ValueError: 如果 embedding_provider 为 'azure_openai' 但没有 openai_endpoint
ValueError: 如果缺少必需参数或参数无效
Example Usage:
.. code-block:: python
from azure.core.credentials import AzureKeyCredential
from autogen_ext.tools.azure import AzureAISearchTool
# 使用服务端向量化的向量搜索
tool = AzureAISearchTool.create_vector_search(
name="vector-search",
endpoint="https://your-search.search.windows.net", # 您的 Azure AI 搜索端点
index_name="<your-index>", # 您的搜索索引名称
credential=AzureKeyCredential("<your-key>"), # 您的 Azure AI 搜索管理员密钥
vector_fields=["content_vector"], # 您的向量字段名称
select_fields=["content", "title", "url"], # 结果中要返回的字段
top=5,
)
# 使用 Azure OpenAI 嵌入的向量搜索
azure_openai_tool = AzureAISearchTool.create_vector_search(
name="azure-openai-vector-search",
endpoint="https://your-search.search.windows.net",
index_name="<your-index>",
credential=AzureKeyCredential("<your-key>"),
vector_fields=["content_vector"],
embedding_provider="azure_openai", # 使用 Azure OpenAI 进行嵌入
embedding_model="text-embedding-ada-002", # 要使用的嵌入模型
openai_endpoint="https://your-openai.openai.azure.com", # 您的 Azure OpenAI 端点
openai_api_key="<your-openai-key>", # 您的 Azure OpenAI 密钥
openai_api_version="2024-02-15-preview", # Azure OpenAI API 版本
select_fields=["content", "title", "url"], # 结果中要返回的字段
top=5,
)
# 使用 OpenAI 嵌入的向量搜索
openai_tool = AzureAISearchTool.create_vector_search(
name="openai-vector-search",
endpoint="https://your-search.search.windows.net",
index_name="<your-index>",
credential=AzureKeyCredential("<your-key>"),
vector_fields=["content_vector"],
embedding_provider="openai", # 使用 OpenAI 进行嵌入
embedding_model="text-embedding-ada-002", # 要使用的嵌入模型
openai_api_key="<your-openai-key>", # 您的 OpenAI API 密钥
select_fields=["content", "title", "url"], # 结果中要返回的字段
top=5,
)
# 与 Agent 一起使用该工具
# assistant = Agent("assistant", tools=[azure_openai_tool])
"""
if embedding_provider == "azure_openai" and not openai_endpoint:
raise ValueError("openai_endpoint is required when embedding_provider is 'azure_openai'")
config_dict = {
"name": name,
"endpoint": endpoint,
"index_name": index_name,
"credential": credential,
"description": description,
"api_version": api_version or DEFAULT_API_VERSION,
"query_type": "vector",
"select_fields": select_fields,
"vector_fields": vector_fields,
"top": top,
"filter": filter,
"enable_caching": enable_caching,
"cache_ttl_seconds": cache_ttl_seconds,
"embedding_provider": embedding_provider,
"embedding_model": embedding_model,
"openai_api_key": openai_api_key,
"openai_api_version": openai_api_version,
"openai_endpoint": openai_endpoint,
}
return cls._create_from_params(config_dict, "vector")
[文档]
@classmethod
def create_hybrid_search(
cls,
name: str,
endpoint: str,
index_name: str,
credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]],
vector_fields: List[str],
search_fields: List[str],
description: Optional[str] = None,
api_version: Optional[str] = None,
query_type: Literal["simple", "full", "semantic"] = "simple",
select_fields: Optional[List[str]] = None,
top: int = 5,
filter: Optional[str] = None,
semantic_config_name: Optional[str] = None,
enable_caching: bool = False,
cache_ttl_seconds: int = 300,
embedding_provider: Optional[str] = None,
embedding_model: Optional[str] = None,
openai_api_key: Optional[str] = None,
openai_api_version: Optional[str] = None,
openai_endpoint: Optional[str] = None,
) -> "AzureAISearchTool":
"""创建一个结合向量搜索和文本搜索能力的工具。
此工厂方法创建一个配置为混合搜索的 AzureAISearchTool,
它结合了向量相似度和传统文本搜索的优势。
Args:
name: 该工具实例的名称
endpoint: Azure AI 搜索服务的完整 URL
index_name: 要查询的搜索索引名称
credential: 用于身份验证的 Azure 凭证(API 密钥或令牌)
vector_fields: 用于向量搜索的字段(必填)
search_fields: 用于文本搜索的字段(必填)
description: 可选描述,说明工具的用途
api_version: 要使用的 Azure AI 搜索 API 版本
query_type: 要执行的文本搜索类型:
• **simple**: 基本关键词搜索,匹配精确术语及其变体
• **full**: 使用 Lucene 查询语法进行高级搜索,适用于复杂查询
• **semantic**: 基于 AI 的搜索,理解语义和上下文,提供增强的相关性排名
select_fields: 要在搜索结果中返回的字段
top: 要返回的最大结果数(默认:5)
filter: 用于优化搜索结果的 OData 过滤表达式
semantic_config_name: 语义配置名称(当 query_type="semantic" 时必填)
enable_caching: 是否缓存搜索结果
cache_ttl_seconds: 缓存结果的秒数
embedding_provider: 客户端嵌入的提供程序(例如 'azure_openai', 'openai')
embedding_model: 客户端嵌入的模型(例如 'text-embedding-ada-002')
openai_api_key: OpenAI/Azure OpenAI 嵌入的 API 密钥
openai_api_version: Azure OpenAI 嵌入的 API 版本
openai_endpoint: Azure OpenAI 嵌入的端点 URL
Returns:
一个初始化好的用于混合搜索的 AzureAISearchTool
Raises:
ValueError: 如果 vector_fields 或 search_fields 为空
ValueError: 如果 query_type 为 "semantic" 但没有 semantic_config_name
ValueError: 如果 embedding_provider 是 'azure_openai' 但没有 openai_endpoint
ValueError: 如果缺少必需参数或参数无效
Example:
.. code-block:: python
from azure.core.credentials import AzureKeyCredential
from autogen_ext.tools.azure import AzureAISearchTool
# 使用服务端向量化的基本混合搜索
tool = AzureAISearchTool.create_hybrid_search(
name="hybrid-search",
endpoint="https://your-search.search.windows.net", # 你的 Azure AI 搜索端点
index_name="<your-index>", # 你的搜索索引名称
credential=AzureKeyCredential("<your-key>"), # 你的 Azure AI 搜索管理员密钥
vector_fields=["content_vector"], # 你的向量字段名称
search_fields=["content", "title"], # 你的可搜索字段
top=5,
)
# 带有语义排名和 Azure OpenAI 嵌入的混合搜索
semantic_tool = AzureAISearchTool.create_hybrid_search(
name="semantic-hybrid-search",
endpoint="https://your-search.search.windows.net",
index_name="<your-index>",
credential=AzureKeyCredential("<your-key>"),
vector_fields=["content_vector"],
search_fields=["content", "title"],
query_type="semantic", # 启用语义排名
semantic_config_name="<your-semantic-config>", # 你的语义配置名称
embedding_provider="azure_openai", # 使用 Azure OpenAI 进行嵌入
embedding_model="text-embedding-ada-002", # 要使用的嵌入模型
openai_endpoint="https://your-openai.openai.azure.com", # 你的 Azure OpenAI 端点
openai_api_key="<your-openai-key>", # 你的 Azure OpenAI 密钥
openai_api_version="2024-02-15-preview", # Azure OpenAI API 版本
select_fields=["content", "title", "url"], # 要在结果中返回的字段
filter="language eq 'en'", # 可选的 OData 过滤器
top=5,
)
# 搜索工具可以与 Agent 一起使用
# assistant = Agent("assistant", tools=[semantic_tool])
"""
if query_type == "semantic" and not semantic_config_name:
raise ValueError("semantic_config_name is required when query_type is 'semantic'")
if embedding_provider == "azure_openai" and not openai_endpoint:
raise ValueError("openai_endpoint is required when embedding_provider is 'azure_openai'")
config_dict = {
"name": name,
"endpoint": endpoint,
"index_name": index_name,
"credential": credential,
"description": description,
"api_version": api_version or DEFAULT_API_VERSION,
"query_type": query_type,
"search_fields": search_fields,
"select_fields": select_fields,
"vector_fields": vector_fields,
"top": top,
"filter": filter,
"semantic_config_name": semantic_config_name,
"enable_caching": enable_caching,
"cache_ttl_seconds": cache_ttl_seconds,
"embedding_provider": embedding_provider,
"embedding_model": embedding_model,
"openai_api_key": openai_api_key,
"openai_api_version": openai_api_version,
"openai_endpoint": openai_endpoint,
}
return cls._create_from_params(config_dict, "hybrid")