Skip to content

[ENH] add query config on collection configuration #4901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 78 additions & 23 deletions chromadb/api/collection_configuration.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import TypedDict, Dict, Any, Optional, cast, get_args
import json
import copy
from chromadb.api.types import (
Space,
CollectionMetadata,
UpdateMetadata,
EmbeddingFunction,
QueryConfig,
)
from chromadb.utils.embedding_functions import (
known_embedding_functions,
Expand Down Expand Up @@ -41,6 +43,7 @@ class CollectionConfiguration(TypedDict, total=True):
hnsw: Optional[HNSWConfiguration]
spann: Optional[SpannConfiguration]
embedding_function: Optional[EmbeddingFunction] # type: ignore
query_embedding_function: Optional[EmbeddingFunction] # type: ignore


def load_collection_configuration_from_json_str(
Expand All @@ -64,6 +67,8 @@ def load_collection_configuration_from_json(
spann_config = None
ef_config = None

query_ef = None

# Process vector index configuration (HNSW or SPANN)
if config_json_map.get("hnsw") is not None:
hnsw_config = cast(HNSWConfiguration, config_json_map["hnsw"])
Expand Down Expand Up @@ -100,13 +105,27 @@ def load_collection_configuration_from_json(
f"Could not build embedding function {ef_config['name']} from config {ef_config['config']}: {e}"
)

if config_json_map.get("query_config") is not None:
query_config = config_json_map["query_config"]
query_ef_config = copy.deepcopy(ef_config)
query_ef = known_embedding_functions[ef_name]
for k, v in query_config.items():
query_ef_config["config"][k] = v

try:
query_ef = query_ef.build_from_config(query_ef_config["config"]) # type: ignore
except Exception as e:
raise ValueError(
f"Could not build query embedding function {query_ef_config['name']} from config {query_ef_config['config']}: {e}"
)
else:
ef = None

return CollectionConfiguration(
hnsw=hnsw_config,
spann=spann_config,
embedding_function=ef, # type: ignore
query_embedding_function=query_ef, # type: ignore
)


Expand All @@ -119,6 +138,7 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
hnsw_config = config.get("hnsw")
spann_config = config.get("spann")
ef = config.get("embedding_function")
query_ef = config.get("query_embedding_function")
else:
try:
hnsw_config = config.get_parameter("hnsw").value
Expand Down Expand Up @@ -148,11 +168,6 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
if ef is None:
ef = None
ef_config = {"type": "legacy"}
return {
"hnsw": hnsw_config,
"spann": spann_config,
"embedding_function": ef_config,
}

if ef is not None:
try:
Expand All @@ -174,10 +189,28 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
ef = None
ef_config = {"type": "legacy"}

query_ef_config: Dict[str, Any] | None = None
if query_ef is not None:
try:
query_ef_config = {
"name": query_ef.name(),
"type": "known",
"config": query_ef.get_config(),
}
except Exception as e:
warnings.warn(
f"legacy query embedding function config: {e}",
DeprecationWarning,
stacklevel=2,
)
query_ef = None
query_ef_config = {"type": "legacy"}

return {
"hnsw": hnsw_config,
"spann": spann_config,
"embedding_function": ef_config,
"query_embedding_function": query_ef_config,
}


Expand Down Expand Up @@ -258,16 +291,7 @@ class CreateCollectionConfiguration(TypedDict, total=False):
hnsw: Optional[CreateHNSWConfiguration]
spann: Optional[CreateSpannConfiguration]
embedding_function: Optional[EmbeddingFunction] # type: ignore


def load_collection_configuration_from_create_collection_configuration(
config: CreateCollectionConfiguration,
) -> CollectionConfiguration:
return CollectionConfiguration(
hnsw=config.get("hnsw"),
spann=config.get("spann"),
embedding_function=config.get("embedding_function"),
)
query_config: Optional[QueryConfig]


def create_collection_configuration_from_legacy_collection_metadata(
Expand Down Expand Up @@ -301,13 +325,6 @@ def create_collection_configuration_from_legacy_metadata_dict(
return CreateCollectionConfiguration(hnsw=hnsw_config)


def load_create_collection_configuration_from_json_str(
json_str: str,
) -> CreateCollectionConfiguration:
json_map = json.loads(json_str)
return load_create_collection_configuration_from_json(json_map)


# TODO: make warnings prettier and add link to migration docs
def load_create_collection_configuration_from_json(
json_map: Dict[str, Any]
Expand Down Expand Up @@ -353,6 +370,7 @@ def create_collection_configuration_to_json(
) -> Dict[str, Any]:
"""Convert a CreateCollection configuration to a JSON-serializable dict"""
ef_config: Dict[str, Any] | None = None
query_config: Dict[str, Any] | None = None
hnsw_config = config.get("hnsw")
spann_config = config.get("spann")
if hnsw_config is not None:
Expand Down Expand Up @@ -389,6 +407,15 @@ def create_collection_configuration_to_json(
"config": ef.get_config(),
}
register_embedding_function(type(ef)) # type: ignore

q = config.get("query_config")
if q is not None:
if q.name() == ef.name():
query_config = q.get_config()
else:
raise ValueError(
f"query config name {q.name()} does not match embedding function name {ef.name()}"
)
except Exception as e:
warnings.warn(
f"legacy embedding function config: {e}",
Expand All @@ -402,6 +429,7 @@ def create_collection_configuration_to_json(
"hnsw": hnsw_config,
"spann": spann_config,
"embedding_function": ef_config,
"query_config": query_config,
}


Expand Down Expand Up @@ -473,6 +501,7 @@ class UpdateCollectionConfiguration(TypedDict, total=False):
hnsw: Optional[UpdateHNSWConfiguration]
spann: Optional[UpdateSpannConfiguration]
embedding_function: Optional[EmbeddingFunction] # type: ignore
query_config: Optional[QueryConfig]


def update_collection_configuration_from_legacy_collection_metadata(
Expand Down Expand Up @@ -528,7 +557,9 @@ def update_collection_configuration_to_json(
hnsw_config = config.get("hnsw")
spann_config = config.get("spann")
ef = config.get("embedding_function")
if hnsw_config is None and spann_config is None and ef is None:
q = config.get("query_config")
query_config: Dict[str, Any] | None = None
if hnsw_config is None and spann_config is None and ef is None and q is None:
return {}

if hnsw_config is not None:
Expand All @@ -555,13 +586,21 @@ def update_collection_configuration_to_json(
"config": ef.get_config(),
}
register_embedding_function(type(ef)) # type: ignore
if q is not None:
if q.name() == ef.name():
query_config = q.get_config()
else:
raise ValueError(
f"query config name {q.name()} does not match embedding function name {ef.name()}"
)
else:
ef_config = None

return {
"hnsw": hnsw_config,
"spann": spann_config,
"embedding_function": ef_config,
"query_config": query_config,
}


Expand Down Expand Up @@ -710,10 +749,26 @@ def overwrite_collection_configuration(
else:
updated_embedding_function = update_ef

query_ef = None
if updated_embedding_function is not None:
q = update_config.get("query_config")
if q is not None:
if q.name() != updated_embedding_function.name():
raise ValueError(
f"query config name {q.name()} does not match embedding function name {updated_embedding_function.name()}"
)
else:
ef_config = copy.deepcopy(updated_embedding_function.get_config())
query_config = q.get_config()
for k, v in query_config.items():
ef_config[k] = v
query_ef = updated_embedding_function.build_from_config(ef_config)

return CollectionConfiguration(
hnsw=updated_hnsw_config,
spann=updated_spann_config,
embedding_function=updated_embedding_function,
query_embedding_function=query_ef,
)


Expand Down
24 changes: 19 additions & 5 deletions chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,9 @@ def _validate_and_prepare_query_request(
# Prepare
if query_records["embeddings"] is None:
validate_record_set_for_embedding(record_set=query_records)
request_embeddings = self._embed_record_set(record_set=query_records)
request_embeddings = self._embed_record_set(
record_set=query_records, is_query=True
)
else:
request_embeddings = query_records["embeddings"]

Expand Down Expand Up @@ -531,7 +533,10 @@ def _update_model_after_modify_success(
)

def _embed_record_set(
self, record_set: BaseRecordSet, embeddable_fields: Optional[Set[str]] = None
self,
record_set: BaseRecordSet,
embeddable_fields: Optional[Set[str]] = None,
is_query: bool = False,
) -> Embeddings:
if embeddable_fields is None:
embeddable_fields = get_default_embeddable_record_set_fields()
Expand All @@ -545,21 +550,30 @@ def _embed_record_set(
"You must set a data loader on the collection if loading from URIs."
)
return self._embed(
input=self._data_loader(uris=cast(URIs, record_set[field])) # type: ignore[literal-required]
input=self._data_loader(uris=cast(URIs, record_set[field])), # type: ignore[literal-required]
is_query=is_query,
)
else:
return self._embed(input=record_set[field]) # type: ignore[literal-required]
return self._embed(
input=record_set[field], # type: ignore[literal-required]
is_query=is_query,
)
raise ValueError(
"Record does not contain any non-None fields that can be embedded."
f"Embeddable Fields: {embeddable_fields}"
f"Record Fields: {record_set}"
)

def _embed(self, input: Any) -> Embeddings:
def _embed(self, input: Any, is_query: bool = False) -> Embeddings:
if self._embedding_function is not None and not isinstance(
self._embedding_function, ef.DefaultEmbeddingFunction
):
return self._embedding_function(input=input)
if is_query:
config_ef = self.configuration.get("query_embedding_function")
if config_ef is not None:
return config_ef(input=input)

config_ef = self.configuration.get("embedding_function")
if config_ef is not None:
return config_ef(input=input)
Expand Down
10 changes: 10 additions & 0 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,16 @@ def is_legacy(self) -> bool:
return False


class QueryConfig:
@abstractmethod
def name(self) -> str:
return NotImplemented

@abstractmethod
def get_config(self) -> Dict[str, Any]:
return NotImplemented


def validate_embedding_function(
embedding_function: EmbeddingFunction[Embeddable],
) -> None:
Expand Down
2 changes: 2 additions & 0 deletions chromadb/utils/embedding_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from chromadb.utils.embedding_functions.jina_embedding_function import (
JinaEmbeddingFunction,
JinaQueryConfig,
)
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
VoyageAIEmbeddingFunction,
Expand Down Expand Up @@ -232,6 +233,7 @@ def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction:
"OllamaEmbeddingFunction",
"InstructorEmbeddingFunction",
"JinaEmbeddingFunction",
"JinaQueryConfig",
"MistralEmbeddingFunction",
"VoyageAIEmbeddingFunction",
"ONNXMiniLM_L6_V2",
Expand Down
23 changes: 22 additions & 1 deletion chromadb/utils/embedding_functions/jina_embedding_function.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
from chromadb.api.types import (
Embeddings,
Documents,
EmbeddingFunction,
Space,
QueryConfig,
)
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from typing import List, Dict, Any, Union, Optional
from typing_extensions import override
import os
import numpy as np
import warnings
Expand Down Expand Up @@ -206,3 +213,17 @@ def validate_config(config: Dict[str, Any]) -> None:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "jina")


class JinaQueryConfig(QueryConfig):
def __init__(self, task: Optional[str] = None):
self.task = task

@override
def name(self) -> str:
return "jina"

def get_config(self) -> Dict[str, Any]:
return {
"task": self.task,
}
2 changes: 2 additions & 0 deletions clients/js/packages/chromadb-core/src/generated/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ export namespace Api {
export interface CollectionConfiguration {
embedding_function?: Api.EmbeddingFunctionConfiguration | null;
hnsw?: Api.HnswConfiguration | null;
query_config?: unknown;
spann?: Api.SpannConfiguration | null;
}

Expand Down Expand Up @@ -335,6 +336,7 @@ export namespace Api {
export interface UpdateCollectionConfiguration {
embedding_function?: Api.EmbeddingFunctionConfiguration | null;
hnsw?: Api.UpdateHnswConfiguration | null;
query_config?: unknown;
spann?: Api.SpannConfiguration | null;
}

Expand Down
2 changes: 2 additions & 0 deletions clients/new-js/packages/chromadb/src/api/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ export type Collection = {
export type CollectionConfiguration = {
embedding_function?: null | EmbeddingFunctionConfiguration;
hnsw?: null | HnswConfiguration;
query_config?: unknown;
spann?: null | SpannConfiguration;
};

Expand Down Expand Up @@ -196,6 +197,7 @@ export type SpannConfiguration = {
export type UpdateCollectionConfiguration = {
embedding_function?: null | EmbeddingFunctionConfiguration;
hnsw?: null | UpdateHnswConfiguration;
query_config?: unknown;
spann?: null | SpannConfiguration;
};

Expand Down
Loading
Loading