Skip to content

Commit faf891b

Browse files
authored
refactor: use generic WeightedInMemoryAggregator for hybrid search in SQLiteVecIndex (#3303)
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> The purpose of this PR is to refactor `SQLiteVecIndex` to eliminate redundant code and simplify the code using generic `WeightedInMemoryAggregator` that can be used for any vector db provider. This pattern is already implemented for `PGVectorIndex` in #3064 CC: @franciscojavierarceo <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> 1. `./scripts/unit-tests.sh` 2. Integration tests in CI Workflow
1 parent 5c873d5 commit faf891b

File tree

1 file changed

+5
-62
lines changed

1 file changed

+5
-62
lines changed

llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py

Lines changed: 5 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
3131
from llama_stack.providers.utils.memory.vector_store import (
3232
RERANKER_TYPE_RRF,
33-
RERANKER_TYPE_WEIGHTED,
3433
ChunkForDeletion,
3534
EmbeddingIndex,
3635
VectorDBWithIndex,
3736
)
37+
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
3838

3939
logger = get_logger(name=__name__, category="vector_io")
4040

@@ -66,59 +66,6 @@ def _create_sqlite_connection(db_path):
6666
return connection
6767

6868

69-
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
70-
"""Normalize scores to [0,1] range using min-max normalization."""
71-
if not scores:
72-
return {}
73-
min_score = min(scores.values())
74-
max_score = max(scores.values())
75-
score_range = max_score - min_score
76-
if score_range > 0:
77-
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
78-
return dict.fromkeys(scores, 1.0)
79-
80-
81-
def _weighted_rerank(
82-
vector_scores: dict[str, float],
83-
keyword_scores: dict[str, float],
84-
alpha: float = 0.5,
85-
) -> dict[str, float]:
86-
"""ReRanker that uses weighted average of scores."""
87-
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
88-
normalized_vector_scores = _normalize_scores(vector_scores)
89-
normalized_keyword_scores = _normalize_scores(keyword_scores)
90-
91-
return {
92-
doc_id: (alpha * normalized_keyword_scores.get(doc_id, 0.0))
93-
+ ((1 - alpha) * normalized_vector_scores.get(doc_id, 0.0))
94-
for doc_id in all_ids
95-
}
96-
97-
98-
def _rrf_rerank(
99-
vector_scores: dict[str, float],
100-
keyword_scores: dict[str, float],
101-
impact_factor: float = 60.0,
102-
) -> dict[str, float]:
103-
"""ReRanker that uses Reciprocal Rank Fusion."""
104-
# Convert scores to ranks
105-
vector_ranks = {
106-
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
107-
}
108-
keyword_ranks = {
109-
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
110-
}
111-
112-
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
113-
rrf_scores = {}
114-
for doc_id in all_ids:
115-
vector_rank = vector_ranks.get(doc_id, float("inf"))
116-
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
117-
# RRF formula: score = 1/(k + r) where k is impact_factor and r is the rank
118-
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
119-
return rrf_scores
120-
121-
12269
def _make_sql_identifier(name: str) -> str:
12370
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
12471

@@ -398,14 +345,10 @@ async def query_hybrid(
398345
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
399346
}
400347

401-
# Combine scores using the specified reranker
402-
if reranker_type == RERANKER_TYPE_WEIGHTED:
403-
alpha = reranker_params.get("alpha", 0.5)
404-
combined_scores = _weighted_rerank(vector_scores, keyword_scores, alpha)
405-
else:
406-
# Default to RRF for None, RRF, or any unknown types
407-
impact_factor = reranker_params.get("impact_factor", 60.0)
408-
combined_scores = _rrf_rerank(vector_scores, keyword_scores, impact_factor)
348+
# Combine scores using the reranking utility
349+
combined_scores = WeightedInMemoryAggregator.combine_search_results(
350+
vector_scores, keyword_scores, reranker_type, reranker_params
351+
)
409352

410353
# Sort by combined score and get top k results
411354
sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)

0 commit comments

Comments
 (0)