Skip to content

Commit c190fc0

Browse files
committed
[ENH] Add schema support to collection configuration
1 parent 480ead9 commit c190fc0

File tree

14 files changed

+294
-37
lines changed

14 files changed

+294
-37
lines changed

chromadb/api/collection_configuration.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
EmbeddingFunction,
99
QueryConfig,
1010
)
11+
from chromadb.base_types import CollectionSchema, ValueType
1112
from chromadb.utils.embedding_functions import (
1213
known_embedding_functions,
1314
register_embedding_function,
@@ -44,6 +45,7 @@ class CollectionConfiguration(TypedDict, total=True):
4445
spann: Optional[SpannConfiguration]
4546
embedding_function: Optional[EmbeddingFunction] # type: ignore
4647
query_embedding_function: Optional[EmbeddingFunction] # type: ignore
48+
schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]]
4749

4850

4951
def load_collection_configuration_from_json_str(
@@ -126,6 +128,7 @@ def load_collection_configuration_from_json(
126128
spann=spann_config,
127129
embedding_function=ef, # type: ignore
128130
query_embedding_function=query_ef, # type: ignore
131+
schema=config_json_map.get("schema"),
129132
)
130133

131134

@@ -139,6 +142,7 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
139142
spann_config = config.get("spann")
140143
ef = config.get("embedding_function")
141144
query_ef = config.get("query_embedding_function")
145+
schema = config.get("schema")
142146
else:
143147
try:
144148
hnsw_config = config.get_parameter("hnsw").value
@@ -211,6 +215,7 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
211215
"spann": spann_config,
212216
"embedding_function": ef_config,
213217
"query_embedding_function": query_ef_config,
218+
"schema": schema,
214219
}
215220

216221

@@ -292,6 +297,7 @@ class CreateCollectionConfiguration(TypedDict, total=False):
292297
spann: Optional[CreateSpannConfiguration]
293298
embedding_function: Optional[EmbeddingFunction] # type: ignore
294299
query_config: Optional[QueryConfig]
300+
schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]]
295301

296302

297303
def create_collection_configuration_from_legacy_collection_metadata(
@@ -430,6 +436,7 @@ def create_collection_configuration_to_json(
430436
"spann": spann_config,
431437
"embedding_function": ef_config,
432438
"query_config": query_config,
439+
"schema": config.get("schema"),
433440
}
434441

435442

@@ -502,6 +509,7 @@ class UpdateCollectionConfiguration(TypedDict, total=False):
502509
spann: Optional[UpdateSpannConfiguration]
503510
embedding_function: Optional[EmbeddingFunction] # type: ignore
504511
query_config: Optional[QueryConfig]
512+
schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]]
505513

506514

