diff --git a/chromadb/api/collection_configuration.py b/chromadb/api/collection_configuration.py index bd573e81ba2..02425666289 100644 --- a/chromadb/api/collection_configuration.py +++ b/chromadb/api/collection_configuration.py @@ -8,6 +8,7 @@ EmbeddingFunction, QueryConfig, ) +from chromadb.base_types import CollectionSchema, ValueType from chromadb.utils.embedding_functions import ( known_embedding_functions, register_embedding_function, @@ -44,6 +45,7 @@ class CollectionConfiguration(TypedDict, total=True): spann: Optional[SpannConfiguration] embedding_function: Optional[EmbeddingFunction] # type: ignore query_embedding_function: Optional[EmbeddingFunction] # type: ignore + schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]] def load_collection_configuration_from_json_str( @@ -126,6 +128,7 @@ def load_collection_configuration_from_json( spann=spann_config, embedding_function=ef, # type: ignore query_embedding_function=query_ef, # type: ignore + schema=config_json_map.get("schema"), ) @@ -139,6 +142,7 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st spann_config = config.get("spann") ef = config.get("embedding_function") query_ef = config.get("query_embedding_function") + schema = config.get("schema") else: try: hnsw_config = config.get_parameter("hnsw").value @@ -211,6 +215,7 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st "spann": spann_config, "embedding_function": ef_config, "query_embedding_function": query_ef_config, + "schema": schema, } @@ -292,6 +297,7 @@ class CreateCollectionConfiguration(TypedDict, total=False): spann: Optional[CreateSpannConfiguration] embedding_function: Optional[EmbeddingFunction] # type: ignore query_config: Optional[QueryConfig] + schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]] def create_collection_configuration_from_legacy_collection_metadata( @@ -430,6 +436,7 @@ def create_collection_configuration_to_json( "spann": spann_config, "embedding_function": ef_config, "query_config": query_config, + "schema": config.get("schema"), } @@ -502,6 +509,7 @@ class UpdateCollectionConfiguration(TypedDict, total=False): spann: Optional[UpdateSpannConfiguration] embedding_function: Optional[EmbeddingFunction] # type: ignore query_config: Optional[QueryConfig] + schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]] def update_collection_configuration_from_legacy_collection_metadata( @@ -556,10 +564,17 @@ def update_collection_configuration_to_json( """Convert an UpdateCollectionConfiguration to a JSON-serializable dict""" hnsw_config = config.get("hnsw") spann_config = config.get("spann") + schema = config.get("schema") ef = config.get("embedding_function") q = config.get("query_config") query_config: Dict[str, Any] | None = None - if hnsw_config is None and spann_config is None and ef is None and q is None: + if ( + hnsw_config is None + and spann_config is None + and ef is None + and q is None + and schema is None + ): return {} if hnsw_config is not None: @@ -601,6 +616,7 @@ def update_collection_configuration_to_json( "spann": spann_config, "embedding_function": ef_config, "query_config": query_config, + "schema": schema, } @@ -764,14 +780,40 @@ def overwrite_collection_configuration( ef_config[k] = v query_ef = updated_embedding_function.build_from_config(ef_config) + existing_schema = existing_config.get("schema") + new_diff_schema = update_config.get("schema") + updated_schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]] = None + if existing_schema is not None: + if new_diff_schema is not None: + updated_schema = overwrite_schema(existing_schema, new_diff_schema) + else: + updated_schema = existing_schema + else: + updated_schema = new_diff_schema + return CollectionConfiguration( hnsw=updated_hnsw_config, spann=updated_spann_config, embedding_function=updated_embedding_function, query_embedding_function=query_ef, + schema=updated_schema, ) +def overwrite_schema( + existing_schema: Dict[str, Dict[ValueType, CollectionSchema]], + new_diff_schema: Dict[str, Dict[ValueType, CollectionSchema]], +) -> Dict[str, Dict[ValueType, CollectionSchema]]: + """Overwrite a schema with a new configuration""" + for new_key, new_value in new_diff_schema.items(): + if new_key in existing_schema: + for value_type, new_schema in new_value.items(): + existing_schema[new_key][value_type] = new_schema + else: + existing_schema[new_key] = new_value + return existing_schema + + def validate_embedding_function_conflict_on_create( embedding_function: Optional[EmbeddingFunction], # type: ignore configuration_ef: Optional[EmbeddingFunction], # type: ignore diff --git a/chromadb/api/models/AsyncCollection.py b/chromadb/api/models/AsyncCollection.py index 29877cfb2fb..5da2d1476d6 100644 --- a/chromadb/api/models/AsyncCollection.py +++ b/chromadb/api/models/AsyncCollection.py @@ -60,15 +60,22 @@ async def add( ValueError: If you provide an id that already exists """ - add_request = self._validate_and_prepare_add_request( + + curr_schema = self._model.get_configuration().get("schema") + + add_request, new_attributes = self._validate_and_prepare_add_request( ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents, images=images, uris=uris, + schema=curr_schema, ) + if len(new_attributes.keys()) > 0: + await self.modify(configuration={"schema": new_attributes}) + await self._client._add( collection_id=self.id, ids=add_request["ids"], @@ -313,15 +320,20 @@ async def update( Returns: None """ - update_request = self._validate_and_prepare_update_request( + curr_schema = self._model.get_configuration().get("schema") + update_request, new_attributes = self._validate_and_prepare_update_request( ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents, images=images, uris=uris, + schema=curr_schema, ) + if len(new_attributes.keys()) > 0: + await self.modify(configuration={"schema": new_attributes}) + await self._client._update( collection_id=self.id, ids=update_request["ids"], @@ -358,14 +370,18 @@ async def upsert( Returns: None """ - upsert_request = self._validate_and_prepare_upsert_request( + curr_schema = self._model.get_configuration().get("schema") + upsert_request, new_attributes = self._validate_and_prepare_upsert_request( ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents, images=images, uris=uris, + schema=curr_schema, ) + if len(new_attributes.keys()) > 0: + await self.modify(configuration={"schema": new_attributes}) await self._client._upsert( collection_id=self.id, diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index b42c2ff64ec..9457aa4b595 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -77,15 +77,20 @@ def add( """ - add_request = self._validate_and_prepare_add_request( + curr_schema = self._model.get_configuration().get("schema") + add_request, new_attributes = self._validate_and_prepare_add_request( ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents, images=images, uris=uris, + schema=curr_schema, ) + if len(new_attributes.keys()) > 0: + self.modify(configuration={"schema": new_attributes}) + self._client._add( collection_id=self.id, ids=add_request["ids"], @@ -255,6 +260,7 @@ def modify( # Note there is a race condition here where the metadata can be updated # but another thread sees the cached local metadata. # TODO: fixme + self._client._modify( id=self.id, new_name=name, @@ -317,15 +323,20 @@ def update( Returns: None """ - update_request = self._validate_and_prepare_update_request( + curr_schema = self._model.get_configuration().get("schema") + update_request, new_attributes = self._validate_and_prepare_update_request( ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents, images=images, uris=uris, + schema=curr_schema, ) + if len(new_attributes.keys()) > 0: + self.modify(configuration={"schema": new_attributes}) + self._client._update( collection_id=self.id, ids=update_request["ids"], @@ -362,15 +373,20 @@ def upsert( Returns: None """ - upsert_request = self._validate_and_prepare_upsert_request( + curr_schema = self._model.get_configuration().get("schema") + upsert_request, new_attributes = self._validate_and_prepare_upsert_request( ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents, images=images, uris=uris, + schema=curr_schema, ) + if len(new_attributes.keys()) > 0: + self.modify(configuration={"schema": new_attributes}) + self._client._upsert( collection_id=self.id, ids=upsert_request["ids"], diff --git a/chromadb/api/models/CollectionCommon.py b/chromadb/api/models/CollectionCommon.py index 683b0582799..f21a0a06403 100644 --- a/chromadb/api/models/CollectionCommon.py +++ b/chromadb/api/models/CollectionCommon.py @@ -10,12 +10,14 @@ TypeVar, Union, cast, + Tuple, ) from chromadb.types import Metadata import numpy as np from uuid import UUID - import chromadb.utils.embedding_functions as ef +from chromadb.base_types import CollectionSchema, ValueType + from chromadb.api.types import ( URI, URIs, @@ -47,6 +49,7 @@ maybe_cast_one_to_many, normalize_base_record_set, normalize_insert_record_set, + update_schema_with_insert_record_set, validate_base_record_set, validate_ids, validate_include, @@ -204,7 +207,8 @@ def _validate_and_prepare_add_request( documents: Optional[OneOrMany[Document]], images: Optional[OneOrMany[Image]], uris: Optional[OneOrMany[URI]], - ) -> AddRequest: + schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]], + ) -> Tuple[AddRequest, Dict[str, Dict[ValueType, CollectionSchema]]]: # Unpack add_records = normalize_insert_record_set( ids=ids, @@ -219,6 +223,10 @@ def _validate_and_prepare_add_request( validate_insert_record_set(record_set=add_records) validate_record_set_contains_any(record_set=add_records, contains_any={"ids"}) + new_attributes = update_schema_with_insert_record_set( + record_set=add_records, schema=schema + ) + # Prepare if add_records["embeddings"] is None: validate_record_set_for_embedding(record_set=add_records) @@ -226,12 +234,15 @@ def _validate_and_prepare_add_request( else: add_embeddings = add_records["embeddings"] - return AddRequest( - ids=add_records["ids"], - embeddings=add_embeddings, - metadatas=add_records["metadatas"], - documents=add_records["documents"], - uris=add_records["uris"], + return ( + AddRequest( + ids=add_records["ids"], + embeddings=add_embeddings, + metadatas=add_records["metadatas"], + documents=add_records["documents"], + uris=add_records["uris"], + ), + new_attributes, ) @validation_context("get") @@ -350,7 +361,8 @@ def _validate_and_prepare_update_request( documents: Optional[OneOrMany[Document]], images: Optional[OneOrMany[Image]], uris: Optional[OneOrMany[URI]], - ) -> UpdateRequest: + schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]], + ) -> Tuple[UpdateRequest, Dict[str, Dict[ValueType, CollectionSchema]]]: # Unpack update_records = normalize_insert_record_set( ids=ids, @@ -363,6 +375,9 @@ def _validate_and_prepare_update_request( # Validate validate_insert_record_set(record_set=update_records) + new_attributes = update_schema_with_insert_record_set( + record_set=update_records, schema=schema + ) # Prepare if update_records["embeddings"] is None: @@ -380,12 +395,15 @@ def _validate_and_prepare_update_request( else: update_embeddings = update_records["embeddings"] - return UpdateRequest( - ids=update_records["ids"], - embeddings=update_embeddings, - metadatas=update_records["metadatas"], - documents=update_records["documents"], - uris=update_records["uris"], + return ( + UpdateRequest( + ids=update_records["ids"], + embeddings=update_embeddings, + metadatas=update_records["metadatas"], + documents=update_records["documents"], + uris=update_records["uris"], + ), + new_attributes, ) @validation_context("upsert") @@ -402,7 +420,8 @@ def _validate_and_prepare_upsert_request( documents: Optional[OneOrMany[Document]] = None, images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, - ) -> UpsertRequest: + schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]] = None, + ) -> Tuple[UpsertRequest, Dict[str, Dict[ValueType, CollectionSchema]]]: # Unpack upsert_records = normalize_insert_record_set( ids=ids, @@ -415,7 +434,9 @@ def _validate_and_prepare_upsert_request( # Validate validate_insert_record_set(record_set=upsert_records) - + new_attributes = update_schema_with_insert_record_set( + record_set=upsert_records, schema=schema + ) # Prepare if upsert_records["embeddings"] is None: validate_record_set_for_embedding( @@ -425,12 +446,15 @@ def _validate_and_prepare_upsert_request( else: upsert_embeddings = upsert_records["embeddings"] - return UpsertRequest( - ids=upsert_records["ids"], - metadatas=upsert_records["metadatas"], - embeddings=upsert_embeddings, - documents=upsert_records["documents"], - uris=upsert_records["uris"], + return ( + UpsertRequest( + ids=upsert_records["ids"], + metadatas=upsert_records["metadatas"], + embeddings=upsert_embeddings, + documents=upsert_records["documents"], + uris=upsert_records["uris"], + ), + new_attributes, ) @validation_context("delete") diff --git a/chromadb/api/rust.py b/chromadb/api/rust.py index 9f04274ccdf..13fda58c00e 100644 --- a/chromadb/api/rust.py +++ b/chromadb/api/rust.py @@ -294,6 +294,7 @@ def _modify( ) else: new_configuration_json_str = None + self.bindings.update_collection( str(id), new_name, new_metadata, new_configuration_json_str ) diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 475d0561172..123f23285d3 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -27,6 +27,8 @@ Where, WhereDocumentOperator, WhereDocument, + CollectionSchema, + ValueType, ) from inspect import signature from tenacity import retry @@ -314,6 +316,47 @@ def validate_insert_record_set(record_set: InsertRecordSet) -> None: validate_metadatas(record_set["metadatas"]) +def update_schema_with_insert_record_set( + record_set: InsertRecordSet, + schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]] = None, +) -> Dict[str, Dict[ValueType, CollectionSchema]]: + """ + Updates the schema with the insert record set. + """ + new_attributes: Dict[str, Dict[ValueType, CollectionSchema]] = {} + + if schema is None: + schema = {} + + if record_set["metadatas"] is not None: + for metadata in record_set["metadatas"]: + if metadata is not None: + for key, value in metadata.items(): + if key not in schema: + if value is not None: + type_name = type(value).__name__ + if key not in new_attributes: + new_attributes[key] = {} + if type_name == "str": + new_attributes[key]["string"] = { + "metadata_index": True, + } + elif type_name == "int": + new_attributes[key]["int"] = { + "metadata_index": True, + } + elif type_name == "float": + new_attributes[key]["float"] = { + "metadata_index": True, + } + elif type_name == "bool": + new_attributes[key]["boolean"] = { + "metadata_index": True, + } + + return new_attributes + + def _validate_record_set_length_consistency(record_set: BaseRecordSet) -> None: lengths = [len(lst) for lst in record_set.values() if lst is not None] # type: ignore[arg-type] diff --git a/chromadb/base_types.py b/chromadb/base_types.py index b966d273790..2c0e5ff2815 100644 --- a/chromadb/base_types.py +++ b/chromadb/base_types.py @@ -1,5 +1,5 @@ from typing import Dict, List, Mapping, Optional, Sequence, Union -from typing_extensions import Literal +from typing_extensions import Literal, TypedDict import numpy as np from numpy.typing import NDArray @@ -36,3 +36,9 @@ LogicalOperator, ] WhereDocument = Dict[WhereDocumentOperator, Union[str, List["WhereDocument"]]] + +ValueType = Literal["int", "float", "string", "boolean"] + + +class CollectionSchema(TypedDict): + metadata_index: bool diff --git a/clients/js/packages/chromadb-core/src/generated/models.ts b/clients/js/packages/chromadb-core/src/generated/models.ts index f832cc9966b..698cf88885d 100644 --- a/clients/js/packages/chromadb-core/src/generated/models.ts +++ b/clients/js/packages/chromadb-core/src/generated/models.ts @@ -68,9 +68,16 @@ export namespace Api { embedding_function?: Api.EmbeddingFunctionConfiguration | null; hnsw?: Api.HnswConfiguration | null; query_config?: unknown; + schema?: { + [name: string]: { [name: string]: Api.CollectionSchema }; + } | null; spann?: Api.SpannConfiguration | null; } + export interface CollectionSchema { + metadata_index: boolean; + } + export interface CreateCollectionPayload { configuration?: Api.CollectionConfiguration | null; get_or_create?: boolean; @@ -337,6 +344,9 @@ export namespace Api { embedding_function?: Api.EmbeddingFunctionConfiguration | null; hnsw?: Api.UpdateHnswConfiguration | null; query_config?: unknown; + schema?: { + [name: string]: { [name: string]: Api.CollectionSchema }; + } | null; spann?: Api.SpannConfiguration | null; } diff --git a/clients/new-js/packages/chromadb/src/api/types.gen.ts b/clients/new-js/packages/chromadb/src/api/types.gen.ts index 07ebee6bd7e..feeb29d6ed8 100644 --- a/clients/new-js/packages/chromadb/src/api/types.gen.ts +++ b/clients/new-js/packages/chromadb/src/api/types.gen.ts @@ -33,9 +33,18 @@ export type CollectionConfiguration = { embedding_function?: null | EmbeddingFunctionConfiguration; hnsw?: null | HnswConfiguration; query_config?: unknown; + schema?: { + [key: string]: { + [key: string]: CollectionSchema; + }; + } | null; spann?: null | SpannConfiguration; }; +export type CollectionSchema = { + metadata_index: boolean; +}; + /** * CollectionUuid is a wrapper around Uuid to provide a type for the collection id. */ @@ -198,6 +207,11 @@ export type UpdateCollectionConfiguration = { embedding_function?: null | EmbeddingFunctionConfiguration; hnsw?: null | UpdateHnswConfiguration; query_config?: unknown; + schema?: { + [key: string]: { + [key: string]: CollectionSchema; + }; + } | null; spann?: null | SpannConfiguration; }; diff --git a/go/pkg/sysdb/coordinator/model/collection_configuration.go b/go/pkg/sysdb/coordinator/model/collection_configuration.go index 985ef1f17a7..0ba6af7e9b2 100644 --- a/go/pkg/sysdb/coordinator/model/collection_configuration.go +++ b/go/pkg/sysdb/coordinator/model/collection_configuration.go @@ -53,10 +53,15 @@ type SpannConfiguration struct { MergeThreshold int `json:"merge_threshold"` } +type CollectionSchema struct { + MetadataIndex bool `json:"metadata_index"` +} + type InternalCollectionConfiguration struct { - VectorIndex *VectorIndexConfiguration `json:"vector_index"` - EmbeddingFunction *EmbeddingFunctionConfiguration `json:"embedding_function,omitempty"` - QueryConfig interface{} `json:"query_config,omitempty"` + VectorIndex *VectorIndexConfiguration `json:"vector_index"` + EmbeddingFunction *EmbeddingFunctionConfiguration `json:"embedding_function,omitempty"` + QueryConfig interface{} `json:"query_config,omitempty"` + Schema map[string]map[string]CollectionSchema `json:"schema,omitempty"` } // DefaultHnswCollectionConfiguration returns a default configuration using HNSW @@ -126,7 +131,8 @@ type UpdateVectorIndexConfiguration struct { } type InternalUpdateCollectionConfiguration struct { - VectorIndex *UpdateVectorIndexConfiguration `json:"vector_index,omitempty"` - EmbeddingFunction *EmbeddingFunctionConfiguration `json:"embedding_function,omitempty"` - QueryConfig interface{} `json:"query_config,omitempty"` + VectorIndex *UpdateVectorIndexConfiguration `json:"vector_index,omitempty"` + EmbeddingFunction *EmbeddingFunctionConfiguration `json:"embedding_function,omitempty"` + QueryConfig interface{} `json:"query_config,omitempty"` + Schema map[string]map[string]CollectionSchema `json:"schema,omitempty"` } diff --git a/go/pkg/sysdb/coordinator/table_catalog.go b/go/pkg/sysdb/coordinator/table_catalog.go index a835e3d045f..12cb73056d2 100644 --- a/go/pkg/sysdb/coordinator/table_catalog.go +++ b/go/pkg/sysdb/coordinator/table_catalog.go @@ -858,6 +858,10 @@ func (tc *Catalog) updateCollectionConfiguration( existingConfig.QueryConfig = updateConfig.QueryConfig } + if updateConfig.Schema != nil { + existingConfig.Schema = mergeSchemas(existingConfig.Schema, updateConfig.Schema) + } + // Serialize updated config back to JSON updatedConfigBytes, err := json.Marshal(existingConfig) if err != nil { @@ -867,6 +871,27 @@ func (tc *Catalog) updateCollectionConfiguration( return &updatedConfigStr, nil } +func mergeSchemas(existingSchema map[string]map[string]model.CollectionSchema, newSchema map[string]map[string]model.CollectionSchema) map[string]map[string]model.CollectionSchema { + if newSchema == nil { + return existingSchema + } + if existingSchema == nil { + return newSchema + } + + for new_key, new_value := range newSchema { + if existingSchema[new_key] == nil { + existingSchema[new_key] = new_value + } else { + for update_value_type, update_collection_schema := range new_value { + existingSchema[new_key][update_value_type] = update_collection_schema + } + } + } + + return existingSchema +} + func (tc *Catalog) UpdateCollection(ctx context.Context, updateCollection *model.UpdateCollection, ts types.Timestamp) (*model.Collection, error) { log.Info("updating collection", zap.String("collectionId", updateCollection.ID.String())) var result *model.Collection diff --git a/rust/python_bindings/src/bindings.rs b/rust/python_bindings/src/bindings.rs index 2addd07b6dc..8285c8f6e65 100644 --- a/rust/python_bindings/src/bindings.rs +++ b/rust/python_bindings/src/bindings.rs @@ -281,6 +281,7 @@ impl Bindings { spann: None, embedding_function: None, query_config: None, + schema: None, }, self.frontend.get_default_knn_index(), )?), diff --git a/rust/segment/src/distributed_hnsw.rs b/rust/segment/src/distributed_hnsw.rs index 4528228b095..b873905829c 100644 --- a/rust/segment/src/distributed_hnsw.rs +++ b/rust/segment/src/distributed_hnsw.rs @@ -435,6 +435,7 @@ pub mod test { ), embedding_function: None, query_config: None, + schema: None, }, ..Default::default() }; diff --git a/rust/segment/src/distributed_spann.rs b/rust/segment/src/distributed_spann.rs index 35c5ed124cf..4598f6e01f6 100644 --- a/rust/segment/src/distributed_spann.rs +++ b/rust/segment/src/distributed_spann.rs @@ -657,6 +657,7 @@ mod test { vector_index: chroma_types::VectorIndexConfiguration::Spann(params), embedding_function: None, query_config: None, + schema: None, }, metadata: None, dimension: None, @@ -887,6 +888,7 @@ mod test { vector_index: chroma_types::VectorIndexConfiguration::Spann(params), embedding_function: None, query_config: None, + schema: None, }, ..Default::default() }; diff --git a/rust/sysdb/src/sqlite.rs b/rust/sysdb/src/sqlite.rs index 2a844e8e7a9..bb1e7424e85 100644 --- a/rust/sysdb/src/sqlite.rs +++ b/rust/sysdb/src/sqlite.rs @@ -373,7 +373,9 @@ impl SqliteSysDb { let collections = collections.unwrap(); let collection = collections.into_iter().next().unwrap(); let mut existing_configuration = collection.config; - existing_configuration.update(&configuration); + existing_configuration + .update(&configuration) + .map_err(|e| UpdateCollectionError::Internal(e.boxed()))?; configuration_json_str = Some( serde_json::to_string(&existing_configuration) .map_err(UpdateCollectionError::Configuration)?, @@ -1363,6 +1365,7 @@ mod tests { spann: None, embedding_function: None, query_config: None, + schema: None, }), ) .await diff --git a/rust/types/src/collection_configuration.rs b/rust/types/src/collection_configuration.rs index afe8646396a..5df5b31ef83 100644 --- a/rust/types/src/collection_configuration.rs +++ b/rust/types/src/collection_configuration.rs @@ -4,9 +4,31 @@ use crate::{ }; use chroma_error::{ChromaError, ErrorCodes}; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fmt; use thiserror::Error; use utoipa::ToSchema; +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, ToSchema, Eq, Hash)] +#[serde(rename_all = "snake_case")] +pub enum ValueType { + Int, + Float, + String, + Boolean, +} + +impl fmt::Display for ValueType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ValueType::Int => write!(f, "int"), + ValueType::Float => write!(f, "float"), + ValueType::String => write!(f, "string"), + ValueType::Boolean => write!(f, "boolean"), + } + } +} + #[derive(Deserialize, Serialize, Clone, Debug, Copy)] pub enum KnnIndex { #[serde(alias = "hnsw")] @@ -19,6 +41,11 @@ pub fn default_default_knn_index() -> KnnIndex { KnnIndex::Hnsw } +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, ToSchema)] +pub struct CollectionSchema { + pub metadata_index: bool, +} + #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, ToSchema)] #[serde(tag = "type")] pub enum EmbeddingFunctionConfiguration { @@ -86,6 +113,7 @@ pub struct InternalCollectionConfiguration { pub vector_index: VectorIndexConfiguration, pub embedding_function: Option, pub query_config: Option, + pub schema: Option>>, } impl InternalCollectionConfiguration { @@ -97,6 +125,7 @@ impl InternalCollectionConfiguration { vector_index: VectorIndexConfiguration::Hnsw(hnsw), embedding_function: None, query_config: None, + schema: None, }) } @@ -105,6 +134,7 @@ impl InternalCollectionConfiguration { vector_index: VectorIndexConfiguration::Hnsw(InternalHnswConfiguration::default()), embedding_function: None, query_config: None, + schema: None, } } @@ -113,6 +143,7 @@ impl InternalCollectionConfiguration { vector_index: VectorIndexConfiguration::Spann(InternalSpannConfiguration::default()), embedding_function: None, query_config: None, + schema: None, } } @@ -155,10 +186,13 @@ impl InternalCollectionConfiguration { } } - pub fn update(&mut self, configuration: &UpdateCollectionConfiguration) { + pub fn update( + &mut self, + update_configuration: &UpdateCollectionConfiguration, + ) -> Result<(), UpdateCollectionConfigurationToInternalConfigurationError> { // Update vector_index if it exists in the update configuration - if let Some(hnsw_config) = &configuration.hnsw { + if let Some(hnsw_config) = &update_configuration.hnsw { if let VectorIndexConfiguration::Hnsw(current_config) = &mut self.vector_index { // Update only the non-None fields from the update configuration if let Some(ef_search) = hnsw_config.ef_search { @@ -181,7 +215,7 @@ impl InternalCollectionConfiguration { } } } - if let Some(spann_config) = &configuration.spann { + if let Some(spann_config) = &update_configuration.spann { if let VectorIndexConfiguration::Spann(current_config) = &mut self.vector_index { if let Some(search_nprobe) = spann_config.search_nprobe { current_config.search_nprobe = search_nprobe; @@ -203,12 +237,32 @@ impl InternalCollectionConfiguration { } } // Update embedding_function if it exists in the update configuration - if let Some(embedding_function) = &configuration.embedding_function { + if let Some(embedding_function) = &update_configuration.embedding_function { self.embedding_function = Some(embedding_function.clone()); } - if let Some(query_config) = &configuration.query_config { + if let Some(query_config) = &update_configuration.query_config { self.query_config = Some(query_config.clone()); } + if let Some(update_schema) = &update_configuration.schema { + if let Some(current_schema) = &mut self.schema { + for (update_key, update_value) in update_schema { + if let Some(current_value) = current_schema.get_mut(update_key) { + for (update_value_type, update_collection_schema) in update_value { + current_value.insert( + update_value_type.clone(), + update_collection_schema.clone(), + ); + } + } else { + current_schema.insert(update_key.clone(), update_value.clone()); + } + } + } else { + self.schema = Some(update_schema.clone()); + } + } + + Ok(()) } pub fn try_from_config( @@ -231,6 +285,7 @@ impl InternalCollectionConfiguration { vector_index: VectorIndexConfiguration::Spann(internal_config), embedding_function: value.embedding_function, query_config: value.query_config, + schema: value.schema, }) }, KnnIndex::Hnsw => { @@ -239,6 +294,7 @@ impl InternalCollectionConfiguration { vector_index: hnsw.into(), embedding_function: value.embedding_function, query_config: value.query_config, + schema: value.schema, }) } } @@ -257,6 +313,7 @@ impl InternalCollectionConfiguration { vector_index: VectorIndexConfiguration::Hnsw(internal_config), embedding_function: value.embedding_function, query_config: value.query_config, + schema: value.schema, }) } KnnIndex::Spann => { @@ -265,6 +322,7 @@ impl InternalCollectionConfiguration { vector_index: spann.into(), embedding_function: value.embedding_function, query_config: value.query_config, + schema: value.schema, }) } } @@ -278,6 +336,7 @@ impl InternalCollectionConfiguration { vector_index, embedding_function: value.embedding_function, query_config: value.query_config, + schema: value.schema, }) } } @@ -288,6 +347,8 @@ impl TryFrom for InternalCollectionConfiguration { type Error = CollectionConfigurationToInternalConfigurationError; fn try_from(value: CollectionConfiguration) -> Result { + // validate the schema + validate_schema(&value.schema)?; match (value.hnsw, value.spann) { (Some(_), Some(_)) => Err(Self::Error::MultipleVectorIndexConfigurations), (Some(hnsw), None) => { @@ -296,6 +357,7 @@ impl TryFrom for InternalCollectionConfiguration { vector_index: hnsw.into(), embedding_function: value.embedding_function, query_config: value.query_config, + schema: value.schema, }) } (None, Some(spann)) => { @@ -304,27 +366,45 @@ impl TryFrom for InternalCollectionConfiguration { vector_index: spann.into(), embedding_function: value.embedding_function, query_config: value.query_config, + schema: value.schema, }) } (None, None) => Ok(InternalCollectionConfiguration { vector_index: InternalHnswConfiguration::default().into(), embedding_function: value.embedding_function, query_config: value.query_config, + schema: value.schema, }), } } } +fn validate_schema( + schema: &Option>>, +) -> Result<(), CollectionConfigurationToInternalConfigurationError> { + // get list of keys, any duplicates are invalid + if let Some(schema) = schema { + let keys = schema.keys().collect::>(); + if keys.len() != schema.len() { + return Err(CollectionConfigurationToInternalConfigurationError::SchemaDuplicateKeys); + } + } + Ok(()) +} + #[derive(Debug, Error)] pub enum CollectionConfigurationToInternalConfigurationError { #[error("Multiple vector index configurations provided")] MultipleVectorIndexConfigurations, + #[error("Schema duplicate keys")] + SchemaDuplicateKeys, } impl ChromaError for CollectionConfigurationToInternalConfigurationError { fn code(&self) -> ErrorCodes { match self { Self::MultipleVectorIndexConfigurations => ErrorCodes::InvalidArgument, + Self::SchemaDuplicateKeys => ErrorCodes::InvalidArgument, } } } @@ -336,6 +416,7 @@ pub struct CollectionConfiguration { pub spann: Option, pub embedding_function: Option, pub query_config: Option, + pub schema: Option>>, } impl From for CollectionConfiguration { @@ -351,6 +432,7 @@ impl From for CollectionConfiguration { }, embedding_function: value.embedding_function, query_config: value.query_config, + schema: value.schema, } } } @@ -378,12 +460,15 @@ impl From for UpdateVectorIndexConfiguration { pub enum UpdateCollectionConfigurationToInternalConfigurationError { #[error("Multiple vector index configurations provided")] MultipleVectorIndexConfigurations, + #[error("Schema value type mismatch: existing: {0}, updated: {1}")] + SchemaValueTypeMismatch(ValueType, ValueType), } impl ChromaError for UpdateCollectionConfigurationToInternalConfigurationError { fn code(&self) -> ErrorCodes { match self { Self::MultipleVectorIndexConfigurations => ErrorCodes::InvalidArgument, + Self::SchemaValueTypeMismatch(_, _) => ErrorCodes::InvalidArgument, } } } @@ -395,6 +480,7 @@ pub struct UpdateCollectionConfiguration { pub spann: Option, pub embedding_function: Option, pub query_config: Option, + pub schema: Option>>, } #[cfg(test)] @@ -444,6 +530,7 @@ mod tests { }), embedding_function: None, query_config: None, + schema: None, }; let overridden_config = config @@ -473,6 +560,7 @@ mod tests { spann: None, embedding_function: None, query_config: None, + schema: None, }; let internal_config_result = @@ -503,6 +591,7 @@ mod tests { spann: None, embedding_function: None, query_config: None, + schema: None, }; let internal_config_result = @@ -537,6 +626,7 @@ mod tests { spann: Some(spann_config.clone()), embedding_function: None, query_config: None, + schema: None, }; let internal_config_result = @@ -568,6 +658,7 @@ mod tests { spann: Some(spann_config.clone()), embedding_function: None, query_config: None, + schema: None, }; let internal_config_result =