@@ -77,3 +77,72 @@ async def test_query_accepts_valid_modes(self):
77
77
# Test that invalid mode raises an error
78
78
with pytest .raises (ValueError ):
79
79
RAGQueryConfig (mode = "wrong_mode" )
80
+
81
+ @pytest .mark .asyncio
82
+ async def test_query_adds_vector_db_id_to_chunk_metadata (self ):
83
+
84
+ rag_tool = MemoryToolRuntimeImpl (
85
+ config = MagicMock (),
86
+ vector_io_api = MagicMock (),
87
+ inference_api = MagicMock (),
88
+ )
89
+
90
+ vector_db_ids = ["db1" , "db2" ]
91
+
92
+ # Fake chunks from each DB
93
+ chunk_metadata1 = ChunkMetadata (
94
+ document_id = "doc1" ,
95
+ chunk_id = "chunk1" ,
96
+ source = "test_source1" ,
97
+ metadata_token_count = 5 ,
98
+ )
99
+ chunk1 = Chunk (
100
+ content = "chunk from db1" ,
101
+ metadata = {"vector_db_id" : "db1" , "document_id" : "doc1" },
102
+ stored_chunk_id = "c1" ,
103
+ chunk_metadata = chunk_metadata1 ,
104
+ )
105
+
106
+ chunk_metadata2 = ChunkMetadata (
107
+ document_id = "doc2" ,
108
+ chunk_id = "chunk2" ,
109
+ source = "test_source2" ,
110
+ metadata_token_count = 5 ,
111
+ )
112
+ chunk2 = Chunk (
113
+ content = "chunk from db2" ,
114
+ metadata = {"vector_db_id" : "db2" , "document_id" : "doc2" },
115
+ stored_chunk_id = "c2" ,
116
+ chunk_metadata = chunk_metadata2 ,
117
+ )
118
+
119
+ rag_tool .vector_io_api .query_chunks = AsyncMock (
120
+ side_effect = [
121
+ QueryChunksResponse (chunks = [chunk1 ], scores = [0.9 ]),
122
+ QueryChunksResponse (chunks = [chunk2 ], scores = [0.8 ]),
123
+ ]
124
+ )
125
+
126
+ result = await rag_tool .query (content = "test" , vector_db_ids = vector_db_ids )
127
+ returned_chunks = result .metadata ["chunks" ]
128
+ returned_scores = result .metadata ["scores" ]
129
+ returned_doc_ids = result .metadata ["document_ids" ]
130
+
131
+ assert returned_chunks == ["chunk from db1" , "chunk from db2" ]
132
+ assert returned_scores == (0.9 , 0.8 )
133
+ assert returned_doc_ids == ["doc1" , "doc2" ]
134
+
135
+ # Parse metadata from query result
136
+ def parse_metadata (s ):
137
+ import ast , re
138
+ match = re .search (r"Metadata:\s*(\{.*\})" , s )
139
+ if not match :
140
+ raise ValueError (f"No metadata found in string: { s } " )
141
+ return ast .literal_eval (match .group (1 ))
142
+
143
+ returned_metadata = [
144
+ parse_metadata (item .text )["vector_db_id" ]
145
+ for item in result .content
146
+ if "Metadata:" in item .text
147
+ ]
148
+ assert returned_metadata == ["db1" , "db2" ]
0 commit comments