Skip to content

[ENH] Add schema support to collection configuration #4932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: jai/nomic-ef
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion chromadb/api/collection_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
EmbeddingFunction,
QueryConfig,
)
from chromadb.base_types import CollectionSchema, ValueType
from chromadb.utils.embedding_functions import (
known_embedding_functions,
register_embedding_function,
Expand Down Expand Up @@ -44,6 +45,7 @@ class CollectionConfiguration(TypedDict, total=True):
spann: Optional[SpannConfiguration]
embedding_function: Optional[EmbeddingFunction] # type: ignore
query_embedding_function: Optional[EmbeddingFunction] # type: ignore
schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]]


def load_collection_configuration_from_json_str(
Expand Down Expand Up @@ -126,6 +128,7 @@ def load_collection_configuration_from_json(
spann=spann_config,
embedding_function=ef, # type: ignore
query_embedding_function=query_ef, # type: ignore
schema=config_json_map.get("schema"),
)


Expand All @@ -139,6 +142,7 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
spann_config = config.get("spann")
ef = config.get("embedding_function")
query_ef = config.get("query_embedding_function")
schema = config.get("schema")
else:
try:
hnsw_config = config.get_parameter("hnsw").value
Expand Down Expand Up @@ -211,6 +215,7 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
"spann": spann_config,
"embedding_function": ef_config,
"query_embedding_function": query_ef_config,
"schema": schema,
}


Expand Down Expand Up @@ -292,6 +297,7 @@ class CreateCollectionConfiguration(TypedDict, total=False):
spann: Optional[CreateSpannConfiguration]
embedding_function: Optional[EmbeddingFunction] # type: ignore
query_config: Optional[QueryConfig]
schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]]


def create_collection_configuration_from_legacy_collection_metadata(
Expand Down Expand Up @@ -430,6 +436,7 @@ def create_collection_configuration_to_json(
"spann": spann_config,
"embedding_function": ef_config,
"query_config": query_config,
"schema": config.get("schema"),
}


Expand Down Expand Up @@ -502,6 +509,7 @@ class UpdateCollectionConfiguration(TypedDict, total=False):
spann: Optional[UpdateSpannConfiguration]
embedding_function: Optional[EmbeddingFunction] # type: ignore
query_config: Optional[QueryConfig]
schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]]


