Skip to content

Commit f913467

Browse files
committed
Add vector_db_id to chunk metadata
Adding unit tests
1 parent eed25fc commit f913467

File tree

2 files changed

+81
-2
lines changed

2 files changed

+81
-2
lines changed

llama_stack/providers/inline/tool_runtime/rag/memory.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,18 @@ async def query(
131131
for vector_db_id in vector_db_ids
132132
]
133133
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
134-
chunks = [c for r in results for c in r.chunks]
135-
scores = [s for r in results for s in r.scores]
134+
135+
chunks = []
136+
scores = []
137+
138+
for vector_db_id, result in zip(vector_db_ids, results):
139+
for chunk, score in zip(result.chunks, result.scores):
140+
if not hasattr(chunk, "metadata") or chunk.metadata is None:
141+
chunk.metadata = {}
142+
chunk.metadata["vector_db_id"] = vector_db_id
143+
144+
chunks.append(chunk)
145+
scores.append(score)
136146

137147
if not chunks:
138148
return RAGQueryResult(content=None)

tests/unit/rag/test_rag_query.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,72 @@ async def test_query_accepts_valid_modes(self):
7777
# Test that invalid mode raises an error
7878
with pytest.raises(ValueError):
7979
RAGQueryConfig(mode="wrong_mode")
80+
81+
@pytest.mark.asyncio
82+
async def test_query_adds_vector_db_id_to_chunk_metadata(self):
83+
84+
rag_tool = MemoryToolRuntimeImpl(
85+
config=MagicMock(),
86+
vector_io_api=MagicMock(),
87+
inference_api=MagicMock(),
88+
)
89+
90+
vector_db_ids = ["db1", "db2"]
91+
92+
# Fake chunks from each DB
93+
chunk_metadata1 = ChunkMetadata(
94+
document_id="doc1",
95+
chunk_id="chunk1",
96+
source="test_source1",
97+
metadata_token_count=5,
98+
)
99+
chunk1 = Chunk(
100+
content="chunk from db1",
101+
metadata={"vector_db_id": "db1", "document_id": "doc1"},
102+
stored_chunk_id="c1",
103+
chunk_metadata=chunk_metadata1,
104+
)
105+
106+
chunk_metadata2 = ChunkMetadata(
107+
document_id="doc2",
108+
chunk_id="chunk2",
109+
source="test_source2",
110+
metadata_token_count=5,
111+
)
112+
chunk2 = Chunk(
113+
content="chunk from db2",
114+
metadata={"vector_db_id": "db2", "document_id": "doc2"},
115+
stored_chunk_id="c2",
116+
chunk_metadata=chunk_metadata2,
117+
)
118+
119+
rag_tool.vector_io_api.query_chunks = AsyncMock(
120+
side_effect=[
121+
QueryChunksResponse(chunks=[chunk1], scores=[0.9]),
122+
QueryChunksResponse(chunks=[chunk2], scores=[0.8]),
123+
]
124+
)
125+
126+
result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids)
127+
returned_chunks = result.metadata["chunks"]
128+
returned_scores = result.metadata["scores"]
129+
returned_doc_ids = result.metadata["document_ids"]
130+
131+
assert returned_chunks == ["chunk from db1", "chunk from db2"]
132+
assert returned_scores == (0.9, 0.8)
133+
assert returned_doc_ids == ["doc1", "doc2"]
134+
135+
# Parse metadata from query result
136+
def parse_metadata(s):
137+
import ast, re
138+
match = re.search(r"Metadata:\s*(\{.*\})", s)
139+
if not match:
140+
raise ValueError(f"No metadata found in string: {s}")
141+
return ast.literal_eval(match.group(1))
142+
143+
returned_metadata = [
144+
parse_metadata(item.text)["vector_db_id"]
145+
for item in result.content
146+
if "Metadata:" in item.text
147+
]
148+
assert returned_metadata == ["db1", "db2"]

0 commit comments

Comments
 (0)