From 0d3c789dcbeccaac61924f2c05502e96567fd578 Mon Sep 17 00:00:00 2001 From: Pablo Castro Date: Mon, 15 May 2023 16:31:34 -0700 Subject: [PATCH] Fix min number of candidates for L2 reranking and env name for overriding the embedding field name (#260) --- datastore/providers/azuresearch_datastore.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/datastore/providers/azuresearch_datastore.py b/datastore/providers/azuresearch_datastore.py index 4ae0182cc..3852258e3 100644 --- a/datastore/providers/azuresearch_datastore.py +++ b/datastore/providers/azuresearch_datastore.py @@ -28,7 +28,7 @@ # Allow overriding field names for Azure Search FIELDS_ID = os.environ.get("AZURESEARCH_FIELDS_ID", "id") FIELDS_TEXT = os.environ.get("AZURESEARCH_FIELDS_TEXT", "text") -FIELDS_EMBEDDING = os.environ.get("AZURESEARCH_FIELDS_TEXT", "embedding") +FIELDS_EMBEDDING = os.environ.get("AZURESEARCH_FIELDS_EMBEDDING", "embedding") FIELDS_DOCUMENT_ID = os.environ.get("AZURESEARCH_FIELDS_DOCUMENT_ID", "document_id") FIELDS_SOURCE = os.environ.get("AZURESEARCH_FIELDS_SOURCE", "source") FIELDS_SOURCE_ID = os.environ.get("AZURESEARCH_FIELDS_SOURCE_ID", "source_id") @@ -132,14 +132,16 @@ async def _single_query(self, query: QueryWithEmbedding) -> QueryResult: """ filter = self._translate_filter(query.filter) if query.filter is not None else None try: - k = query.top_k if filter is None else query.top_k * 2 + vector_top_k = query.top_k if filter is None else query.top_k * 2 q = query.query if not AZURESEARCH_DISABLE_HYBRID else None if AZURESEARCH_SEMANTIC_CONFIG != None and not AZURESEARCH_DISABLE_HYBRID: + # Ensure we're feeding a good number of candidates to the L2 reranker + vector_top_k = max(50, vector_top_k) r = await self.client.search( q, filter=filter, top=query.top_k, - vector=Vector(value=query.embedding, k=k, fields=FIELDS_EMBEDDING), + vector=Vector(value=query.embedding, k=vector_top_k, fields=FIELDS_EMBEDDING), query_type=QueryType.SEMANTIC, query_language=AZURESEARCH_LANGUAGE, semantic_configuration_name=AZURESEARCH_SEMANTIC_CONFIG) @@ -148,7 +150,7 @@ async def _single_query(self, query: QueryWithEmbedding) -> QueryResult: q, filter=filter, top=query.top_k, - vector=Vector(value=query.embedding, k=k, fields=FIELDS_EMBEDDING)) + vector=Vector(value=query.embedding, k=vector_top_k, fields=FIELDS_EMBEDDING)) results: List[DocumentChunkWithScore] = [] async for hit in r: f = lambda field: hit.get(field) if field != "-" else None