Skip to content

Commit

Permalink
Fix min number of candidates for L2 reranking and env name for overri…
Browse files Browse the repository at this point in the history
…ding the embedding field name (openai#260)
  • Loading branch information
pablocastro authored May 15, 2023
1 parent f10c677 commit 0d3c789
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions datastore/providers/azuresearch_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# Allow overriding field names for Azure Search
FIELDS_ID = os.environ.get("AZURESEARCH_FIELDS_ID", "id")
FIELDS_TEXT = os.environ.get("AZURESEARCH_FIELDS_TEXT", "text")
FIELDS_EMBEDDING = os.environ.get("AZURESEARCH_FIELDS_TEXT", "embedding")
FIELDS_EMBEDDING = os.environ.get("AZURESEARCH_FIELDS_EMBEDDING", "embedding")
FIELDS_DOCUMENT_ID = os.environ.get("AZURESEARCH_FIELDS_DOCUMENT_ID", "document_id")
FIELDS_SOURCE = os.environ.get("AZURESEARCH_FIELDS_SOURCE", "source")
FIELDS_SOURCE_ID = os.environ.get("AZURESEARCH_FIELDS_SOURCE_ID", "source_id")
Expand Down Expand Up @@ -132,14 +132,16 @@ async def _single_query(self, query: QueryWithEmbedding) -> QueryResult:
"""
filter = self._translate_filter(query.filter) if query.filter is not None else None
try:
k = query.top_k if filter is None else query.top_k * 2
vector_top_k = query.top_k if filter is None else query.top_k * 2
q = query.query if not AZURESEARCH_DISABLE_HYBRID else None
if AZURESEARCH_SEMANTIC_CONFIG != None and not AZURESEARCH_DISABLE_HYBRID:
# Ensure we're feeding a good number of candidates to the L2 reranker
vector_top_k = max(50, vector_top_k)
r = await self.client.search(
q,
filter=filter,
top=query.top_k,
vector=Vector(value=query.embedding, k=k, fields=FIELDS_EMBEDDING),
vector=Vector(value=query.embedding, k=vector_top_k, fields=FIELDS_EMBEDDING),
query_type=QueryType.SEMANTIC,
query_language=AZURESEARCH_LANGUAGE,
semantic_configuration_name=AZURESEARCH_SEMANTIC_CONFIG)
Expand All @@ -148,7 +150,7 @@ async def _single_query(self, query: QueryWithEmbedding) -> QueryResult:
q,
filter=filter,
top=query.top_k,
vector=Vector(value=query.embedding, k=k, fields=FIELDS_EMBEDDING))
vector=Vector(value=query.embedding, k=vector_top_k, fields=FIELDS_EMBEDDING))
results: List[DocumentChunkWithScore] = []
async for hit in r:
f = lambda field: hit.get(field) if field != "-" else None
Expand Down

0 comments on commit 0d3c789

Please sign in to comment.