Skip to content

Commit d438fa5

Browse files
committed
[ENH] Add schema support to collection configuration
1 parent cc85a9a commit d438fa5

File tree

13 files changed

+301
-37
lines changed

13 files changed

+301
-37
lines changed

chromadb/api/collection_configuration.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
EmbeddingFunction,
99
QueryConfig,
1010
)
11+
from chromadb.base_types import CollectionSchema
1112
from chromadb.utils.embedding_functions import (
1213
known_embedding_functions,
1314
register_embedding_function,
@@ -44,6 +45,7 @@ class CollectionConfiguration(TypedDict, total=True):
4445
spann: Optional[SpannConfiguration]
4546
embedding_function: Optional[EmbeddingFunction] # type: ignore
4647
query_embedding_function: Optional[EmbeddingFunction] # type: ignore
48+
schema: Optional[Dict[str, CollectionSchema]]
4749

4850

4951
def load_collection_configuration_from_json_str(
@@ -126,6 +128,7 @@ def load_collection_configuration_from_json(
126128
spann=spann_config,
127129
embedding_function=ef, # type: ignore
128130
query_embedding_function=query_ef, # type: ignore
131+
schema=config_json_map.get("schema"),
129132
)
130133

131134

@@ -278,6 +281,7 @@ class CreateCollectionConfiguration(TypedDict, total=False):
278281
spann: Optional[CreateSpannConfiguration]
279282
embedding_function: Optional[EmbeddingFunction] # type: ignore
280283
query_config: Optional[QueryConfig]
284+
schema: Optional[Dict[str, CollectionSchema]]
281285

282286

283287
def create_collection_configuration_from_legacy_collection_metadata(
@@ -416,6 +420,7 @@ def create_collection_configuration_to_json(
416420
"spann": spann_config,
417421
"embedding_function": ef_config,
418422
"query_config": query_config,
423+
"schema": config.get("schema"),
419424
}
420425

421426

@@ -488,6 +493,7 @@ class UpdateCollectionConfiguration(TypedDict, total=False):
488493
spann: Optional[UpdateSpannConfiguration]
489494
embedding_function: Optional[EmbeddingFunction] # type: ignore
490495
query_config: Optional[QueryConfig]
496+
schema: Optional[Dict[str, CollectionSchema]]
491497

492498

493499
def update_collection_configuration_from_legacy_collection_metadata(
@@ -587,6 +593,7 @@ def update_collection_configuration_to_json(
587593
"spann": spann_config,
588594
"embedding_function": ef_config,
589595
"query_config": query_config,
596+
"schema": config.get("schema"),
590597
}
591598

592599

@@ -750,14 +757,34 @@ def overwrite_collection_configuration(
750757
ef_config[k] = v
751758
query_ef = updated_embedding_function.build_from_config(ef_config)
752759

760+
existing_schema = existing_config.get("schema")
761+
new_diff_schema = update_config.get("schema")
762+
updated_schema: Optional[Dict[str, CollectionSchema]] = None
763+
if existing_schema is not None:
764+
if new_diff_schema is not None:
765+
updated_schema = overwrite_schema(existing_schema, new_diff_schema)
766+
else:
767+
updated_schema = existing_schema
768+
else:
769+
updated_schema = new_diff_schema
770+
753771
return CollectionConfiguration(
754772
hnsw=updated_hnsw_config,
755773
spann=updated_spann_config,
756774
embedding_function=updated_embedding_function,
757775
query_embedding_function=query_ef,
776+
schema=updated_schema,
758777
)
759778

760779

780+
def overwrite_schema(
781+
existing_schema: Dict[str, CollectionSchema],
782+
new_diff_schema: Dict[str, CollectionSchema],
783+
) -> Dict[str, CollectionSchema]:
784+
"""Overwrite a schema with a new configuration"""
785+
return {**existing_schema, **new_diff_schema}
786+
787+
761788
def validate_embedding_function_conflict_on_create(
762789
embedding_function: Optional[EmbeddingFunction], # type: ignore
763790
configuration_ef: Optional[EmbeddingFunction], # type: ignore

chromadb/api/models/AsyncCollection.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ async def add(
6060
ValueError: If you provide an id that already exists
6161
6262
"""
63+
64+
# get metadatas, if they don't exist in schema yet add them, and do collection.modify()
65+
# also validate that the metadatas are valid for the schema
66+
6367
add_request = self._validate_and_prepare_add_request(
6468
ids=ids,
6569
embeddings=embeddings,

chromadb/api/models/Collection.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,23 @@ def add(
7777
7878
"""
7979

80-
add_request = self._validate_and_prepare_add_request(
80+
curr_schema = self._model.get_configuration().get("schema")
81+
add_request, new_attributes = self._validate_and_prepare_add_request(
8182
ids=ids,
8283
embeddings=embeddings,
8384
metadatas=metadatas,
8485
documents=documents,
8586
images=images,
8687
uris=uris,
88+
schema=curr_schema,
8789
)
8890

91+
if len(new_attributes.keys()) > 0:
92+
if curr_schema is None:
93+
curr_schema = {}
94+
curr_schema = {**curr_schema, **new_attributes}
95+
self.modify(configuration={"schema": curr_schema})
96+
8997
self._client._add(
9098
collection_id=self.id,
9199
ids=add_request["ids"],
@@ -317,7 +325,8 @@ def update(
317325
Returns:
318326
None
319327
"""
320-
update_request = self._validate_and_prepare_update_request(
328+
curr_schema = self._model.get_configuration().get("schema")
329+
update_request, new_attributes = self._validate_and_prepare_update_request(
321330
ids=ids,
322331
embeddings=embeddings,
323332
metadatas=metadatas,
@@ -326,6 +335,12 @@ def update(
326335
uris=uris,
327336
)
328337

338+
if len(new_attributes.keys()) > 0:
339+
if curr_schema is None:
340+
curr_schema = {}
341+
curr_schema = {**curr_schema, **new_attributes}
342+
self.modify(configuration={"schema": curr_schema})
343+
329344
self._client._update(
330345
collection_id=self.id,
331346
ids=update_request["ids"],
@@ -362,7 +377,8 @@ def upsert(
362377
Returns:
363378
None
364379
"""
365-
upsert_request = self._validate_and_prepare_upsert_request(
380+
curr_schema = self._model.get_configuration().get("schema")
381+
upsert_request, new_attributes = self._validate_and_prepare_upsert_request(
366382
ids=ids,
367383
embeddings=embeddings,
368384
metadatas=metadatas,
@@ -371,6 +387,12 @@ def upsert(
371387
uris=uris,
372388
)
373389

390+
if len(new_attributes.keys()) > 0:
391+
if curr_schema is None:
392+
curr_schema = {}
393+
curr_schema = {**curr_schema, **new_attributes}
394+
self.modify(configuration={"schema": curr_schema})
395+
374396
self._client._upsert(
375397
collection_id=self.id,
376398
ids=upsert_request["ids"],

chromadb/api/models/CollectionCommon.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
TypeVar,
1111
Union,
1212
cast,
13+
Tuple,
1314
)
1415
from chromadb.types import Metadata
1516
import numpy as np
1617
from uuid import UUID
17-
1818
import chromadb.utils.embedding_functions as ef
19+
from chromadb.base_types import CollectionSchema
20+
1921
from chromadb.api.types import (
2022
URI,
2123
URIs,
@@ -47,6 +49,7 @@
4749
maybe_cast_one_to_many,
4850
normalize_base_record_set,
4951
normalize_insert_record_set,
52+
update_schema_with_insert_record_set,
5053
validate_base_record_set,
5154
validate_ids,
5255
validate_include,
@@ -204,7 +207,8 @@ def _validate_and_prepare_add_request(
204207
documents: Optional[OneOrMany[Document]],
205208
images: Optional[OneOrMany[Image]],
206209
uris: Optional[OneOrMany[URI]],
207-
) -> AddRequest:
210+
schema: Optional[Dict[str, CollectionSchema]],
211+
) -> Tuple[AddRequest, Dict[str, CollectionSchema]]:
208212
# Unpack
209213
add_records = normalize_insert_record_set(
210214
ids=ids,
@@ -216,22 +220,29 @@ def _validate_and_prepare_add_request(
216220
)
217221

218222
# Validate
219-
validate_insert_record_set(record_set=add_records)
223+
validate_insert_record_set(record_set=add_records, schema=schema)
220224
validate_record_set_contains_any(record_set=add_records, contains_any={"ids"})
221225

226+
new_attributes = update_schema_with_insert_record_set(
227+
record_set=add_records, schema=schema
228+
)
229+
222230
# Prepare
223231
if add_records["embeddings"] is None:
224232
validate_record_set_for_embedding(record_set=add_records)
225233
add_embeddings = self._embed_record_set(record_set=add_records)
226234
else:
227235
add_embeddings = add_records["embeddings"]
228236

229-
return AddRequest(
230-
ids=add_records["ids"],
231-
embeddings=add_embeddings,
232-
metadatas=add_records["metadatas"],
233-
documents=add_records["documents"],
234-
uris=add_records["uris"],
237+
return (
238+
AddRequest(
239+
ids=add_records["ids"],
240+
embeddings=add_embeddings,
241+
metadatas=add_records["metadatas"],
242+
documents=add_records["documents"],
243+
uris=add_records["uris"],
244+
),
245+
new_attributes,
235246
)
236247

237248
@validation_context("get")
@@ -350,7 +361,8 @@ def _validate_and_prepare_update_request(
350361
documents: Optional[OneOrMany[Document]],
351362
images: Optional[OneOrMany[Image]],
352363
uris: Optional[OneOrMany[URI]],
353-
) -> UpdateRequest:
364+
schema: Optional[Dict[str, CollectionSchema]],
365+
) -> Tuple[UpdateRequest, Dict[str, CollectionSchema]]:
354366
# Unpack
355367
update_records = normalize_insert_record_set(
356368
ids=ids,
@@ -362,7 +374,10 @@ def _validate_and_prepare_update_request(
362374
)
363375

364376
# Validate
365-
validate_insert_record_set(record_set=update_records)
377+
validate_insert_record_set(record_set=update_records, schema=schema)
378+
new_attributes = update_schema_with_insert_record_set(
379+
record_set=update_records, schema=schema
380+
)
366381

367382
# Prepare
368383
if update_records["embeddings"] is None:
@@ -380,12 +395,15 @@ def _validate_and_prepare_update_request(
380395
else:
381396
update_embeddings = update_records["embeddings"]
382397

383-
return UpdateRequest(
384-
ids=update_records["ids"],
385-
embeddings=update_embeddings,
386-
metadatas=update_records["metadatas"],
387-
documents=update_records["documents"],
388-
uris=update_records["uris"],
398+
return (
399+
UpdateRequest(
400+
ids=update_records["ids"],
401+
embeddings=update_embeddings,
402+
metadatas=update_records["metadatas"],
403+
documents=update_records["documents"],
404+
uris=update_records["uris"],
405+
),
406+
new_attributes,
389407
)
390408

391409
@validation_context("upsert")
@@ -402,7 +420,8 @@ def _validate_and_prepare_upsert_request(
402420
documents: Optional[OneOrMany[Document]] = None,
403421
images: Optional[OneOrMany[Image]] = None,
404422
uris: Optional[OneOrMany[URI]] = None,
405-
) -> UpsertRequest:
423+
schema: Optional[Dict[str, CollectionSchema]] = None,
424+
) -> Tuple[UpsertRequest, Dict[str, CollectionSchema]]:
406425
# Unpack
407426
upsert_records = normalize_insert_record_set(
408427
ids=ids,
@@ -414,8 +433,10 @@ def _validate_and_prepare_upsert_request(
414433
)
415434

416435
# Validate
417-
validate_insert_record_set(record_set=upsert_records)
418-
436+
validate_insert_record_set(record_set=upsert_records, schema=schema)
437+
new_attributes = update_schema_with_insert_record_set(
438+
record_set=upsert_records, schema=schema
439+
)
419440
# Prepare
420441
if upsert_records["embeddings"] is None:
421442
validate_record_set_for_embedding(
@@ -425,12 +446,15 @@ def _validate_and_prepare_upsert_request(
425446
else:
426447
upsert_embeddings = upsert_records["embeddings"]
427448

428-
return UpsertRequest(
429-
ids=upsert_records["ids"],
430-
metadatas=upsert_records["metadatas"],
431-
embeddings=upsert_embeddings,
432-
documents=upsert_records["documents"],
433-
uris=upsert_records["uris"],
449+
return (
450+
UpsertRequest(
451+
ids=upsert_records["ids"],
452+
metadatas=upsert_records["metadatas"],
453+
embeddings=upsert_embeddings,
454+
documents=upsert_records["documents"],
455+
uris=upsert_records["uris"],
456+
),
457+
new_attributes,
434458
)
435459

436460
@validation_context("delete")

0 commit comments

Comments
 (0)