1
1
from typing import TypedDict , Dict , Any , Optional , cast , get_args
2
2
import json
3
+ import copy
3
4
from chromadb .api .types import (
4
5
Space ,
5
6
CollectionMetadata ,
6
7
UpdateMetadata ,
7
8
EmbeddingFunction ,
9
+ QueryConfig ,
8
10
)
9
11
from chromadb .utils .embedding_functions import (
10
12
known_embedding_functions ,
@@ -41,6 +43,7 @@ class CollectionConfiguration(TypedDict, total=True):
41
43
hnsw : Optional [HNSWConfiguration ]
42
44
spann : Optional [SpannConfiguration ]
43
45
embedding_function : Optional [EmbeddingFunction ] # type: ignore
46
+ query_embedding_function : Optional [EmbeddingFunction ] # type: ignore
44
47
45
48
46
49
def load_collection_configuration_from_json_str (
@@ -64,6 +67,8 @@ def load_collection_configuration_from_json(
64
67
spann_config = None
65
68
ef_config = None
66
69
70
+ query_ef = None
71
+
67
72
# Process vector index configuration (HNSW or SPANN)
68
73
if config_json_map .get ("hnsw" ) is not None :
69
74
hnsw_config = cast (HNSWConfiguration , config_json_map ["hnsw" ])
@@ -100,13 +105,27 @@ def load_collection_configuration_from_json(
100
105
f"Could not build embedding function { ef_config ['name' ]} from config { ef_config ['config' ]} : { e } "
101
106
)
102
107
108
+ if config_json_map .get ("query_config" ) is not None :
109
+ query_config = config_json_map ["query_config" ]
110
+ query_ef_config = copy .deepcopy (ef_config )
111
+ query_ef = known_embedding_functions [ef_name ]
112
+ for k , v in query_config .items ():
113
+ query_ef_config ["config" ][k ] = v
114
+
115
+ try :
116
+ query_ef = query_ef .build_from_config (query_ef_config ["config" ]) # type: ignore
117
+ except Exception as e :
118
+ raise ValueError (
119
+ f"Could not build query embedding function { query_ef_config ['name' ]} from config { query_ef_config ['config' ]} : { e } "
120
+ )
103
121
else :
104
122
ef = None
105
123
106
124
return CollectionConfiguration (
107
125
hnsw = hnsw_config ,
108
126
spann = spann_config ,
109
127
embedding_function = ef , # type: ignore
128
+ query_embedding_function = query_ef , # type: ignore
110
129
)
111
130
112
131
@@ -119,6 +138,7 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
119
138
hnsw_config = config .get ("hnsw" )
120
139
spann_config = config .get ("spann" )
121
140
ef = config .get ("embedding_function" )
141
+ query_ef = config .get ("query_embedding_function" )
122
142
else :
123
143
try :
124
144
hnsw_config = config .get_parameter ("hnsw" ).value
@@ -148,11 +168,6 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
148
168
if ef is None :
149
169
ef = None
150
170
ef_config = {"type" : "legacy" }
151
- return {
152
- "hnsw" : hnsw_config ,
153
- "spann" : spann_config ,
154
- "embedding_function" : ef_config ,
155
- }
156
171
157
172
if ef is not None :
158
173
try :
@@ -174,10 +189,28 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
174
189
ef = None
175
190
ef_config = {"type" : "legacy" }
176
191
192
+ query_ef_config : Dict [str , Any ] | None = None
193
+ if query_ef is not None :
194
+ try :
195
+ query_ef_config = {
196
+ "name" : query_ef .name (),
197
+ "type" : "known" ,
198
+ "config" : query_ef .get_config (),
199
+ }
200
+ except Exception as e :
201
+ warnings .warn (
202
+ f"legacy query embedding function config: { e } " ,
203
+ DeprecationWarning ,
204
+ stacklevel = 2 ,
205
+ )
206
+ query_ef = None
207
+ query_ef_config = {"type" : "legacy" }
208
+
177
209
return {
178
210
"hnsw" : hnsw_config ,
179
211
"spann" : spann_config ,
180
212
"embedding_function" : ef_config ,
213
+ "query_embedding_function" : query_ef_config ,
181
214
}
182
215
183
216
@@ -258,16 +291,7 @@ class CreateCollectionConfiguration(TypedDict, total=False):
258
291
hnsw : Optional [CreateHNSWConfiguration ]
259
292
spann : Optional [CreateSpannConfiguration ]
260
293
embedding_function : Optional [EmbeddingFunction ] # type: ignore
261
-
262
-
263
- def load_collection_configuration_from_create_collection_configuration (
264
- config : CreateCollectionConfiguration ,
265
- ) -> CollectionConfiguration :
266
- return CollectionConfiguration (
267
- hnsw = config .get ("hnsw" ),
268
- spann = config .get ("spann" ),
269
- embedding_function = config .get ("embedding_function" ),
270
- )
294
+ query_config : Optional [QueryConfig ]
271
295
272
296
273
297
def create_collection_configuration_from_legacy_collection_metadata (
@@ -301,13 +325,6 @@ def create_collection_configuration_from_legacy_metadata_dict(
301
325
return CreateCollectionConfiguration (hnsw = hnsw_config )
302
326
303
327
304
- def load_create_collection_configuration_from_json_str (
305
- json_str : str ,
306
- ) -> CreateCollectionConfiguration :
307
- json_map = json .loads (json_str )
308
- return load_create_collection_configuration_from_json (json_map )
309
-
310
-
311
328
# TODO: make warnings prettier and add link to migration docs
312
329
def load_create_collection_configuration_from_json (
313
330
json_map : Dict [str , Any ]
@@ -353,6 +370,7 @@ def create_collection_configuration_to_json(
353
370
) -> Dict [str , Any ]:
354
371
"""Convert a CreateCollection configuration to a JSON-serializable dict"""
355
372
ef_config : Dict [str , Any ] | None = None
373
+ query_config : Dict [str , Any ] | None = None
356
374
hnsw_config = config .get ("hnsw" )
357
375
spann_config = config .get ("spann" )
358
376
if hnsw_config is not None :
@@ -389,6 +407,15 @@ def create_collection_configuration_to_json(
389
407
"config" : ef .get_config (),
390
408
}
391
409
register_embedding_function (type (ef )) # type: ignore
410
+
411
+ q = config .get ("query_config" )
412
+ if q is not None :
413
+ if q .name () == ef .name ():
414
+ query_config = q .get_config ()
415
+ else :
416
+ raise ValueError (
417
+ f"query config name { q .name ()} does not match embedding function name { ef .name ()} "
418
+ )
392
419
except Exception as e :
393
420
warnings .warn (
394
421
f"legacy embedding function config: { e } " ,
@@ -402,6 +429,7 @@ def create_collection_configuration_to_json(
402
429
"hnsw" : hnsw_config ,
403
430
"spann" : spann_config ,
404
431
"embedding_function" : ef_config ,
432
+ "query_config" : query_config ,
405
433
}
406
434
407
435
@@ -473,6 +501,7 @@ class UpdateCollectionConfiguration(TypedDict, total=False):
473
501
hnsw : Optional [UpdateHNSWConfiguration ]
474
502
spann : Optional [UpdateSpannConfiguration ]
475
503
embedding_function : Optional [EmbeddingFunction ] # type: ignore
504
+ query_config : Optional [QueryConfig ]
476
505
477
506
478
507
def update_collection_configuration_from_legacy_collection_metadata (
@@ -528,7 +557,9 @@ def update_collection_configuration_to_json(
528
557
hnsw_config = config .get ("hnsw" )
529
558
spann_config = config .get ("spann" )
530
559
ef = config .get ("embedding_function" )
531
- if hnsw_config is None and spann_config is None and ef is None :
560
+ q = config .get ("query_config" )
561
+ query_config : Dict [str , Any ] | None = None
562
+ if hnsw_config is None and spann_config is None and ef is None and q is None :
532
563
return {}
533
564
534
565
if hnsw_config is not None :
@@ -555,13 +586,21 @@ def update_collection_configuration_to_json(
555
586
"config" : ef .get_config (),
556
587
}
557
588
register_embedding_function (type (ef )) # type: ignore
589
+ if q is not None :
590
+ if q .name () == ef .name ():
591
+ query_config = q .get_config ()
592
+ else :
593
+ raise ValueError (
594
+ f"query config name { q .name ()} does not match embedding function name { ef .name ()} "
595
+ )
558
596
else :
559
597
ef_config = None
560
598
561
599
return {
562
600
"hnsw" : hnsw_config ,
563
601
"spann" : spann_config ,
564
602
"embedding_function" : ef_config ,
603
+ "query_config" : query_config ,
565
604
}
566
605
567
606
@@ -710,10 +749,26 @@ def overwrite_collection_configuration(
710
749
else :
711
750
updated_embedding_function = update_ef
712
751
752
+ query_ef = None
753
+ if updated_embedding_function is not None :
754
+ q = update_config .get ("query_config" )
755
+ if q is not None :
756
+ if q .name () != updated_embedding_function .name ():
757
+ raise ValueError (
758
+ f"query config name { q .name ()} does not match embedding function name { updated_embedding_function .name ()} "
759
+ )
760
+ else :
761
+ ef_config = copy .deepcopy (updated_embedding_function .get_config ())
762
+ query_config = q .get_config ()
763
+ for k , v in query_config .items ():
764
+ ef_config [k ] = v
765
+ query_ef = updated_embedding_function .build_from_config (ef_config )
766
+
713
767
return CollectionConfiguration (
714
768
hnsw = updated_hnsw_config ,
715
769
spann = updated_spann_config ,
716
770
embedding_function = updated_embedding_function ,
771
+ query_embedding_function = query_ef ,
717
772
)
718
773
719
774
0 commit comments