Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions llama_stack/providers/inline/tool_runtime/rag/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,18 @@ async def query(
for vector_db_id in vector_db_ids
]
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]

chunks = []
scores = []

for vector_db_id, result in zip(vector_db_ids, results, strict=False):
for chunk, score in zip(result.chunks, result.scores, strict=False):
if not hasattr(chunk, "metadata") or chunk.metadata is None:
chunk.metadata = {}
chunk.metadata["vector_db_id"] = vector_db_id

chunks.append(chunk)
scores.append(score)

if not chunks:
return RAGQueryResult(content=None)
Expand Down Expand Up @@ -167,6 +177,7 @@ async def query(
metadata_keys_to_exclude_from_context = [
"token_count",
"metadata_token_count",
"vector_db_id",
]
metadata_for_context = {}
for k in chunk_metadata_keys_to_include_from_context:
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/providers/registry/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter=AdapterSpec(
adapter_type="watsonx",
pip_packages=["ibm_watson_machine_learning"],
pip_packages=["ibm_watsonx_ai"],
module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/remote/inference/watsonx/watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any

from ibm_watson_machine_learning.foundation_models import Model
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
from ibm_watsonx_ai.foundation_models import Model
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
from openai import AsyncOpenAI

from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
Expand Down
72 changes: 69 additions & 3 deletions tests/unit/rag/test_rag_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ async def test_query_chunk_metadata_handling(self):
result = await rag_tool.query(content=content, vector_db_ids=vector_db_ids)

assert result is not None
expected_metadata_string = (
"Metadata: {'chunk_id': 'chunk1', 'document_id': 'doc1', 'source': 'test_source', 'key1': 'value1'}"
)
expected_metadata_string = "Metadata: {'chunk_id': 'chunk1', 'document_id': 'doc1', 'source': 'test_source', 'key1': 'value1', 'vector_db_id': 'db1'}"
assert expected_metadata_string in result.content[1].text
assert result.content is not None

Expand All @@ -77,3 +75,71 @@ async def test_query_accepts_valid_modes(self):
# Test that invalid mode raises an error
with pytest.raises(ValueError):
RAGQueryConfig(mode="wrong_mode")

@pytest.mark.asyncio
async def test_query_adds_vector_db_id_to_chunk_metadata(self):
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(),
vector_io_api=MagicMock(),
inference_api=MagicMock(),
)

vector_db_ids = ["db1", "db2"]

# Fake chunks from each DB
chunk_metadata1 = ChunkMetadata(
document_id="doc1",
chunk_id="chunk1",
source="test_source1",
metadata_token_count=5,
)
chunk1 = Chunk(
content="chunk from db1",
metadata={"vector_db_id": "db1", "document_id": "doc1"},
stored_chunk_id="c1",
chunk_metadata=chunk_metadata1,
)

chunk_metadata2 = ChunkMetadata(
document_id="doc2",
chunk_id="chunk2",
source="test_source2",
metadata_token_count=5,
)
chunk2 = Chunk(
content="chunk from db2",
metadata={"vector_db_id": "db2", "document_id": "doc2"},
stored_chunk_id="c2",
chunk_metadata=chunk_metadata2,
)

rag_tool.vector_io_api.query_chunks = AsyncMock(
side_effect=[
QueryChunksResponse(chunks=[chunk1], scores=[0.9]),
QueryChunksResponse(chunks=[chunk2], scores=[0.8]),
]
)

result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids)
returned_chunks = result.metadata["chunks"]
returned_scores = result.metadata["scores"]
returned_doc_ids = result.metadata["document_ids"]

assert returned_chunks == ["chunk from db1", "chunk from db2"]
assert returned_scores == (0.9, 0.8)
assert returned_doc_ids == ["doc1", "doc2"]

# Parse metadata from query result
def parse_metadata(s):
import ast
import re

match = re.search(r"Metadata:\s*(\{.*\})", s)
if not match:
raise ValueError(f"No metadata found in string: {s}")
return ast.literal_eval(match.group(1))

returned_metadata = [
parse_metadata(item.text)["vector_db_id"] for item in result.content if "Metadata:" in item.text
]
assert returned_metadata == ["db1", "db2"]
Loading