Skip to content

Commit

Permalink
Azure Cognitive Search datastore (openai#244)
Browse files Browse the repository at this point in the history
* Add Azure Cognitive Search as data store

* Vectors support

* Pull only index names when checking for existing indexes

* Remove the extra url prefix option

* Actually use semantic mode (L2 reranking) when enabled

* Docs

* Add user agent to search clients

* Use async for search I/O

* Use search client package reference from Azure SDK dev feed

* PR feedback

* Encode keys to avoid restrictions on valid characters

* Updating removable dependencies list
  • Loading branch information
pablocastro authored May 12, 2023
1 parent 9ca9552 commit 361a0ad
Show file tree
Hide file tree
Showing 8 changed files with 652 additions and 12 deletions.
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ This README provides detailed information on how to set up, develop, and deploy
- [Redis](#redis)
- [Llama Index](#llamaindex)
- [Chroma](#chroma)
- [Azure Cognitive Search](#azure-cognitive-search)
- [Running the API Locally](#running-the-api-locally)
- [Testing a Localhost Plugin in ChatGPT](#testing-a-localhost-plugin-in-chatgpt)
- [Personalization](#personalization)
Expand Down Expand Up @@ -135,6 +136,11 @@ Follow these steps to quickly set up and run the ChatGPT Retrieval Plugin:
export CHROMA_PERSISTENCE_DIR=<your_chroma_persistence_directory>
export CHROMA_HOST=<your_chroma_host>
export CHROMA_PORT=<your_chroma_port>
# Azure Cognitive Search
export AZURESEARCH_SERVICE=<your_search_service_name>
export AZURESEARCH_INDEX=<your_search_index_name>
export AZURESEARCH_API_KEY=<your_api_key> (optional, uses key-free managed identity if not set)
```

10. Run the API locally: `poetry run start`
Expand Down Expand Up @@ -248,7 +254,7 @@ The API requires the following environment variables to work:

| Name | Required | Description |
| ---------------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `DATASTORE` | Yes | This specifies the vector database provider you want to use to store and query embeddings. You can choose from `chroma`, `pinecone`, `weaviate`, `zilliz`, `milvus`, `qdrant`, or `redis`. |
| `DATASTORE` | Yes | This specifies the vector database provider you want to use to store and query embeddings. You can choose from `chroma`, `pinecone`, `weaviate`, `zilliz`, `milvus`, `qdrant`, `redis`, `azuresearch`. |
| `BEARER_TOKEN` | Yes | This is a secret token that you need to authenticate your requests to the API. You can generate one using any tool or method you prefer, such as [jwt.io](https://jwt.io/). |
| `OPENAI_API_KEY` | Yes | This is your OpenAI API key that you need to generate embeddings using the `text-embedding-ada-002` model. You can get an API key by creating an account on [OpenAI](https://openai.com/). |

Expand Down Expand Up @@ -305,6 +311,10 @@ For detailed setup instructions, refer to [`/docs/providers/llama/setup.md`](/do

[Chroma](https://trychroma.com) is an AI-native open-source embedding database designed to make getting started as easy as possible. Chroma runs in-memory, or in a client-server setup. It supports metadata and keyword filtering out of the box. For detailed instructions, refer to [`/docs/providers/chroma/setup.md`](/docs/providers/chroma/setup.md).

#### Azure Cognitive Search

[Azure Cognitive Search](https://azure.microsoft.com/products/search/) is a complete retrieval cloud service that supports vector search, text search, and hybrid (vectors + text combined to yield the best of the two approaches). It also offers an [optional L2 re-ranking step](https://learn.microsoft.com/azure/search/semantic-search-overview) to further improve results quality. For detailed setup instructions, refer to [`/docs/providers/azuresearch/setup.md`](/docs/providers/azuresearch/setup.md)

### Running the API locally

To run the API locally, you first need to set the requisite environment variables with the `export` command:
Expand Down Expand Up @@ -448,7 +458,7 @@ The scripts are:

While the ChatGPT Retrieval Plugin is designed to provide a flexible solution for semantic search and retrieval, it does have some limitations:

- **Keyword search limitations**: The embeddings generated by the `text-embedding-ada-002` model may not always be effective at capturing exact keyword matches. As a result, the plugin might not return the most relevant results for queries that rely heavily on specific keywords. Some vector databases, like Pinecone and Weaviate, use hybrid search and might perform better for keyword searches.
- **Keyword search limitations**: The embeddings generated by the `text-embedding-ada-002` model may not always be effective at capturing exact keyword matches. As a result, the plugin might not return the most relevant results for queries that rely heavily on specific keywords. Some vector databases, like Pinecone, Weaviate and Azure Cognitive Search, use hybrid search and might perform better for keyword searches.
- **Sensitive data handling**: The plugin does not automatically detect or filter sensitive data. It is the responsibility of the developers to ensure that they have the necessary authorization to include content in the Retrieval Plugin and that the content complies with data privacy requirements.
- **Scalability**: The performance of the plugin may vary depending on the chosen vector database provider and the size of the dataset. Some providers may offer better scalability and performance than others.
- **Language support**: The plugin currently uses OpenAI's `text-embedding-ada-002` model, which is optimized for use in English. However, it is still robust enough to generate good results for a variety of languages.
Expand Down
4 changes: 4 additions & 0 deletions datastore/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ async def get_datastore() -> DataStore:
from datastore.providers.qdrant_datastore import QdrantDataStore

return QdrantDataStore()
case "azuresearch":
from datastore.providers.azuresearch_datastore import AzureSearchDataStore

return AzureSearchDataStore()
case _:
raise ValueError(
f"Unsupported vector database: {datastore}. "
Expand Down
258 changes: 258 additions & 0 deletions datastore/providers/azuresearch_datastore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
import asyncio
import os
import re
import time
import base64
from typing import Dict, List, Optional, Union
from datastore.datastore import DataStore
from models.models import DocumentChunk, DocumentChunkMetadata, DocumentChunkWithScore, DocumentMetadataFilter, Query, QueryResult, QueryWithEmbedding
from loguru import logger
from azure.search.documents.aio import SearchClient
from azure.search.documents.models import Vector, QueryType
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import *
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential as DefaultAzureCredentialSync
from azure.identity.aio import DefaultAzureCredential

AZURESEARCH_SERVICE = os.environ.get("AZURESEARCH_SERVICE")
AZURESEARCH_INDEX = os.environ.get("AZURESEARCH_INDEX")
AZURESEARCH_API_KEY = os.environ.get("AZURESEARCH_API_KEY")
AZURESEARCH_SEMANTIC_CONFIG = os.environ.get("AZURESEARCH_SEMANTIC_CONFIG")
AZURESEARCH_LANGUAGE = os.environ.get("AZURESEARCH_LANGUAGE", "en-us")
AZURESEARCH_DISABLE_HYBRID = os.environ.get("AZURESEARCH_DISABLE_HYBRID")
AZURESEARCH_DIMENSIONS = os.environ.get("AZURESEARCH_DIMENSIONS", 1536) # Default to OpenAI's ada-002 embedding model vector size
assert AZURESEARCH_SERVICE is not None
assert AZURESEARCH_INDEX is not None

# Allow overriding field names for Azure Search
FIELDS_ID = os.environ.get("AZURESEARCH_FIELDS_ID", "id")
FIELDS_TEXT = os.environ.get("AZURESEARCH_FIELDS_TEXT", "text")
FIELDS_EMBEDDING = os.environ.get("AZURESEARCH_FIELDS_TEXT", "embedding")
FIELDS_DOCUMENT_ID = os.environ.get("AZURESEARCH_FIELDS_DOCUMENT_ID", "document_id")
FIELDS_SOURCE = os.environ.get("AZURESEARCH_FIELDS_SOURCE", "source")
FIELDS_SOURCE_ID = os.environ.get("AZURESEARCH_FIELDS_SOURCE_ID", "source_id")
FIELDS_URL = os.environ.get("AZURESEARCH_FIELDS_URL", "url")
FIELDS_CREATED_AT = os.environ.get("AZURESEARCH_FIELDS_CREATED_AT", "created_at")
FIELDS_AUTHOR = os.environ.get("AZURESEARCH_FIELDS_AUTHOR", "author")

MAX_UPLOAD_BATCH_SIZE = 1000
MAX_DELETE_BATCH_SIZE = 1000

class AzureSearchDataStore(DataStore):
def __init__(self):
self.client = SearchClient(
endpoint=f"https://{AZURESEARCH_SERVICE}.search.windows.net",
index_name=AZURESEARCH_INDEX,
credential=AzureSearchDataStore._create_credentials(True),
user_agent="retrievalplugin"
)

mgmt_client = SearchIndexClient(
endpoint=f"https://{AZURESEARCH_SERVICE}.search.windows.net",
credential=AzureSearchDataStore._create_credentials(False),
user_agent="retrievalplugin"
)
if AZURESEARCH_INDEX not in [name for name in mgmt_client.list_index_names()]:
self._create_index(mgmt_client)
else:
logger.info(f"Using existing index {AZURESEARCH_INDEX} in service {AZURESEARCH_SERVICE}")

async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]:
azdocuments: List[Dict] = []

async def upload():
r = await self.client.upload_documents(documents=azdocuments)
count = sum(1 for rr in r if rr.succeeded)
logger.info(f"Upserted {count} chunks out of {len(azdocuments)}")
if count < len(azdocuments):
raise Exception(f"Failed to upload {len(azdocuments) - count} chunks")

ids = []
for document_id, document_chunks in chunks.items():
ids.append(document_id)
for chunk in document_chunks:
azdocuments.append({
# base64-encode the id string to stay within Azure Search's valid characters for keys
FIELDS_ID: base64.urlsafe_b64encode(bytes(chunk.id, "utf-8")).decode("ascii"),
FIELDS_TEXT: chunk.text,
FIELDS_EMBEDDING: chunk.embedding,
FIELDS_DOCUMENT_ID: document_id,
FIELDS_SOURCE: chunk.metadata.source,
FIELDS_SOURCE_ID: chunk.metadata.source_id,
FIELDS_URL: chunk.metadata.url,
FIELDS_CREATED_AT: chunk.metadata.created_at,
FIELDS_AUTHOR: chunk.metadata.author,
})

if len(azdocuments) >= MAX_UPLOAD_BATCH_SIZE:
await upload()
azdocuments = []

if len(azdocuments) > 0:
await upload()

return ids

async def delete(self, ids: Optional[List[str]] = None, filter: Optional[DocumentMetadataFilter] = None, delete_all: Optional[bool] = None) -> bool:
filter = None if delete_all else self._translate_filter(filter)
if delete_all or filter is not None:
deleted = set()
while True:
search_result = await self.client.search(None, filter=filter, top=MAX_DELETE_BATCH_SIZE, include_total_count=True, select=FIELDS_ID)
if await search_result.get_count() == 0:
break
documents = [{ FIELDS_ID: d[FIELDS_ID] } async for d in search_result if d[FIELDS_ID] not in deleted]
if len(documents) > 0:
logger.info(f"Deleting {len(documents)} chunks " + ("using a filter" if filter is not None else "using delete_all"))
del_result = await self.client.delete_documents(documents=documents)
if not all([rr.succeeded for rr in del_result]):
raise Exception("Failed to delete documents")
deleted.update([d[FIELDS_ID] for d in documents])
else:
# All repeats, delay a bit to let the index refresh and try again
time.sleep(0.25)

if ids is not None and len(ids) > 0:
for id in ids:
logger.info(f"Deleting chunks for document id {id}")
await self.delete(filter=DocumentMetadataFilter(document_id=id))

return True

async def _query(self, queries: List[QueryWithEmbedding]) -> List[QueryResult]:
"""
Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores.
"""
return await asyncio.gather(*(self._single_query(query) for query in queries))

async def _single_query(self, query: QueryWithEmbedding) -> QueryResult:
"""
Takes in a single query and filters and returns a query result with matching document chunks and scores.
"""
filter = self._translate_filter(query.filter) if query.filter is not None else None
try:
k = query.top_k if filter is None else query.top_k * 2
q = query.query if not AZURESEARCH_DISABLE_HYBRID else None
if AZURESEARCH_SEMANTIC_CONFIG != None and not AZURESEARCH_DISABLE_HYBRID:
r = await self.client.search(
q,
filter=filter,
top=query.top_k,
vector=Vector(value=query.embedding, k=k, fields=FIELDS_EMBEDDING),
query_type=QueryType.SEMANTIC,
query_language=AZURESEARCH_LANGUAGE,
semantic_configuration_name=AZURESEARCH_SEMANTIC_CONFIG)
else:
r = await self.client.search(
q,
filter=filter,
top=query.top_k,
vector=Vector(value=query.embedding, k=k, fields=FIELDS_EMBEDDING))
results: List[DocumentChunkWithScore] = []
async for hit in r:
f = lambda field: hit.get(field) if field != "-" else None
results.append(DocumentChunkWithScore(
id=hit[FIELDS_ID],
text=hit[FIELDS_TEXT],
metadata=DocumentChunkMetadata(
document_id=f(FIELDS_DOCUMENT_ID),
source=f(FIELDS_SOURCE),
source_id=f(FIELDS_SOURCE_ID),
url=f(FIELDS_URL),
created_at=f(FIELDS_CREATED_AT),
author=f(FIELDS_AUTHOR)
),
score=hit["@search.score"]
))

return QueryResult(query=query.query, results=results)
except Exception as e:
raise Exception(f"Error querying the index: {e}")

@staticmethod
def _translate_filter(filter: DocumentMetadataFilter) -> str:
"""
Translates a DocumentMetadataFilter into an Azure Search filter string
"""
if filter is None:
return None

escape = lambda s: s.replace("'", "''")

# regex to validate dates are in OData format
date_re = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z")

filter_list = []
if filter.document_id is not None:
filter_list.append(f"{FIELDS_DOCUMENT_ID} eq '{escape(filter.document_id)}'")
if filter.source is not None:
filter_list.append(f"{FIELDS_SOURCE} eq '{escape(filter.source)}'")
if filter.source_id is not None:
filter_list.append(f"{FIELDS_SOURCE_ID} eq '{escape(filter.source_id)}'")
if filter.author is not None:
filter_list.append(f"{FIELDS_AUTHOR} eq '{escape(filter.author)}'")
if filter.start_date is not None:
if not date_re.match(filter.start_date):
raise ValueError(f"start_date must be in OData format, got {filter.start_date}")
filter_list.append(f"{FIELDS_CREATED_AT} ge {filter.start_date}")
if filter.end_date is not None:
if not date_re.match(filter.end_date):
raise ValueError(f"end_date must be in OData format, got {filter.end_date}")
filter_list.append(f"{FIELDS_CREATED_AT} le {filter.end_date}")
return " and ".join(filter_list) if len(filter_list) > 0 else None

def _create_index(self, mgmt_client: SearchIndexClient):
"""
Creates an Azure Cognitive Search index, including a semantic search configuration if a name is specified for it
"""
logger.info(
f"Creating index {AZURESEARCH_INDEX} in service {AZURESEARCH_SERVICE}" +
(f" with semantic search configuration {AZURESEARCH_SEMANTIC_CONFIG}" if AZURESEARCH_SEMANTIC_CONFIG is not None else "")
)
mgmt_client.create_index(
SearchIndex(
name=AZURESEARCH_INDEX,
fields=[
SimpleField(name=FIELDS_ID, type=SearchFieldDataType.String, key=True),
SearchableField(name=FIELDS_TEXT, type=SearchFieldDataType.String, analyzer_name="standard.lucene"),
SearchField(name=FIELDS_EMBEDDING, type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
hidden=False, searchable=True, filterable=False, sortable=False, facetable=False,
dimensions=AZURESEARCH_DIMENSIONS, vector_search_configuration="default"),
SimpleField(name=FIELDS_DOCUMENT_ID, type=SearchFieldDataType.String, filterable=True, sortable=True),
SimpleField(name=FIELDS_SOURCE, type=SearchFieldDataType.String, filterable=True, sortable=True),
SimpleField(name=FIELDS_SOURCE_ID, type=SearchFieldDataType.String, filterable=True, sortable=True),
SimpleField(name=FIELDS_URL, type=SearchFieldDataType.String),
SimpleField(name=FIELDS_CREATED_AT, type=SearchFieldDataType.DateTimeOffset, filterable=True, sortable=True),
SimpleField(name=FIELDS_AUTHOR, type=SearchFieldDataType.String, filterable=True, sortable=True)
],
semantic_settings=None if AZURESEARCH_SEMANTIC_CONFIG is None else SemanticSettings(
configurations=[SemanticConfiguration(
name=AZURESEARCH_SEMANTIC_CONFIG,
prioritized_fields=PrioritizedFields(
title_field=None, prioritized_content_fields=[SemanticField(field_name=FIELDS_TEXT)]
)
)]
),
vector_search=VectorSearch(
algorithm_configurations=[
VectorSearchAlgorithmConfiguration(
name="default",
kind="hnsw",
# Could change to dotproduct for OpenAI's embeddings since they normalize vectors to unit length
hnsw_parameters=HnswParameters(metric="cosine")
)
]
)
)
)

@staticmethod
def _create_credentials(use_async: bool) -> Union[AzureKeyCredential, DefaultAzureCredential, DefaultAzureCredentialSync]:
if AZURESEARCH_API_KEY is None:
logger.info("Using DefaultAzureCredential for Azure Search, make sure local identity or managed identity are set up appropriately")
credential = DefaultAzureCredential() if use_async else DefaultAzureCredentialSync()
else:
logger.info("Using an API key to authenticate with Azure Search")
credential = AzureKeyCredential(AZURESEARCH_API_KEY)
return credential
Loading

0 comments on commit 361a0ad

Please sign in to comment.