Skip to content

Commit fe517f1

Browse files
are-cesfranciscojavierarceo
authored andcommitted
feat: Add vector_db_id to chunk metadata (#3304)
# What does this PR do? When running RAG in a multi vector DB setting, it can be difficult to trace where retrieved chunks originate from. This PR adds the `vector_db_id` into each chunk’s metadata, making it easier to understand which database a given chunk came from. This is helpful for debugging and for analyzing retrieval behavior of multiple DBs. Relevant code: ```python for vector_db_id, result in zip(vector_db_ids, results): for chunk, score in zip(result.chunks, result.scores): if not hasattr(chunk, "metadata") or chunk.metadata is None: chunk.metadata = {} chunk.metadata["vector_db_id"] = vector_db_id chunks.append(chunk) scores.append(score) ``` ## Test Plan * Ran Llama Stack in debug mode. * Verified that `vector_db_id` was added to each chunk’s metadata. * Confirmed that the metadata was printed in the console when using the RAG tool. --------- Co-authored-by: are-ces <[email protected]> Co-authored-by: Francisco Arceo <[email protected]>
1 parent ce5885d commit fe517f1

File tree

2 files changed

+69
-2
lines changed

2 files changed

+69
-2
lines changed

llama_stack/providers/inline/tool_runtime/rag/memory.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,18 @@ async def query(
167167
for vector_db_id in vector_db_ids
168168
]
169169
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
170-
chunks = [c for r in results for c in r.chunks]
171-
scores = [s for r in results for s in r.scores]
170+
171+
chunks = []
172+
scores = []
173+
174+
for vector_db_id, result in zip(vector_db_ids, results, strict=False):
175+
for chunk, score in zip(result.chunks, result.scores, strict=False):
176+
if not hasattr(chunk, "metadata") or chunk.metadata is None:
177+
chunk.metadata = {}
178+
chunk.metadata["vector_db_id"] = vector_db_id
179+
180+
chunks.append(chunk)
181+
scores.append(score)
172182

173183
if not chunks:
174184
return RAGQueryResult(content=None)
@@ -203,6 +213,7 @@ async def query(
203213
metadata_keys_to_exclude_from_context = [
204214
"token_count",
205215
"metadata_token_count",
216+
"vector_db_id",
206217
]
207218
metadata_for_context = {}
208219
for k in chunk_metadata_keys_to_include_from_context:
@@ -227,6 +238,7 @@ async def query(
227238
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
228239
"chunks": [c.content for c in chunks[: len(picked)]],
229240
"scores": scores[: len(picked)],
241+
"vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]],
230242
},
231243
)
232244

tests/unit/rag/test_rag_query.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,58 @@ async def test_query_accepts_valid_modes(self):
8181
# Test that invalid mode raises an error
8282
with pytest.raises(ValueError):
8383
RAGQueryConfig(mode="wrong_mode")
84+
85+
@pytest.mark.asyncio
86+
async def test_query_adds_vector_db_id_to_chunk_metadata(self):
87+
rag_tool = MemoryToolRuntimeImpl(
88+
config=MagicMock(),
89+
vector_io_api=MagicMock(),
90+
inference_api=MagicMock(),
91+
)
92+
93+
vector_db_ids = ["db1", "db2"]
94+
95+
# Fake chunks from each DB
96+
chunk_metadata1 = ChunkMetadata(
97+
document_id="doc1",
98+
chunk_id="chunk1",
99+
source="test_source1",
100+
metadata_token_count=5,
101+
)
102+
chunk1 = Chunk(
103+
content="chunk from db1",
104+
metadata={"vector_db_id": "db1", "document_id": "doc1"},
105+
stored_chunk_id="c1",
106+
chunk_metadata=chunk_metadata1,
107+
)
108+
109+
chunk_metadata2 = ChunkMetadata(
110+
document_id="doc2",
111+
chunk_id="chunk2",
112+
source="test_source2",
113+
metadata_token_count=5,
114+
)
115+
chunk2 = Chunk(
116+
content="chunk from db2",
117+
metadata={"vector_db_id": "db2", "document_id": "doc2"},
118+
stored_chunk_id="c2",
119+
chunk_metadata=chunk_metadata2,
120+
)
121+
122+
rag_tool.vector_io_api.query_chunks = AsyncMock(
123+
side_effect=[
124+
QueryChunksResponse(chunks=[chunk1], scores=[0.9]),
125+
QueryChunksResponse(chunks=[chunk2], scores=[0.8]),
126+
]
127+
)
128+
129+
result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids)
130+
returned_chunks = result.metadata["chunks"]
131+
returned_scores = result.metadata["scores"]
132+
returned_doc_ids = result.metadata["document_ids"]
133+
returned_vector_db_ids = result.metadata["vector_db_ids"]
134+
135+
assert returned_chunks == ["chunk from db1", "chunk from db2"]
136+
assert returned_scores == (0.9, 0.8)
137+
assert returned_doc_ids == ["doc1", "doc2"]
138+
assert returned_vector_db_ids == ["db1", "db2"]

0 commit comments

Comments
 (0)