|
30 | 30 | from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin |
31 | 31 | from llama_stack.providers.utils.memory.vector_store import ( |
32 | 32 | RERANKER_TYPE_RRF, |
33 | | - RERANKER_TYPE_WEIGHTED, |
34 | 33 | ChunkForDeletion, |
35 | 34 | EmbeddingIndex, |
36 | 35 | VectorDBWithIndex, |
37 | 36 | ) |
| 37 | +from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator |
38 | 38 |
|
39 | 39 | logger = get_logger(name=__name__, category="vector_io") |
40 | 40 |
|
@@ -66,59 +66,6 @@ def _create_sqlite_connection(db_path): |
66 | 66 | return connection |
67 | 67 |
|
68 | 68 |
|
69 | | -def _normalize_scores(scores: dict[str, float]) -> dict[str, float]: |
70 | | - """Normalize scores to [0,1] range using min-max normalization.""" |
71 | | - if not scores: |
72 | | - return {} |
73 | | - min_score = min(scores.values()) |
74 | | - max_score = max(scores.values()) |
75 | | - score_range = max_score - min_score |
76 | | - if score_range > 0: |
77 | | - return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()} |
78 | | - return dict.fromkeys(scores, 1.0) |
79 | | - |
80 | | - |
81 | | -def _weighted_rerank( |
82 | | - vector_scores: dict[str, float], |
83 | | - keyword_scores: dict[str, float], |
84 | | - alpha: float = 0.5, |
85 | | -) -> dict[str, float]: |
86 | | - """ReRanker that uses weighted average of scores.""" |
87 | | - all_ids = set(vector_scores.keys()) | set(keyword_scores.keys()) |
88 | | - normalized_vector_scores = _normalize_scores(vector_scores) |
89 | | - normalized_keyword_scores = _normalize_scores(keyword_scores) |
90 | | - |
91 | | - return { |
92 | | - doc_id: (alpha * normalized_keyword_scores.get(doc_id, 0.0)) |
93 | | - + ((1 - alpha) * normalized_vector_scores.get(doc_id, 0.0)) |
94 | | - for doc_id in all_ids |
95 | | - } |
96 | | - |
97 | | - |
98 | | -def _rrf_rerank( |
99 | | - vector_scores: dict[str, float], |
100 | | - keyword_scores: dict[str, float], |
101 | | - impact_factor: float = 60.0, |
102 | | -) -> dict[str, float]: |
103 | | - """ReRanker that uses Reciprocal Rank Fusion.""" |
104 | | - # Convert scores to ranks |
105 | | - vector_ranks = { |
106 | | - doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True)) |
107 | | - } |
108 | | - keyword_ranks = { |
109 | | - doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True)) |
110 | | - } |
111 | | - |
112 | | - all_ids = set(vector_scores.keys()) | set(keyword_scores.keys()) |
113 | | - rrf_scores = {} |
114 | | - for doc_id in all_ids: |
115 | | - vector_rank = vector_ranks.get(doc_id, float("inf")) |
116 | | - keyword_rank = keyword_ranks.get(doc_id, float("inf")) |
117 | | - # RRF formula: score = 1/(k + r) where k is impact_factor and r is the rank |
118 | | - rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank)) |
119 | | - return rrf_scores |
120 | | - |
121 | | - |
122 | 69 | def _make_sql_identifier(name: str) -> str: |
123 | 70 | return re.sub(r"[^a-zA-Z0-9_]", "_", name) |
124 | 71 |
|
@@ -398,14 +345,10 @@ async def query_hybrid( |
398 | 345 | for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False) |
399 | 346 | } |
400 | 347 |
|
401 | | - # Combine scores using the specified reranker |
402 | | - if reranker_type == RERANKER_TYPE_WEIGHTED: |
403 | | - alpha = reranker_params.get("alpha", 0.5) |
404 | | - combined_scores = _weighted_rerank(vector_scores, keyword_scores, alpha) |
405 | | - else: |
406 | | - # Default to RRF for None, RRF, or any unknown types |
407 | | - impact_factor = reranker_params.get("impact_factor", 60.0) |
408 | | - combined_scores = _rrf_rerank(vector_scores, keyword_scores, impact_factor) |
| 348 | + # Combine scores using the reranking utility |
| 349 | + combined_scores = WeightedInMemoryAggregator.combine_search_results( |
| 350 | + vector_scores, keyword_scores, reranker_type, reranker_params |
| 351 | + ) |
409 | 352 |
|
410 | 353 | # Sort by combined score and get top k results |
411 | 354 | sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True) |
|
0 commit comments