Skip to content

Commit 7eac9c8

Browse files
committed
[ENH] add query config on collection configuration
1 parent 1784fc9 commit 7eac9c8

File tree

14 files changed

+169
-29
lines changed

14 files changed

+169
-29
lines changed

chromadb/api/collection_configuration.py

Lines changed: 78 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from typing import TypedDict, Dict, Any, Optional, cast, get_args
22
import json
3+
import copy
34
from chromadb.api.types import (
45
Space,
56
CollectionMetadata,
67
UpdateMetadata,
78
EmbeddingFunction,
9+
QueryConfig,
810
)
911
from chromadb.utils.embedding_functions import (
1012
known_embedding_functions,
@@ -41,6 +43,7 @@ class CollectionConfiguration(TypedDict, total=True):
4143
hnsw: Optional[HNSWConfiguration]
4244
spann: Optional[SpannConfiguration]
4345
embedding_function: Optional[EmbeddingFunction] # type: ignore
46+
query_embedding_function: Optional[EmbeddingFunction] # type: ignore
4447

4548

4649
def load_collection_configuration_from_json_str(
@@ -64,6 +67,8 @@ def load_collection_configuration_from_json(
6467
spann_config = None
6568
ef_config = None
6669

70+
query_ef = None
71+
6772
# Process vector index configuration (HNSW or SPANN)
6873
if config_json_map.get("hnsw") is not None:
6974
hnsw_config = cast(HNSWConfiguration, config_json_map["hnsw"])
@@ -100,13 +105,27 @@ def load_collection_configuration_from_json(
100105
f"Could not build embedding function {ef_config['name']} from config {ef_config['config']}: {e}"
101106
)
102107

108+
if config_json_map.get("query_config") is not None:
109+
query_config = config_json_map["query_config"]
110+
query_ef_config = copy.deepcopy(ef_config)
111+
query_ef = known_embedding_functions[ef_name]
112+
for k, v in query_config.items():
113+
query_ef_config["config"][k] = v
114+
115+
try:
116+
query_ef = query_ef.build_from_config(query_ef_config["config"]) # type: ignore
117+
except Exception as e:
118+
raise ValueError(
119+
f"Could not build query embedding function {query_ef_config['name']} from config {query_ef_config['config']}: {e}"
120+
)
103121
else:
104122
ef = None
105123

106124
return CollectionConfiguration(
107125
hnsw=hnsw_config,
108126
spann=spann_config,
109127
embedding_function=ef, # type: ignore
128+
query_embedding_function=query_ef, # type: ignore
110129
)
111130

112131

@@ -119,6 +138,7 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
119138
hnsw_config = config.get("hnsw")
120139
spann_config = config.get("spann")
121140
ef = config.get("embedding_function")
141+
query_ef = config.get("query_embedding_function")
122142
else:
123143
try:
124144
hnsw_config = config.get_parameter("hnsw").value
@@ -148,11 +168,6 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
148168
if ef is None:
149169
ef = None
150170
ef_config = {"type": "legacy"}
151-
return {
152-
"hnsw": hnsw_config,
153-
"spann": spann_config,
154-
"embedding_function": ef_config,
155-
}
156171

157172
if ef is not None:
158173
try:
@@ -174,10 +189,28 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
174189
ef = None
175190
ef_config = {"type": "legacy"}
176191

192+
query_ef_config: Dict[str, Any] | None = None
193+
if query_ef is not None:
194+
try:
195+
query_ef_config = {
196+
"name": query_ef.name(),
197+
"type": "known",
198+
"config": query_ef.get_config(),
199+
}
200+
except Exception as e:
201+
warnings.warn(
202+
f"legacy query embedding function config: {e}",
203+
DeprecationWarning,
204+
stacklevel=2,
205+
)
206+
query_ef = None
207+
query_ef_config = {"type": "legacy"}
208+
177209
return {
178210
"hnsw": hnsw_config,
179211
"spann": spann_config,
180212
"embedding_function": ef_config,
213+
"query_embedding_function": query_ef_config,
181214
}
182215

183216

@@ -258,16 +291,7 @@ class CreateCollectionConfiguration(TypedDict, total=False):
258291
hnsw: Optional[CreateHNSWConfiguration]
259292
spann: Optional[CreateSpannConfiguration]
260293
embedding_function: Optional[EmbeddingFunction] # type: ignore
261-
262-
263-
def load_collection_configuration_from_create_collection_configuration(
264-
config: CreateCollectionConfiguration,
265-
) -> CollectionConfiguration:
266-
return CollectionConfiguration(
267-
hnsw=config.get("hnsw"),
268-
spann=config.get("spann"),
269-
embedding_function=config.get("embedding_function"),
270-
)
294+
query_config: Optional[QueryConfig]
271295

272296

273297
def create_collection_configuration_from_legacy_collection_metadata(
@@ -301,13 +325,6 @@ def create_collection_configuration_from_legacy_metadata_dict(
301325
return CreateCollectionConfiguration(hnsw=hnsw_config)
302326

303327

304-
def load_create_collection_configuration_from_json_str(
305-
json_str: str,
306-
) -> CreateCollectionConfiguration:
307-
json_map = json.loads(json_str)
308-
return load_create_collection_configuration_from_json(json_map)
309-
310-
311328
# TODO: make warnings prettier and add link to migration docs
312329
def load_create_collection_configuration_from_json(
313330
json_map: Dict[str, Any]
@@ -353,6 +370,7 @@ def create_collection_configuration_to_json(
353370
) -> Dict[str, Any]:
354371
"""Convert a CreateCollection configuration to a JSON-serializable dict"""
355372
ef_config: Dict[str, Any] | None = None
373+
query_config: Dict[str, Any] | None = None
356374
hnsw_config = config.get("hnsw")
357375
spann_config = config.get("spann")
358376
if hnsw_config is not None:
@@ -389,6 +407,15 @@ def create_collection_configuration_to_json(
389407
"config": ef.get_config(),
390408
}
391409
register_embedding_function(type(ef)) # type: ignore
410+
411+
q = config.get("query_config")
412+
if q is not None:
413+
if q.name() == ef.name():
414+
query_config = q.get_config()
415+
else:
416+
raise ValueError(
417+
f"query config name {q.name()} does not match embedding function name {ef.name()}"
418+
)
392419
except Exception as e:
393420
warnings.warn(
394421
f"legacy embedding function config: {e}",
@@ -402,6 +429,7 @@ def create_collection_configuration_to_json(
402429
"hnsw": hnsw_config,
403430
"spann": spann_config,
404431
"embedding_function": ef_config,
432+
"query_config": query_config,
405433
}
406434

407435

@@ -473,6 +501,7 @@ class UpdateCollectionConfiguration(TypedDict, total=False):
473501
hnsw: Optional[UpdateHNSWConfiguration]
474502
spann: Optional[UpdateSpannConfiguration]
475503
embedding_function: Optional[EmbeddingFunction] # type: ignore
504+
query_config: Optional[QueryConfig]
476505

477506

478507
def update_collection_configuration_from_legacy_collection_metadata(
@@ -528,7 +557,9 @@ def update_collection_configuration_to_json(
528557
hnsw_config = config.get("hnsw")
529558
spann_config = config.get("spann")
530559
ef = config.get("embedding_function")
531-
if hnsw_config is None and spann_config is None and ef is None:
560+
q = config.get("query_config")
561+
query_config: Dict[str, Any] | None = None
562+
if hnsw_config is None and spann_config is None and ef is None and q is None:
532563
return {}
533564

534565
if hnsw_config is not None:
@@ -555,13 +586,21 @@ def update_collection_configuration_to_json(
555586
"config": ef.get_config(),
556587
}
557588
register_embedding_function(type(ef)) # type: ignore
589+
if q is not None:
590+
if q.name() == ef.name():
591+
query_config = q.get_config()
592+
else:
593+
raise ValueError(
594+
f"query config name {q.name()} does not match embedding function name {ef.name()}"
595+
)
558596
else:
559597
ef_config = None
560598

561599
return {
562600
"hnsw": hnsw_config,
563601
"spann": spann_config,
564602
"embedding_function": ef_config,
603+
"query_config": query_config,
565604
}
566605

567606

@@ -710,10 +749,26 @@ def overwrite_collection_configuration(
710749
else:
711750
updated_embedding_function = update_ef
712751

752+
query_ef = None
753+
if updated_embedding_function is not None:
754+
q = update_config.get("query_config")
755+
if q is not None:
756+
if q.name() != updated_embedding_function.name():
757+
raise ValueError(
758+
f"query config name {q.name()} does not match embedding function name {updated_embedding_function.name()}"
759+
)
760+
else:
761+
ef_config = copy.deepcopy(updated_embedding_function.get_config())
762+
query_config = q.get_config()
763+
for k, v in query_config.items():
764+
ef_config[k] = v
765+
query_ef = updated_embedding_function.build_from_config(ef_config)
766+
713767
return CollectionConfiguration(
714768
hnsw=updated_hnsw_config,
715769
spann=updated_spann_config,
716770
embedding_function=updated_embedding_function,
771+
query_embedding_function=query_ef,
717772
)
718773

719774

chromadb/api/models/CollectionCommon.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,9 @@ def _validate_and_prepare_query_request(
313313
# Prepare
314314
if query_records["embeddings"] is None:
315315
validate_record_set_for_embedding(record_set=query_records)
316-
request_embeddings = self._embed_record_set(record_set=query_records)
316+
request_embeddings = self._embed_record_set(
317+
record_set=query_records, is_query=True
318+
)
317319
else:
318320
request_embeddings = query_records["embeddings"]
319321

@@ -531,7 +533,10 @@ def _update_model_after_modify_success(
531533
)
532534

533535
def _embed_record_set(
534-
self, record_set: BaseRecordSet, embeddable_fields: Optional[Set[str]] = None
536+
self,
537+
record_set: BaseRecordSet,
538+
embeddable_fields: Optional[Set[str]] = None,
539+
is_query: bool = False,
535540
) -> Embeddings:
536541
if embeddable_fields is None:
537542
embeddable_fields = get_default_embeddable_record_set_fields()
@@ -545,21 +550,30 @@ def _embed_record_set(
545550
"You must set a data loader on the collection if loading from URIs."
546551
)
547552
return self._embed(
548-
input=self._data_loader(uris=cast(URIs, record_set[field])) # type: ignore[literal-required]
553+
input=self._data_loader(uris=cast(URIs, record_set[field])), # type: ignore[literal-required]
554+
is_query=is_query,
549555
)
550556
else:
551-
return self._embed(input=record_set[field]) # type: ignore[literal-required]
557+
return self._embed(
558+
input=record_set[field], # type: ignore[literal-required]
559+
is_query=is_query,
560+
)
552561
raise ValueError(
553562
"Record does not contain any non-None fields that can be embedded."
554563
f"Embeddable Fields: {embeddable_fields}"
555564
f"Record Fields: {record_set}"
556565
)
557566

558-
def _embed(self, input: Any) -> Embeddings:
567+
def _embed(self, input: Any, is_query: bool = False) -> Embeddings:
559568
if self._embedding_function is not None and not isinstance(
560569
self._embedding_function, ef.DefaultEmbeddingFunction
561570
):
562571
return self._embedding_function(input=input)
572+
if is_query:
573+
config_ef = self.configuration.get("query_embedding_function")
574+
if config_ef is not None:
575+
return config_ef(input=input)
576+
563577
config_ef = self.configuration.get("embedding_function")
564578
if config_ef is not None:
565579
return config_ef(input=input)

chromadb/api/types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,16 @@ def is_legacy(self) -> bool:
672672
return False
673673

674674

675+
class QueryConfig:
676+
@abstractmethod
677+
def name(self) -> str:
678+
return NotImplemented
679+
680+
@abstractmethod
681+
def get_config(self) -> Dict[str, Any]:
682+
return NotImplemented
683+
684+
675685
def validate_embedding_function(
676686
embedding_function: EmbeddingFunction[Embeddable],
677687
) -> None:

chromadb/utils/embedding_functions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from chromadb.utils.embedding_functions.jina_embedding_function import (
3434
JinaEmbeddingFunction,
35+
JinaQueryConfig,
3536
)
3637
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
3738
VoyageAIEmbeddingFunction,
@@ -232,6 +233,7 @@ def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction:
232233
"OllamaEmbeddingFunction",
233234
"InstructorEmbeddingFunction",
234235
"JinaEmbeddingFunction",
236+
"JinaQueryConfig",
235237
"MistralEmbeddingFunction",
236238
"VoyageAIEmbeddingFunction",
237239
"ONNXMiniLM_L6_V2",

chromadb/utils/embedding_functions/jina_embedding_function.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1-
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
1+
from chromadb.api.types import (
2+
Embeddings,
3+
Documents,
4+
EmbeddingFunction,
5+
Space,
6+
QueryConfig,
7+
)
28
from chromadb.utils.embedding_functions.schemas import validate_config_schema
39
from typing import List, Dict, Any, Union, Optional
10+
from typing_extensions import override
411
import os
512
import numpy as np
613
import warnings
@@ -206,3 +213,17 @@ def validate_config(config: Dict[str, Any]) -> None:
206213
ValidationError: If the configuration does not match the schema
207214
"""
208215
validate_config_schema(config, "jina")
216+
217+
218+
class JinaQueryConfig(QueryConfig):
219+
def __init__(self, task: Optional[str] = None):
220+
self.task = task
221+
222+
@override
223+
def name(self) -> str:
224+
return "jina"
225+
226+
def get_config(self) -> Dict[str, Any]:
227+
return {
228+
"task": self.task,
229+
}

clients/js/packages/chromadb-core/src/generated/models.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ export namespace Api {
6767
export interface CollectionConfiguration {
6868
embedding_function?: Api.EmbeddingFunctionConfiguration | null;
6969
hnsw?: Api.HnswConfiguration | null;
70+
query_config?: unknown;
7071
spann?: Api.SpannConfiguration | null;
7172
}
7273

@@ -335,6 +336,7 @@ export namespace Api {
335336
export interface UpdateCollectionConfiguration {
336337
embedding_function?: Api.EmbeddingFunctionConfiguration | null;
337338
hnsw?: Api.UpdateHnswConfiguration | null;
339+
query_config?: unknown;
338340
spann?: Api.SpannConfiguration | null;
339341
}
340342

clients/new-js/packages/chromadb/src/api/types.gen.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ export type Collection = {
3232
export type CollectionConfiguration = {
3333
embedding_function?: null | EmbeddingFunctionConfiguration;
3434
hnsw?: null | HnswConfiguration;
35+
query_config?: unknown;
3536
spann?: null | SpannConfiguration;
3637
};
3738

@@ -196,6 +197,7 @@ export type SpannConfiguration = {
196197
export type UpdateCollectionConfiguration = {
197198
embedding_function?: null | EmbeddingFunctionConfiguration;
198199
hnsw?: null | UpdateHnswConfiguration;
200+
query_config?: unknown;
199201
spann?: null | SpannConfiguration;
200202
};
201203

0 commit comments

Comments
 (0)