507515
def update_collection_configuration_from_legacy_collection_metadata(
@@ -556,10 +564,17 @@ def update_collection_configuration_to_json(
556564
"""Convert an UpdateCollectionConfiguration to a JSON-serializable dict"""
557565
hnsw_config = config.get("hnsw")
558566
spann_config = config.get("spann")
567+
schema = config.get("schema")
559568
ef = config.get("embedding_function")
560569
q = config.get("query_config")
561570
query_config: Dict[str, Any] | None = None
562-
if hnsw_config is None and spann_config is None and ef is None and q is None:
571+
if (
572+
hnsw_config is None
573+
and spann_config is None
574+
and ef is None
575+
and q is None
576+
and schema is None
577+
):
563578
return {}
564579

565580
if hnsw_config is not None:
@@ -601,6 +616,7 @@ def update_collection_configuration_to_json(
601616
"spann": spann_config,
602617
"embedding_function": ef_config,
603618
"query_config": query_config,
619+
"schema": schema,
604620
}
605621

606622

@@ -764,14 +780,40 @@ def overwrite_collection_configuration(
764780
ef_config[k] = v
765781
query_ef = updated_embedding_function.build_from_config(ef_config)
766782

783+
existing_schema = existing_config.get("schema")
784+
new_diff_schema = update_config.get("schema")
785+
updated_schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]] = None
786+
if existing_schema is not None:
787+
if new_diff_schema is not None:
788+
updated_schema = overwrite_schema(existing_schema, new_diff_schema)
789+
else:
790+
updated_schema = existing_schema
791+
else:
792+
updated_schema = new_diff_schema
793+
767794
return CollectionConfiguration(
768795
hnsw=updated_hnsw_config,
769796
spann=updated_spann_config,
770797
embedding_function=updated_embedding_function,
771798
query_embedding_function=query_ef,
799+
schema=updated_schema,
772800
)
773801

774802

803+
def overwrite_schema(
804+
existing_schema: Dict[str, Dict[ValueType, CollectionSchema]],
805+
new_diff_schema: Dict[str, Dict[ValueType, CollectionSchema]],
806+
) -> Dict[str, Dict[ValueType, CollectionSchema]]:
807+
"""Overwrite a schema with a new configuration"""
808+
for new_key, new_value in new_diff_schema.items():
809+
if new_key in existing_schema:
810+
for value_type, new_schema in new_value.items():
811+
existing_schema[new_key][value_type] = new_schema
812+
else:
813+
existing_schema[new_key] = new_value
814+
return existing_schema
815+
816+
775817
def validate_embedding_function_conflict_on_create(
776818
embedding_function: Optional[EmbeddingFunction], # type: ignore
777819
configuration_ef: Optional[EmbeddingFunction], # type: ignore

chromadb/api/models/AsyncCollection.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,22 @@ async def add(
6060
ValueError: If you provide an id that already exists
6161
6262
"""
63-
add_request = self._validate_and_prepare_add_request(
63+
64+
curr_schema = self._model.get_configuration().get("schema")
65+
66+
add_request, new_attributes = self._validate_and_prepare_add_request(
6467
ids=ids,
6568
embeddings=embeddings,
6669
metadatas=metadatas,
6770
documents=documents,
6871
images=images,
6972
uris=uris,
73+
schema=curr_schema,
7074
)
7175

76+
if len(new_attributes.keys()) > 0:
77+
await self.modify(configuration={"schema": new_attributes})
78+
7279
await self._client._add(
7380
collection_id=self.id,
7481
ids=add_request["ids"],
@@ -313,15 +320,20 @@ async def update(
313320
Returns:
314321
None
315322
"""
316-
update_request = self._validate_and_prepare_update_request(
323+
curr_schema = self._model.get_configuration().get("schema")
324+
update_request, new_attributes = self._validate_and_prepare_update_request(
317325
ids=ids,
318326
embeddings=embeddings,
319327
metadatas=metadatas,
320328
documents=documents,
321329
images=images,
322330
uris=uris,
331+
schema=curr_schema,
323332
)
324333

334+
if len(new_attributes.keys()) > 0:
335+
await self.modify(configuration={"schema": new_attributes})
336+
325337
await self._client._update(
326338
collection_id=self.id,
327339
ids=update_request["ids"],
@@ -358,14 +370,18 @@ async def upsert(
358370
Returns:
359371
None
360372
"""
361-
upsert_request = self._validate_and_prepare_upsert_request(
373+
curr_schema = self._model.get_configuration().get("schema")
374+
upsert_request, new_attributes = self._validate_and_prepare_upsert_request(
362375
ids=ids,
363376
embeddings=embeddings,
364377
metadatas=metadatas,
365378
documents=documents,
366379
images=images,
367380
uris=uris,
381+
schema=curr_schema,
368382
)
383+
if len(new_attributes.keys()) > 0:
384+
await self.modify(configuration={"schema": new_attributes})
369385

370386
await self._client._upsert(
371387
collection_id=self.id,

chromadb/api/models/Collection.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,20 @@ def add(
7777
7878
"""
7979

80-
add_request = self._validate_and_prepare_add_request(
80+
curr_schema = self._model.get_configuration().get("schema")
81+
add_request, new_attributes = self._validate_and_prepare_add_request(
8182
ids=ids,
8283
embeddings=embeddings,
8384
metadatas=metadatas,
8485
documents=documents,
8586
images=images,
8687
uris=uris,
88+
schema=curr_schema,
8789
)
8890

91+
if len(new_attributes.keys()) > 0:
92+
self.modify(configuration={"schema": new_attributes})
93+
8994
self._client._add(
9095
collection_id=self.id,
9196
ids=add_request["ids"],
@@ -255,6 +260,7 @@ def modify(
255260
# Note there is a race condition here where the metadata can be updated
256261
# but another thread sees the cached local metadata.
257262
# TODO: fixme
263+
258264
self._client._modify(
259265
id=self.id,
260266
new_name=name,
@@ -317,15 +323,20 @@ def update(
317323
Returns:
318324
None
319325
"""
320-
update_request = self._validate_and_prepare_update_request(
326+
curr_schema = self._model.get_configuration().get("schema")
327+
update_request, new_attributes = self._validate_and_prepare_update_request(
321328
ids=ids,
322329
embeddings=embeddings,
323330
metadatas=metadatas,
324331
documents=documents,
325332
images=images,
326333
uris=uris,
334+
schema=curr_schema,
327335
)
328336

337+
if len(new_attributes.keys()) > 0:
338+
self.modify(configuration={"schema": new_attributes})
339+
329340
self._client._update(
330341
collection_id=self.id,
331342
ids=update_request["ids"],
@@ -362,15 +373,20 @@ def upsert(
362373
Returns:
363374
None
364375
"""
365-
upsert_request = self._validate_and_prepare_upsert_request(
376+
curr_schema = self._model.get_configuration().get("schema")
377+
upsert_request, new_attributes = self._validate_and_prepare_upsert_request(
366378
ids=ids,
367379
embeddings=embeddings,
368380
metadatas=metadatas,
369381
documents=documents,
370382
images=images,
371383
uris=uris,
384+
schema=curr_schema,
372385
)
373386

387+
if len(new_attributes.keys()) > 0:
388+
self.modify(configuration={"schema": new_attributes})
389+
374390
self._client._upsert(
375391
collection_id=self.id,
376392
ids=upsert_request["ids"],

0 commit comments

Comments
 (0)