From 361a0adff403b46d890fd86d6811c9cd32236279 Mon Sep 17 00:00:00 2001 From: Pablo Castro Date: Thu, 11 May 2023 18:11:26 -0700 Subject: [PATCH] Azure Cognitive Search datastore (#244) * Add Azure Cognitive Search as data store * Vectors support * Pull only index names when checking for existing indexes * Remove the extra url prefix option * Actually use semantic mode (L2 reranking) when enabled * Docs * Add user agent to search clients * Use async for search I/O * Use search client package reference from Azure SDK dev feed * PR feedback * Encode keys to avoid restrictions on valid characters * Updating removable dependencies list --- README.md | 14 +- datastore/factory.py | 4 + datastore/providers/azuresearch_datastore.py | 258 ++++++++++++++++++ .../removing-unused-dependencies.md | 17 +- docs/providers/azuresearch/setup.md | 29 ++ poetry.lock | 196 ++++++++++++- pyproject.toml | 7 + .../azuresearch/test_azuresearch_datastore.py | 139 ++++++++++ 8 files changed, 652 insertions(+), 12 deletions(-) create mode 100644 datastore/providers/azuresearch_datastore.py create mode 100644 docs/providers/azuresearch/setup.md create mode 100644 tests/datastore/providers/azuresearch/test_azuresearch_datastore.py diff --git a/README.md b/README.md index 47e7ae3db..8b8a1c119 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ This README provides detailed information on how to set up, develop, and deploy - [Redis](#redis) - [Llama Index](#llamaindex) - [Chroma](#chroma) + - [Azure Cognitive Search](#azure-cognitive-search) - [Running the API Locally](#running-the-api-locally) - [Testing a Localhost Plugin in ChatGPT](#testing-a-localhost-plugin-in-chatgpt) - [Personalization](#personalization) @@ -135,6 +136,11 @@ Follow these steps to quickly set up and run the ChatGPT Retrieval Plugin: export CHROMA_PERSISTENCE_DIR= export CHROMA_HOST= export CHROMA_PORT= + + # Azure Cognitive Search + export AZURESEARCH_SERVICE= + export AZURESEARCH_INDEX= + export AZURESEARCH_API_KEY= (optional, uses key-free managed identity if not set) ``` 10. Run the API locally: `poetry run start` @@ -248,7 +254,7 @@ The API requires the following environment variables to work: | Name | Required | Description | | ---------------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| `DATASTORE` | Yes | This specifies the vector database provider you want to use to store and query embeddings. You can choose from `chroma`, `pinecone`, `weaviate`, `zilliz`, `milvus`, `qdrant`, or `redis`. | +| `DATASTORE` | Yes | This specifies the vector database provider you want to use to store and query embeddings. You can choose from `chroma`, `pinecone`, `weaviate`, `zilliz`, `milvus`, `qdrant`, `redis`, `azuresearch`. | | `BEARER_TOKEN` | Yes | This is a secret token that you need to authenticate your requests to the API. You can generate one using any tool or method you prefer, such as [jwt.io](https://jwt.io/). | | `OPENAI_API_KEY` | Yes | This is your OpenAI API key that you need to generate embeddings using the `text-embedding-ada-002` model. You can get an API key by creating an account on [OpenAI](https://openai.com/). | @@ -305,6 +311,10 @@ For detailed setup instructions, refer to [`/docs/providers/llama/setup.md`](/do [Chroma](https://trychroma.com) is an AI-native open-source embedding database designed to make getting started as easy as possible. Chroma runs in-memory, or in a client-server setup. It supports metadata and keyword filtering out of the box. For detailed instructions, refer to [`/docs/providers/chroma/setup.md`](/docs/providers/chroma/setup.md). +#### Azure Cognitive Search + +[Azure Cognitive Search](https://azure.microsoft.com/products/search/) is a complete retrieval cloud service that supports vector search, text search, and hybrid (vectors + text combined to yield the best of the two approaches). It also offers an [optional L2 re-ranking step](https://learn.microsoft.com/azure/search/semantic-search-overview) to further improve results quality. For detailed setup instructions, refer to [`/docs/providers/azuresearch/setup.md`](/docs/providers/azuresearch/setup.md) + ### Running the API locally To run the API locally, you first need to set the requisite environment variables with the `export` command: @@ -448,7 +458,7 @@ The scripts are: While the ChatGPT Retrieval Plugin is designed to provide a flexible solution for semantic search and retrieval, it does have some limitations: -- **Keyword search limitations**: The embeddings generated by the `text-embedding-ada-002` model may not always be effective at capturing exact keyword matches. As a result, the plugin might not return the most relevant results for queries that rely heavily on specific keywords. Some vector databases, like Pinecone and Weaviate, use hybrid search and might perform better for keyword searches. +- **Keyword search limitations**: The embeddings generated by the `text-embedding-ada-002` model may not always be effective at capturing exact keyword matches. As a result, the plugin might not return the most relevant results for queries that rely heavily on specific keywords. Some vector databases, like Pinecone, Weaviate and Azure Cognitive Search, use hybrid search and might perform better for keyword searches. - **Sensitive data handling**: The plugin does not automatically detect or filter sensitive data. It is the responsibility of the developers to ensure that they have the necessary authorization to include content in the Retrieval Plugin and that the content complies with data privacy requirements. - **Scalability**: The performance of the plugin may vary depending on the chosen vector database provider and the size of the dataset. Some providers may offer better scalability and performance than others. - **Language support**: The plugin currently uses OpenAI's `text-embedding-ada-002` model, which is optimized for use in English. However, it is still robust enough to generate good results for a variety of languages. diff --git a/datastore/factory.py b/datastore/factory.py index e6eb48edf..026798899 100644 --- a/datastore/factory.py +++ b/datastore/factory.py @@ -39,6 +39,10 @@ async def get_datastore() -> DataStore: from datastore.providers.qdrant_datastore import QdrantDataStore return QdrantDataStore() + case "azuresearch": + from datastore.providers.azuresearch_datastore import AzureSearchDataStore + + return AzureSearchDataStore() case _: raise ValueError( f"Unsupported vector database: {datastore}. " diff --git a/datastore/providers/azuresearch_datastore.py b/datastore/providers/azuresearch_datastore.py new file mode 100644 index 000000000..4ae0182cc --- /dev/null +++ b/datastore/providers/azuresearch_datastore.py @@ -0,0 +1,258 @@ +import asyncio +import os +import re +import time +import base64 +from typing import Dict, List, Optional, Union +from datastore.datastore import DataStore +from models.models import DocumentChunk, DocumentChunkMetadata, DocumentChunkWithScore, DocumentMetadataFilter, Query, QueryResult, QueryWithEmbedding +from loguru import logger +from azure.search.documents.aio import SearchClient +from azure.search.documents.models import Vector, QueryType +from azure.search.documents.indexes import SearchIndexClient +from azure.search.documents.indexes.models import * +from azure.core.credentials import AzureKeyCredential +from azure.identity import DefaultAzureCredential as DefaultAzureCredentialSync +from azure.identity.aio import DefaultAzureCredential + +AZURESEARCH_SERVICE = os.environ.get("AZURESEARCH_SERVICE") +AZURESEARCH_INDEX = os.environ.get("AZURESEARCH_INDEX") +AZURESEARCH_API_KEY = os.environ.get("AZURESEARCH_API_KEY") +AZURESEARCH_SEMANTIC_CONFIG = os.environ.get("AZURESEARCH_SEMANTIC_CONFIG") +AZURESEARCH_LANGUAGE = os.environ.get("AZURESEARCH_LANGUAGE", "en-us") +AZURESEARCH_DISABLE_HYBRID = os.environ.get("AZURESEARCH_DISABLE_HYBRID") +AZURESEARCH_DIMENSIONS = os.environ.get("AZURESEARCH_DIMENSIONS", 1536) # Default to OpenAI's ada-002 embedding model vector size +assert AZURESEARCH_SERVICE is not None +assert AZURESEARCH_INDEX is not None + +# Allow overriding field names for Azure Search +FIELDS_ID = os.environ.get("AZURESEARCH_FIELDS_ID", "id") +FIELDS_TEXT = os.environ.get("AZURESEARCH_FIELDS_TEXT", "text") +FIELDS_EMBEDDING = os.environ.get("AZURESEARCH_FIELDS_TEXT", "embedding") +FIELDS_DOCUMENT_ID = os.environ.get("AZURESEARCH_FIELDS_DOCUMENT_ID", "document_id") +FIELDS_SOURCE = os.environ.get("AZURESEARCH_FIELDS_SOURCE", "source") +FIELDS_SOURCE_ID = os.environ.get("AZURESEARCH_FIELDS_SOURCE_ID", "source_id") +FIELDS_URL = os.environ.get("AZURESEARCH_FIELDS_URL", "url") +FIELDS_CREATED_AT = os.environ.get("AZURESEARCH_FIELDS_CREATED_AT", "created_at") +FIELDS_AUTHOR = os.environ.get("AZURESEARCH_FIELDS_AUTHOR", "author") + +MAX_UPLOAD_BATCH_SIZE = 1000 +MAX_DELETE_BATCH_SIZE = 1000 + +class AzureSearchDataStore(DataStore): + def __init__(self): + self.client = SearchClient( + endpoint=f"https://{AZURESEARCH_SERVICE}.search.windows.net", + index_name=AZURESEARCH_INDEX, + credential=AzureSearchDataStore._create_credentials(True), + user_agent="retrievalplugin" + ) + + mgmt_client = SearchIndexClient( + endpoint=f"https://{AZURESEARCH_SERVICE}.search.windows.net", + credential=AzureSearchDataStore._create_credentials(False), + user_agent="retrievalplugin" + ) + if AZURESEARCH_INDEX not in [name for name in mgmt_client.list_index_names()]: + self._create_index(mgmt_client) + else: + logger.info(f"Using existing index {AZURESEARCH_INDEX} in service {AZURESEARCH_SERVICE}") + + async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: + azdocuments: List[Dict] = [] + + async def upload(): + r = await self.client.upload_documents(documents=azdocuments) + count = sum(1 for rr in r if rr.succeeded) + logger.info(f"Upserted {count} chunks out of {len(azdocuments)}") + if count < len(azdocuments): + raise Exception(f"Failed to upload {len(azdocuments) - count} chunks") + + ids = [] + for document_id, document_chunks in chunks.items(): + ids.append(document_id) + for chunk in document_chunks: + azdocuments.append({ + # base64-encode the id string to stay within Azure Search's valid characters for keys + FIELDS_ID: base64.urlsafe_b64encode(bytes(chunk.id, "utf-8")).decode("ascii"), + FIELDS_TEXT: chunk.text, + FIELDS_EMBEDDING: chunk.embedding, + FIELDS_DOCUMENT_ID: document_id, + FIELDS_SOURCE: chunk.metadata.source, + FIELDS_SOURCE_ID: chunk.metadata.source_id, + FIELDS_URL: chunk.metadata.url, + FIELDS_CREATED_AT: chunk.metadata.created_at, + FIELDS_AUTHOR: chunk.metadata.author, + }) + + if len(azdocuments) >= MAX_UPLOAD_BATCH_SIZE: + await upload() + azdocuments = [] + + if len(azdocuments) > 0: + await upload() + + return ids + + async def delete(self, ids: Optional[List[str]] = None, filter: Optional[DocumentMetadataFilter] = None, delete_all: Optional[bool] = None) -> bool: + filter = None if delete_all else self._translate_filter(filter) + if delete_all or filter is not None: + deleted = set() + while True: + search_result = await self.client.search(None, filter=filter, top=MAX_DELETE_BATCH_SIZE, include_total_count=True, select=FIELDS_ID) + if await search_result.get_count() == 0: + break + documents = [{ FIELDS_ID: d[FIELDS_ID] } async for d in search_result if d[FIELDS_ID] not in deleted] + if len(documents) > 0: + logger.info(f"Deleting {len(documents)} chunks " + ("using a filter" if filter is not None else "using delete_all")) + del_result = await self.client.delete_documents(documents=documents) + if not all([rr.succeeded for rr in del_result]): + raise Exception("Failed to delete documents") + deleted.update([d[FIELDS_ID] for d in documents]) + else: + # All repeats, delay a bit to let the index refresh and try again + time.sleep(0.25) + + if ids is not None and len(ids) > 0: + for id in ids: + logger.info(f"Deleting chunks for document id {id}") + await self.delete(filter=DocumentMetadataFilter(document_id=id)) + + return True + + async def _query(self, queries: List[QueryWithEmbedding]) -> List[QueryResult]: + """ + Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores. + """ + return await asyncio.gather(*(self._single_query(query) for query in queries)) + + async def _single_query(self, query: QueryWithEmbedding) -> QueryResult: + """ + Takes in a single query and filters and returns a query result with matching document chunks and scores. + """ + filter = self._translate_filter(query.filter) if query.filter is not None else None + try: + k = query.top_k if filter is None else query.top_k * 2 + q = query.query if not AZURESEARCH_DISABLE_HYBRID else None + if AZURESEARCH_SEMANTIC_CONFIG != None and not AZURESEARCH_DISABLE_HYBRID: + r = await self.client.search( + q, + filter=filter, + top=query.top_k, + vector=Vector(value=query.embedding, k=k, fields=FIELDS_EMBEDDING), + query_type=QueryType.SEMANTIC, + query_language=AZURESEARCH_LANGUAGE, + semantic_configuration_name=AZURESEARCH_SEMANTIC_CONFIG) + else: + r = await self.client.search( + q, + filter=filter, + top=query.top_k, + vector=Vector(value=query.embedding, k=k, fields=FIELDS_EMBEDDING)) + results: List[DocumentChunkWithScore] = [] + async for hit in r: + f = lambda field: hit.get(field) if field != "-" else None + results.append(DocumentChunkWithScore( + id=hit[FIELDS_ID], + text=hit[FIELDS_TEXT], + metadata=DocumentChunkMetadata( + document_id=f(FIELDS_DOCUMENT_ID), + source=f(FIELDS_SOURCE), + source_id=f(FIELDS_SOURCE_ID), + url=f(FIELDS_URL), + created_at=f(FIELDS_CREATED_AT), + author=f(FIELDS_AUTHOR) + ), + score=hit["@search.score"] + )) + + return QueryResult(query=query.query, results=results) + except Exception as e: + raise Exception(f"Error querying the index: {e}") + + @staticmethod + def _translate_filter(filter: DocumentMetadataFilter) -> str: + """ + Translates a DocumentMetadataFilter into an Azure Search filter string + """ + if filter is None: + return None + + escape = lambda s: s.replace("'", "''") + + # regex to validate dates are in OData format + date_re = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z") + + filter_list = [] + if filter.document_id is not None: + filter_list.append(f"{FIELDS_DOCUMENT_ID} eq '{escape(filter.document_id)}'") + if filter.source is not None: + filter_list.append(f"{FIELDS_SOURCE} eq '{escape(filter.source)}'") + if filter.source_id is not None: + filter_list.append(f"{FIELDS_SOURCE_ID} eq '{escape(filter.source_id)}'") + if filter.author is not None: + filter_list.append(f"{FIELDS_AUTHOR} eq '{escape(filter.author)}'") + if filter.start_date is not None: + if not date_re.match(filter.start_date): + raise ValueError(f"start_date must be in OData format, got {filter.start_date}") + filter_list.append(f"{FIELDS_CREATED_AT} ge {filter.start_date}") + if filter.end_date is not None: + if not date_re.match(filter.end_date): + raise ValueError(f"end_date must be in OData format, got {filter.end_date}") + filter_list.append(f"{FIELDS_CREATED_AT} le {filter.end_date}") + return " and ".join(filter_list) if len(filter_list) > 0 else None + + def _create_index(self, mgmt_client: SearchIndexClient): + """ + Creates an Azure Cognitive Search index, including a semantic search configuration if a name is specified for it + """ + logger.info( + f"Creating index {AZURESEARCH_INDEX} in service {AZURESEARCH_SERVICE}" + + (f" with semantic search configuration {AZURESEARCH_SEMANTIC_CONFIG}" if AZURESEARCH_SEMANTIC_CONFIG is not None else "") + ) + mgmt_client.create_index( + SearchIndex( + name=AZURESEARCH_INDEX, + fields=[ + SimpleField(name=FIELDS_ID, type=SearchFieldDataType.String, key=True), + SearchableField(name=FIELDS_TEXT, type=SearchFieldDataType.String, analyzer_name="standard.lucene"), + SearchField(name=FIELDS_EMBEDDING, type=SearchFieldDataType.Collection(SearchFieldDataType.Single), + hidden=False, searchable=True, filterable=False, sortable=False, facetable=False, + dimensions=AZURESEARCH_DIMENSIONS, vector_search_configuration="default"), + SimpleField(name=FIELDS_DOCUMENT_ID, type=SearchFieldDataType.String, filterable=True, sortable=True), + SimpleField(name=FIELDS_SOURCE, type=SearchFieldDataType.String, filterable=True, sortable=True), + SimpleField(name=FIELDS_SOURCE_ID, type=SearchFieldDataType.String, filterable=True, sortable=True), + SimpleField(name=FIELDS_URL, type=SearchFieldDataType.String), + SimpleField(name=FIELDS_CREATED_AT, type=SearchFieldDataType.DateTimeOffset, filterable=True, sortable=True), + SimpleField(name=FIELDS_AUTHOR, type=SearchFieldDataType.String, filterable=True, sortable=True) + ], + semantic_settings=None if AZURESEARCH_SEMANTIC_CONFIG is None else SemanticSettings( + configurations=[SemanticConfiguration( + name=AZURESEARCH_SEMANTIC_CONFIG, + prioritized_fields=PrioritizedFields( + title_field=None, prioritized_content_fields=[SemanticField(field_name=FIELDS_TEXT)] + ) + )] + ), + vector_search=VectorSearch( + algorithm_configurations=[ + VectorSearchAlgorithmConfiguration( + name="default", + kind="hnsw", + # Could change to dotproduct for OpenAI's embeddings since they normalize vectors to unit length + hnsw_parameters=HnswParameters(metric="cosine") + ) + ] + ) + ) + ) + + @staticmethod + def _create_credentials(use_async: bool) -> Union[AzureKeyCredential, DefaultAzureCredential, DefaultAzureCredentialSync]: + if AZURESEARCH_API_KEY is None: + logger.info("Using DefaultAzureCredential for Azure Search, make sure local identity or managed identity are set up appropriately") + credential = DefaultAzureCredential() if use_async else DefaultAzureCredentialSync() + else: + logger.info("Using an API key to authenticate with Azure Search") + credential = AzureKeyCredential(AZURESEARCH_API_KEY) + return credential diff --git a/docs/deployment/removing-unused-dependencies.md b/docs/deployment/removing-unused-dependencies.md index 99dee16c1..dcdac20c7 100644 --- a/docs/deployment/removing-unused-dependencies.md +++ b/docs/deployment/removing-unused-dependencies.md @@ -4,13 +4,14 @@ Before deploying your app, you might want to remove unused dependencies from you Here are the packages you can remove for each vector database provider: -- **Pinecone:** Remove `weaviate-client`, `pymilvus`, `qdrant-client`, `redis`, `chromadb`, and `llama-index`. -- **Weaviate:** Remove `pinecone-client`, `pymilvus`, `qdrant-client`, `redis`, `chromadb`, and `llama-index`. -- **Zilliz:** Remove `pinecone-client`, `weaviate-client`, `qdrant-client`, `redis`, `chromadb`, and `llama-index`. -- **Milvus:** Remove `pinecone-client`, `weaviate-client`, `qdrant-client`, `redis`, `chromadb`, and `llama-index`. -- **Qdrant:** Remove `pinecone-client`, `weaviate-client`, `pymilvus`, `redis`, `chromadb`, and `llama-index`. -- **Redis:** Remove `pinecone-client`, `weaviate-client`, `pymilvus`, `qdrant-client`, `chromadb`, and `llama-index`. -- **LlamaIndex:** Remove `pinecone-client`, `weaviate-client`, `pymilvus`, `qdrant-client`, `chromadb`, and `redis`. -- **Chroma:**: Remove `pinecone-client`, `weaviate-client`, `pymilvus`, `qdrant-client`, `llama-index`, and `redis`. +- **Pinecone:** Remove `weaviate-client`, `pymilvus`, `qdrant-client`, `redis`, `chromadb`, `llama-index`, `azure-identity` and `azure-search-documents`. +- **Weaviate:** Remove `pinecone-client`, `pymilvus`, `qdrant-client`, `redis`, `chromadb`, `llama-index`, `azure-identity` and `azure-search-documents`. +- **Zilliz:** Remove `pinecone-client`, `weaviate-client`, `qdrant-client`, `redis`, `chromadb`, `llama-index`, `azure-identity` and `azure-search-documents`. +- **Milvus:** Remove `pinecone-client`, `weaviate-client`, `qdrant-client`, `redis`, `chromadb`, `llama-index`, `azure-identity` and `azure-search-documents`. +- **Qdrant:** Remove `pinecone-client`, `weaviate-client`, `pymilvus`, `redis`, `chromadb`, `llama-index`, `azure-identity` and `azure-search-documents`. +- **Redis:** Remove `pinecone-client`, `weaviate-client`, `pymilvus`, `qdrant-client`, `chromadb`, `llama-index`, `azure-identity` and `azure-search-documents`. +- **LlamaIndex:** Remove `pinecone-client`, `weaviate-client`, `pymilvus`, `qdrant-client`, `chromadb`, `redis`, `azure-identity` and `azure-search-documents`. +- **Chroma:**: Remove `pinecone-client`, `weaviate-client`, `pymilvus`, `qdrant-client`, `llama-index`, `redis`, `azure-identity` and `azure-search-documents`. +- **Azure Cognitive Search**: Remove `pinecone-client`, `weaviate-client`, `pymilvus`, `qdrant-client`, `llama-index`, `redis` and `chromadb`. After removing the unnecessary packages from the `pyproject.toml` file, you don't need to run `poetry lock` and `poetry install` manually. The provided Dockerfile takes care of installing the required dependencies using the `requirements.txt` file generated by the `poetry export` command. diff --git a/docs/providers/azuresearch/setup.md b/docs/providers/azuresearch/setup.md new file mode 100644 index 000000000..e184c2621 --- /dev/null +++ b/docs/providers/azuresearch/setup.md @@ -0,0 +1,29 @@ +# Azure Cognitive Search + +[Azure Cognitive Search](https://azure.microsoft.com/products/search/) is a complete retrieval cloud service that supports vector search, text search, and hybrid (vectors + text combined to yield the best of the two approaches). Azure Cognitive Search also offers an [optional L2 re-ranking step](https://learn.microsoft.com/azure/search/semantic-search-overview) to further improve results quality. + +You can find the Azure Cognitive Search documentation [here](https://learn.microsoft.com/azure/search/search-what-is-azure-search). If you don't have an Azure account, you can start setting one up [here](https://azure.microsoft.com/). + +## Environment variables + +| Name | Required | Description | Default | +| ---------------------------- | -------- | ------------------------------------------------------------------------------------- | ------------------- | +| `DATASTORE` | Yes | Datastore name, set to `azuresearch` | | +| `BEARER_TOKEN` | Yes | Secret token | | +| `OPENAI_API_KEY` | Yes | OpenAI API key | | +| `AZURESEARCH_SERVICE` | Yes | Name of your search service | | +| `AZURESEARCH_INDEX` | Yes | Name of your search index | | +| `AZURESEARCH_API_KEY` | No | Your API key, if using key-based auth instead of Azure managed identity |Uses managed identity| +| `AZURESEARCH_DISABLE_HYBRID` | No | Disable hybrid search and only use vector similarity |Use hybrid search | +| `AZURESEARCH_SEMANTIC_CONFIG`| No | Enable L2 re-ranking with this configuration name [see re-ranking below](#re-ranking) |L2 not enabled | +| `AZURESEARCH_LANGUAGE` | No | If using L2 re-ranking, language for queries/documents (valid values [listed here](https://learn.microsoft.com/rest/api/searchservice/preview-api/search-documents#queryLanguage)) |`en-us` | +| `AZURESEARCH_DIMENSIONS` | No | Vector size for embeddings |1536 (OpenAI's Ada002)| + +## Authentication Options + +* API key: this is enabled by default; you can obtain the key in the Azure Portal or using the Azure CLI. +* Managed identity: If the plugin is running in Azure, you can enable managed identity for the host and give that identity access to the service, without having to manage keys (avoiding secret storage, rotation, etc.). More details [here](https://learn.microsoft.com/azure/search/search-security-rbac). + +## Re-ranking + +Azure Cognitive Search offers the option to enable a second (L2) ranking step after retrieval to further improve results quality. This only applies when using text or hybrid search. Since it has latency and cost implications, if you want to try this option you need to explicitly [enable "semantic search"](https://learn.microsoft.com/azure/search/semantic-search-overview#enable-semantic-search) in your Cognitive Search service, and [create a semantic search configuration](https://learn.microsoft.com/azure/search/semantic-how-to-query-request#2---create-a-semantic-configuration) for your index. \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index b51d18045..d119e11aa 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -206,6 +206,79 @@ files = [ [package.dependencies] cryptography = ">=3.2" +[[package]] +name = "azure-common" +version = "1.1.28" +description = "Microsoft Azure Client Library for Python (Common)" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "azure-common-1.1.28.zip", hash = "sha256:4ac0cd3214e36b6a1b6a442686722a5d8cc449603aa833f3f0f40bda836704a3"}, + {file = "azure_common-1.1.28-py2.py3-none-any.whl", hash = "sha256:5c12d3dcf4ec20599ca6b0d3e09e86e146353d443e7fcc050c9a19c1f9df20ad"}, +] + +[[package]] +name = "azure-core" +version = "1.26.4" +description = "Microsoft Azure Core Library for Python" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "azure-core-1.26.4.zip", hash = "sha256:075fe06b74c3007950dd93d49440c2f3430fd9b4a5a2756ec8c79454afc989c6"}, + {file = "azure_core-1.26.4-py3-none-any.whl", hash = "sha256:d9664b4bc2675d72fba461a285ac43ae33abb2967014a955bf136d9703a2ab3c"}, +] + +[package.dependencies] +requests = ">=2.18.4" +six = ">=1.11.0" +typing-extensions = ">=4.3.0" + +[package.extras] +aio = ["aiohttp (>=3.0)"] + +[[package]] +name = "azure-identity" +version = "1.12.0" +description = "Microsoft Azure Identity Library for Python" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "azure-identity-1.12.0.zip", hash = "sha256:7f9b1ae7d97ea7af3f38dd09305e19ab81a1e16ab66ea186b6579d85c1ca2347"}, + {file = "azure_identity-1.12.0-py3-none-any.whl", hash = "sha256:2a58ce4a209a013e37eaccfd5937570ab99e9118b3e1acf875eed3a85d541b92"}, +] + +[package.dependencies] +azure-core = ">=1.11.0,<2.0.0" +cryptography = ">=2.5" +msal = ">=1.12.0,<2.0.0" +msal-extensions = ">=0.3.0,<2.0.0" +six = ">=1.12.0" + +[[package]] +name = "azure-search-documents" +version = "11.4.0a20230509004" +description = "Microsoft Azure Cognitive Search Client Library for Python" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "azure-search-documents-11.4.0a20230509004.zip", hash = "sha256:6cca144573161a10aa0fcd13927264453e79c63be6a53cf2ec241c9c8c22f6b5"}, + {file = "azure_search_documents-11.4.0a20230509004-py3-none-any.whl", hash = "sha256:6215e9a4f9e935ff3eac1b7d5519c6c0789b4497eb11242d376911aaefbb0359"}, +] + +[package.dependencies] +azure-common = ">=1.1,<2.0" +azure-core = ">=1.24.0,<2.0.0" +isodate = ">=0.6.0" + +[package.source] +type = "legacy" +url = "https://pkgs.dev.azure.com/azure-sdk/public/_packaging/azure-sdk-for-python/pypi/simple" +reference = "azure-sdk-dev" + [[package]] name = "backoff" version = "2.2.1" @@ -1392,6 +1465,21 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "isodate" +version = "0.6.1" +description = "An ISO 8601 date/time/duration parser and formatter" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "isodate-0.6.1-py2.py3-none-any.whl", hash = "sha256:0751eece944162659049d35f4f549ed815792b38793f07cf73381c1c87cbed96"}, + {file = "isodate-0.6.1.tar.gz", hash = "sha256:48c5881de7e8b0a0d648cb024c8062dc84e7b840ed81e864c7614fd3c127bde9"}, +] + +[package.dependencies] +six = "*" + [[package]] name = "jinja2" version = "3.1.2" @@ -1765,6 +1853,45 @@ docs = ["sphinx"] gmpy = ["gmpy2 (>=2.1.0a4)"] tests = ["pytest (>=4.6)"] +[[package]] +name = "msal" +version = "1.22.0" +description = "The Microsoft Authentication Library (MSAL) for Python library enables your app to access the Microsoft Cloud by supporting authentication of users with Microsoft Azure Active Directory accounts (AAD) and Microsoft Accounts (MSA) using industry standard OAuth2 and OpenID Connect." +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "msal-1.22.0-py2.py3-none-any.whl", hash = "sha256:9120b7eafdf061c92f7b3d744e5f325fca35873445fa8ffebb40b1086a13dd58"}, + {file = "msal-1.22.0.tar.gz", hash = "sha256:8a82f5375642c1625c89058018430294c109440dce42ea667d466c2cab520acd"}, +] + +[package.dependencies] +cryptography = ">=0.6,<43" +PyJWT = {version = ">=1.0.0,<3", extras = ["crypto"]} +requests = ">=2.0.0,<3" + +[package.extras] +broker = ["pymsalruntime (>=0.13.2,<0.14)"] + +[[package]] +name = "msal-extensions" +version = "1.0.0" +description = "Microsoft Authentication Library extensions (MSAL EX) provides a persistence API that can save your data on disk, encrypted on Windows, macOS and Linux. Concurrent data access will be coordinated by a file lock mechanism." +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "msal-extensions-1.0.0.tar.gz", hash = "sha256:c676aba56b0cce3783de1b5c5ecfe828db998167875126ca4b47dc6436451354"}, + {file = "msal_extensions-1.0.0-py2.py3-none-any.whl", hash = "sha256:91e3db9620b822d0ed2b4d1850056a0f133cba04455e62f11612e40f5502f2ee"}, +] + +[package.dependencies] +msal = ">=0.4.1,<2.0.0" +portalocker = [ + {version = ">=1.0,<3", markers = "python_version >= \"3.5\" and platform_system != \"Windows\""}, + {version = ">=1.6,<3", markers = "python_version >= \"3.5\" and platform_system == \"Windows\""}, +] + [[package]] name = "multidict" version = "6.0.4" @@ -2226,6 +2353,26 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "portalocker" +version = "2.7.0" +description = "Wraps the portalocker recipe for easy usage" +category = "main" +optional = false +python-versions = ">=3.5" +files = [ + {file = "portalocker-2.7.0-py2.py3-none-any.whl", hash = "sha256:a07c5b4f3985c3cf4798369631fb7011adb498e2a46d8440efc75a8f29a0f983"}, + {file = "portalocker-2.7.0.tar.gz", hash = "sha256:032e81d534a88ec1736d03f780ba073f047a06c478b06e2937486f334e955c51"}, +] + +[package.dependencies] +pywin32 = {version = ">=226", markers = "platform_system == \"Windows\""} + +[package.extras] +docs = ["sphinx (>=1.7.1)"] +redis = ["redis"] +tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "pytest-timeout (>=2.1.0)", "redis", "sphinx (>=6.0.0)"] + [[package]] name = "posthog" version = "3.0.1" @@ -2381,6 +2528,27 @@ typing-extensions = ">=4.2.0" dotenv = ["python-dotenv (>=0.10.4)"] email = ["email-validator (>=1.0.3)"] +[[package]] +name = "pyjwt" +version = "2.7.0" +description = "JSON Web Token implementation in Python" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "PyJWT-2.7.0-py3-none-any.whl", hash = "sha256:ba2b425b15ad5ef12f200dc67dd56af4e26de2331f965c5439994dad075876e1"}, + {file = "PyJWT-2.7.0.tar.gz", hash = "sha256:bd6ca4a3c4285c1a2d4349e5a035fdf8fb94e04ccd0fcbe6ba289dae9cc3e074"}, +] + +[package.dependencies] +cryptography = {version = ">=3.4.0", optional = true, markers = "extra == \"crypto\""} + +[package.extras] +crypto = ["cryptography (>=3.4.0)"] +dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] +docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] +tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] + [[package]] name = "pymilvus" version = "2.2.8" @@ -2553,6 +2721,30 @@ files = [ {file = "pytz-2023.3.tar.gz", hash = "sha256:1d8ce29db189191fb55338ee6d0387d82ab59f3d00eac103412d64e0ebd0c588"}, ] +[[package]] +name = "pywin32" +version = "306" +description = "Python for Window Extensions" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"}, + {file = "pywin32-306-cp310-cp310-win_amd64.whl", hash = "sha256:84f4471dbca1887ea3803d8848a1616429ac94a4a8d05f4bc9c5dcfd42ca99c8"}, + {file = "pywin32-306-cp311-cp311-win32.whl", hash = "sha256:e65028133d15b64d2ed8f06dd9fbc268352478d4f9289e69c190ecd6818b6407"}, + {file = "pywin32-306-cp311-cp311-win_amd64.whl", hash = "sha256:a7639f51c184c0272e93f244eb24dafca9b1855707d94c192d4a0b4c01e1100e"}, + {file = "pywin32-306-cp311-cp311-win_arm64.whl", hash = "sha256:70dba0c913d19f942a2db25217d9a1b726c278f483a919f1abfed79c9cf64d3a"}, + {file = "pywin32-306-cp312-cp312-win32.whl", hash = "sha256:383229d515657f4e3ed1343da8be101000562bf514591ff383ae940cad65458b"}, + {file = "pywin32-306-cp312-cp312-win_amd64.whl", hash = "sha256:37257794c1ad39ee9be652da0462dc2e394c8159dfd913a8a4e8eb6fd346da0e"}, + {file = "pywin32-306-cp312-cp312-win_arm64.whl", hash = "sha256:5821ec52f6d321aa59e2db7e0a35b997de60c201943557d108af9d4ae1ec7040"}, + {file = "pywin32-306-cp37-cp37m-win32.whl", hash = "sha256:1c73ea9a0d2283d889001998059f5eaaba3b6238f767c9cf2833b13e6a685f65"}, + {file = "pywin32-306-cp37-cp37m-win_amd64.whl", hash = "sha256:72c5f621542d7bdd4fdb716227be0dd3f8565c11b280be6315b06ace35487d36"}, + {file = "pywin32-306-cp38-cp38-win32.whl", hash = "sha256:e4c092e2589b5cf0d365849e73e02c391c1349958c5ac3e9d5ccb9a28e017b3a"}, + {file = "pywin32-306-cp38-cp38-win_amd64.whl", hash = "sha256:e8ac1ae3601bee6ca9f7cb4b5363bf1c0badb935ef243c4733ff9a393b1690c0"}, + {file = "pywin32-306-cp39-cp39-win32.whl", hash = "sha256:e25fd5b485b55ac9c057f67d94bc203f3f6595078d1fb3b458c9c28b7153a802"}, + {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"}, +] + [[package]] name = "pyyaml" version = "6.0" @@ -3924,4 +4116,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "d4a5a453221f38f4f2ef316632490b6f45dd4100e632b77ba63fc7a8ec4a1b6d" +content-hash = "565f4035798b2616c26c04feee5e98ebf7630dec67b4df73adae876ffd52ec86" diff --git a/pyproject.toml b/pyproject.toml index fe2059f7f..bb7fa1705 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,11 @@ authors = ["isafulf "] readme = "README.md" packages = [{include = "server"}] +[[tool.poetry.source]] +name = "azure-sdk-dev" +url = "https://pkgs.dev.azure.com/azure-sdk/public/_packaging/azure-sdk-for-python/pypi/simple/" +secondary = true + [tool.poetry.dependencies] python = "^3.10" fastapi = "^0.92.0" @@ -28,6 +33,8 @@ pymilvus = "^2.2.2" qdrant-client = {version = "^1.0.4", python = "<3.12"} redis = "4.5.1" llama-index = "0.5.4" +azure-identity = "^1.12.0" +azure-search-documents = {version = "11.4.0a20230509004", source = "azure-sdk-dev"} [tool.poetry.scripts] start = "server.main:start" diff --git a/tests/datastore/providers/azuresearch/test_azuresearch_datastore.py b/tests/datastore/providers/azuresearch/test_azuresearch_datastore.py new file mode 100644 index 000000000..0a777d3b0 --- /dev/null +++ b/tests/datastore/providers/azuresearch/test_azuresearch_datastore.py @@ -0,0 +1,139 @@ +import pytest +import os +import time +from typing import Union +from azure.search.documents.indexes import SearchIndexClient +from models.models import DocumentMetadataFilter, Query, Source, Document, DocumentMetadata + +AZURESEARCH_TEST_INDEX = "testindex" +os.environ["AZURESEARCH_INDEX"] = AZURESEARCH_TEST_INDEX +if os.environ.get("AZURESEARCH_SERVICE") == None: + os.environ["AZURESEARCH_SERVICE"] = "invalid service name" # Will fail anyway if not set to a real service, but allows tests to be discovered + +import datastore.providers.azuresearch_datastore +from datastore.providers.azuresearch_datastore import AzureSearchDataStore + +@pytest.fixture(scope="module") +def azuresearch_mgmt_client(): + service = os.environ["AZURESEARCH_SERVICE"] + return SearchIndexClient( + endpoint=f"https://{service}.search.windows.net", + credential=AzureSearchDataStore._create_credentials(False) + ) + +def test_translate_filter(): + assert AzureSearchDataStore._translate_filter( + DocumentMetadataFilter() + ) == None + + for field in ["document_id", "source", "source_id", "author"]: + value = Source.file if field == "source" else f"test_{field}" + needs_escaping_value = None if field == "source" else f"test'_{field}" + assert AzureSearchDataStore._translate_filter( + DocumentMetadataFilter(**{field: value}) + ) == f"{field} eq '{value}'" + if needs_escaping_value != None: + assert AzureSearchDataStore._translate_filter( + DocumentMetadataFilter(**{field: needs_escaping_value}) + ) == f"{field} eq 'test''_{field}'" + + assert AzureSearchDataStore._translate_filter( + DocumentMetadataFilter( + document_id = "test_document_id", + source = Source.file, + source_id = "test_source_id", + author = "test_author" + ) + ) == "document_id eq 'test_document_id' and source eq 'file' and source_id eq 'test_source_id' and author eq 'test_author'" + + with pytest.raises(ValueError): + assert AzureSearchDataStore._translate_filter( + DocumentMetadataFilter(start_date="2023-01-01") + ) + with pytest.raises(ValueError): + assert AzureSearchDataStore._translate_filter( + DocumentMetadataFilter(end_date="2023-01-01") + ) + + assert AzureSearchDataStore._translate_filter( + DocumentMetadataFilter(start_date="2023-01-01T00:00:00Z", end_date="2023-01-02T00:00:00Z", document_id = "test_document_id") + ) == "document_id eq 'test_document_id' and created_at ge 2023-01-01T00:00:00Z and created_at le 2023-01-02T00:00:00Z" + +@pytest.mark.asyncio +async def test_lifecycle_hybrid(azuresearch_mgmt_client: SearchIndexClient): + datastore.providers.azuresearch_datastore.AZURESEARCH_DISABLE_HYBRID = None + datastore.providers.azuresearch_datastore.AZURESEARCH_SEMANTIC_CONFIG = None + await lifecycle(azuresearch_mgmt_client) + +@pytest.mark.asyncio +async def test_lifecycle_vectors_only(azuresearch_mgmt_client: SearchIndexClient): + datastore.providers.azuresearch_datastore.AZURESEARCH_DISABLE_HYBRID = "1" + datastore.providers.azuresearch_datastore.AZURESEARCH_SEMANTIC_CONFIG = None + await lifecycle(azuresearch_mgmt_client) + +@pytest.mark.asyncio +async def test_lifecycle_semantic(azuresearch_mgmt_client: SearchIndexClient): + datastore.providers.azuresearch_datastore.AZURESEARCH_DISABLE_HYBRID = None + datastore.providers.azuresearch_datastore.AZURESEARCH_SEMANTIC_CONFIG = "testsemconfig" + await lifecycle(azuresearch_mgmt_client) + +async def lifecycle(azuresearch_mgmt_client: SearchIndexClient): + if AZURESEARCH_TEST_INDEX in azuresearch_mgmt_client.list_index_names(): + azuresearch_mgmt_client.delete_index(AZURESEARCH_TEST_INDEX) + assert AZURESEARCH_TEST_INDEX not in azuresearch_mgmt_client.list_index_names() + try: + store = AzureSearchDataStore() + index = azuresearch_mgmt_client.get_index(AZURESEARCH_TEST_INDEX) + assert index is not None + + result = await store.upsert([ + Document( + id="test_id_1", + text="test text", + metadata=DocumentMetadata(source=Source.file, source_id="test_source_id", author="test_author", created_at="2023-01-01T00:00:00Z", url="http://some-test-url/path")), + Document( + id="test_id_2+", + text="different", + metadata=DocumentMetadata(source=Source.file, source_id="test_source_id", author="test_author", created_at="2023-01-01T00:00:00Z", url="http://some-test-url/path"))]) + assert len(result) == 2 and result[0] == "test_id_1" and result[1] == "test_id_2+" + + # query in a loop in case we need to retry since documents aren't searchable synchronosuly after updates + for _ in range(4): + time.sleep(0.25) + result = await store.query([Query(query="text")]) + if len(result) > 0 and len(result[0].results) > 0: + break + assert len(result) == 1 and len(result[0].results) == 2 + assert result[0].results[0].metadata.document_id == "test_id_1" and result[0].results[1].metadata.document_id == "test_id_2+" + + result = await store.query([Query(query="text", filter=DocumentMetadataFilter(source_id="test_source_id"))]) + assert len(result) == 1 and len(result[0].results) == 2 + assert result[0].results[0].metadata.document_id == "test_id_1" and result[0].results[1].metadata.document_id == "test_id_2+" + + result = await store.query([Query(query="text", filter=DocumentMetadataFilter(source_id="nonexisting_id"))]) + assert len(result) == 1 and len(result[0].results) == 0 + + result = await store.query([Query(query="text", filter=DocumentMetadataFilter(start_date="2023-01-02T00:00:00Z"))]) + assert len(result) == 1 and len(result[0].results) == 0 + + result = await store.query([Query(query="text", filter=DocumentMetadataFilter(start_date="2023-01-01T00:00:00Z"))]) + assert len(result) == 1 and len(result[0].results) == 2 + assert result[0].results[0].metadata.document_id == "test_id_1" and result[0].results[1].metadata.document_id == "test_id_2+" + + result = await store.query([Query(query="text", filter=DocumentMetadataFilter(end_date="2022-12-31T00:00:00Z"))]) + assert len(result) == 1 and len(result[0].results) == 0 + + result = await store.query([Query(query="text", filter=DocumentMetadataFilter(end_date="2023-01-02T00:00:00Z"))]) + assert len(result) == 1 and len(result[0].results) == 2 + assert result[0].results[0].metadata.document_id == "test_id_1" and result[0].results[1].metadata.document_id == "test_id_2+" + + # query in a loop in case we need to retry since documents aren't searchable synchronosuly after updates + assert await store.delete(["test_id_1", "test_id_2+"]) + for _ in range(4): + time.sleep(0.25) + result = await store.query([Query(query="text")]) + if len(result) > 0 and len(result[0].results) == 0: + break + assert len(result) == 1 and len(result[0].results) == 0 + finally: + azuresearch_mgmt_client.delete_index(AZURESEARCH_TEST_INDEX)