def update_collection_configuration_from_legacy_collection_metadata(
Expand Down Expand Up @@ -556,10 +564,17 @@ def update_collection_configuration_to_json(
"""Convert an UpdateCollectionConfiguration to a JSON-serializable dict"""
hnsw_config = config.get("hnsw")
spann_config = config.get("spann")
schema = config.get("schema")
ef = config.get("embedding_function")
q = config.get("query_config")
query_config: Dict[str, Any] | None = None
if hnsw_config is None and spann_config is None and ef is None and q is None:
if (
hnsw_config is None
and spann_config is None
and ef is None
and q is None
and schema is None
):
return {}

if hnsw_config is not None:
Expand Down Expand Up @@ -601,6 +616,7 @@ def update_collection_configuration_to_json(
"spann": spann_config,
"embedding_function": ef_config,
"query_config": query_config,
"schema": schema,
}


Expand Down Expand Up @@ -764,14 +780,40 @@ def overwrite_collection_configuration(
ef_config[k] = v
query_ef = updated_embedding_function.build_from_config(ef_config)

existing_schema = existing_config.get("schema")
new_diff_schema = update_config.get("schema")
updated_schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]] = None
if existing_schema is not None:
if new_diff_schema is not None:
updated_schema = overwrite_schema(existing_schema, new_diff_schema)
else:
updated_schema = existing_schema
else:
updated_schema = new_diff_schema

return CollectionConfiguration(
hnsw=updated_hnsw_config,
spann=updated_spann_config,
embedding_function=updated_embedding_function,
query_embedding_function=query_ef,
schema=updated_schema,
)


def overwrite_schema(
existing_schema: Dict[str, Dict[ValueType, CollectionSchema]],
new_diff_schema: Dict[str, Dict[ValueType, CollectionSchema]],
) -> Dict[str, Dict[ValueType, CollectionSchema]]:
"""Overwrite a schema with a new configuration"""
for new_key, new_value in new_diff_schema.items():
if new_key in existing_schema:
for value_type, new_schema in new_value.items():
existing_schema[new_key][value_type] = new_schema
else:
existing_schema[new_key] = new_value
return existing_schema


def validate_embedding_function_conflict_on_create(
embedding_function: Optional[EmbeddingFunction], # type: ignore
configuration_ef: Optional[EmbeddingFunction], # type: ignore
Expand Down
22 changes: 19 additions & 3 deletions chromadb/api/models/AsyncCollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,22 @@ async def add(
ValueError: If you provide an id that already exists

"""
add_request = self._validate_and_prepare_add_request(

curr_schema = self._model.get_configuration().get("schema")

add_request, new_attributes = self._validate_and_prepare_add_request(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
schema=curr_schema,
)

if len(new_attributes.keys()) > 0:
await self.modify(configuration={"schema": new_attributes})

await self._client._add(
collection_id=self.id,
ids=add_request["ids"],
Expand Down Expand Up @@ -313,15 +320,20 @@ async def update(
Returns:
None
"""
update_request = self._validate_and_prepare_update_request(
curr_schema = self._model.get_configuration().get("schema")
update_request, new_attributes = self._validate_and_prepare_update_request(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
schema=curr_schema,
)

if len(new_attributes.keys()) > 0:
await self.modify(configuration={"schema": new_attributes})

await self._client._update(
collection_id=self.id,
ids=update_request["ids"],
Expand Down Expand Up @@ -358,14 +370,18 @@ async def upsert(
Returns:
None
"""
upsert_request = self._validate_and_prepare_upsert_request(
curr_schema = self._model.get_configuration().get("schema")
upsert_request, new_attributes = self._validate_and_prepare_upsert_request(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
schema=curr_schema,
)
if len(new_attributes.keys()) > 0:
await self.modify(configuration={"schema": new_attributes})

await self._client._upsert(
collection_id=self.id,
Expand Down
22 changes: 19 additions & 3 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,20 @@ def add(

"""

add_request = self._validate_and_prepare_add_request(
curr_schema = self._model.get_configuration().get("schema")
add_request, new_attributes = self._validate_and_prepare_add_request(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
schema=curr_schema,
)

if len(new_attributes.keys()) > 0:
self.modify(configuration={"schema": new_attributes})

self._client._add(
collection_id=self.id,
ids=add_request["ids"],
Expand Down Expand Up @@ -255,6 +260,7 @@ def modify(
# Note there is a race condition here where the metadata can be updated
# but another thread sees the cached local metadata.
# TODO: fixme

self._client._modify(
id=self.id,
new_name=name,
Expand Down Expand Up @@ -317,15 +323,20 @@ def update(
Returns:
None
"""
update_request = self._validate_and_prepare_update_request(
curr_schema = self._model.get_configuration().get("schema")
update_request, new_attributes = self._validate_and_prepare_update_request(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
schema=curr_schema,
)

if len(new_attributes.keys()) > 0:
self.modify(configuration={"schema": new_attributes})

self._client._update(
collection_id=self.id,
ids=update_request["ids"],
Expand Down Expand Up @@ -362,15 +373,20 @@ def upsert(
Returns:
None
"""
upsert_request = self._validate_and_prepare_upsert_request(
curr_schema = self._model.get_configuration().get("schema")
upsert_request, new_attributes = self._validate_and_prepare_upsert_request(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
schema=curr_schema,
)

if len(new_attributes.keys()) > 0:
self.modify(configuration={"schema": new_attributes})

self._client._upsert(
collection_id=self.id,
ids=upsert_request["ids"],
Expand Down
Loading
Loading