From 964072d37c11a1e5f5e465a2cf760035b1c6b875 Mon Sep 17 00:00:00 2001 From: groot Date: Wed, 5 Apr 2023 04:56:09 +0800 Subject: [PATCH] Refine Milvus datastore (#87) Signed-off-by: groot --- datastore/providers/milvus_datastore.py | 528 +++++++++++------- datastore/providers/zilliz_datastore.py | 391 +------------ docs/providers/milvus/setup.md | 23 +- docs/providers/zilliz/setup.md | 19 +- .../providers/milvus/test_milvus_datastore.py | 53 +- .../providers/zilliz/test_zilliz_datastore.py | 325 +---------- 6 files changed, 408 insertions(+), 931 deletions(-) diff --git a/datastore/providers/milvus_datastore.py b/datastore/providers/milvus_datastore.py index edb7a5240..202e86d55 100644 --- a/datastore/providers/milvus_datastore.py +++ b/datastore/providers/milvus_datastore.py @@ -1,3 +1,4 @@ +import json import os import asyncio @@ -33,24 +34,28 @@ MILVUS_PASSWORD = os.environ.get("MILVUS_PASSWORD") MILVUS_USE_SECURITY = False if MILVUS_PASSWORD is None else True +MILVUS_INDEX_PARAMS = os.environ.get("MILVUS_INDEX_PARAMS") +MILVUS_SEARCH_PARAMS = os.environ.get("MILVUS_SEARCH_PARAMS") +MILVUS_CONSISTENCY_LEVEL = os.environ.get("MILVUS_CONSISTENCY_LEVEL") + UPSERT_BATCH_SIZE = 100 OUTPUT_DIM = 1536 +EMBEDDING_FIELD = "embedding" class Required: pass - # The fields names that we are going to be storing within Milvus, the field declaration for schema creation, and the default value -SCHEMA = [ +SCHEMA_V1 = [ ( "pk", FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True), Required, ), ( - "embedding", - FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=OUTPUT_DIM), + EMBEDDING_FIELD, + FieldSchema(name=EMBEDDING_FIELD, dtype=DataType.FLOAT_VECTOR, dim=OUTPUT_DIM), Required, ), ( @@ -91,13 +96,16 @@ class Required: ), ] +# V2 schema, remomve the "pk" field +SCHEMA_V2 = SCHEMA_V1[1:] +SCHEMA_V2[4][1].is_primary = True + class MilvusDataStore(DataStore): def __init__( self, create_new: Optional[bool] = False, - index_params: Optional[dict] = None, - search_params: Optional[dict] = None, + consistency_level: str = "Bounded", ): """Create a Milvus DataStore. @@ -105,119 +113,168 @@ def __init__( Args: create_new (Optional[bool], optional): Whether to overwrite if collection already exists. Defaults to True. - index_params (Optional[dict], optional): Custom index params to use. Defaults to None. - search_params (Optional[dict], optional): Custom search params to use. Defaults to None. + consistency_level(str, optional): Specify the collection consistency level. + Defaults to "Bounded" for search performance. + Set to "Strong" in test cases for result validation. """ + # Overwrite the default consistency level by MILVUS_CONSISTENCY_LEVEL + self._consistency_level = MILVUS_CONSISTENCY_LEVEL or consistency_level + self._create_connection() + + self._create_collection(MILVUS_COLLECTION, create_new) # type: ignore + self._create_index() + + def _print_info(self, msg): + # TODO: logger + print(msg) + + def _print_err(self, msg): + # TODO: logger + print(msg) + + def _get_schema(self): + return SCHEMA_V1 if self._schema_ver == "V1" else SCHEMA_V2 - # # TODO: Auto infer the fields - # non_string_fields = [('embedding', List[float]), ('created_at', int)] - # fields_to_index = list(DocumentChunkMetadata.__fields__.keys()) - # fields_to_index = list(DocumentChunk.__fields__.keys()) - - # Set the index_params to passed in or the default - self.index_params = index_params - - # The default search params - self.default_search_params = { - "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, - "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}}, - "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, - "HNSW": {"metric_type": "L2", "params": {"ef": 10}}, - "RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}}, - "RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}}, - "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}}, - "IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}}, - "ANNOY": {"metric_type": "L2", "params": {"search_k": 10}}, - "AUTOINDEX": {"metric_type": "L2", "params": {}}, - } - - # Check if the connection already exists + def _create_connection(self): try: - i = [ - connections.get_connection_addr(x[0]) - for x in connections.list_connections() - ].index({"host": MILVUS_HOST, "port": MILVUS_PORT}) - self.alias = connections.list_connections()[i][0] - except ValueError: + self.alias = "" + # Check if the connection already exists + for x in connections.list_connections(): + addr = connections.get_connection_addr(x[0]) + if x[1] and ('address' in addr) and (addr['address'] == "{}:{}".format(MILVUS_HOST, MILVUS_PORT)): + self.alias = x[0] + self._print_info("Reuse connection to Milvus server '{}:{}' with alias '{:s}'" + .format(MILVUS_HOST, MILVUS_PORT, self.alias)) + break + # Connect to the Milvus instance using the passed in Environment variables - self.alias = uuid4().hex - connections.connect( - alias=self.alias, - host=MILVUS_HOST, - port=MILVUS_PORT, - user=MILVUS_USER, # type: ignore - password=MILVUS_PASSWORD, # type: ignore - secure=MILVUS_USE_SECURITY, - ) - - self._create_collection(create_new) # type: ignore - - index_params = self.index_params or {} - - # Use in the passed in search params or the default for the specified index - self.search_params = ( - search_params or self.default_search_params[index_params["index_type"]] - ) + if len(self.alias) == 0: + self.alias = uuid4().hex + connections.connect( + alias=self.alias, + host=MILVUS_HOST, + port=MILVUS_PORT, + user=MILVUS_USER, # type: ignore + password=MILVUS_PASSWORD, # type: ignore + secure=MILVUS_USE_SECURITY, + ) + self._print_info("Create connection to Milvus server '{}:{}' with alias '{:s}'" + .format(MILVUS_HOST, MILVUS_PORT, self.alias)) + except Exception as e: + self._print_err("Failed to create connection to Milvus server '{}:{}', error: {}" + .format(MILVUS_HOST, MILVUS_PORT, e)) - def _create_collection(self, create_new: bool) -> None: + def _create_collection(self, collection_name, create_new: bool) -> None: """Create a collection based on environment and passed in variables. Args: create_new (bool): Whether to overwrite if collection already exists. """ - - # If the collection exists and create_new is True, drop the existing collection - if utility.has_collection(MILVUS_COLLECTION, using=self.alias) and create_new: - utility.drop_collection(MILVUS_COLLECTION, using=self.alias) - - # Check if the collection doesn't exist - if utility.has_collection(MILVUS_COLLECTION, using=self.alias) is False: - # If it doesn't exist use the field params from init to create a new schema - schema = [field[1] for field in SCHEMA] - schema = CollectionSchema(schema) - # Use the schema to create a new collection - self.col = Collection( - MILVUS_COLLECTION, - schema=schema, - consistency_level="Strong", - using=self.alias, - ) - else: - # If the collection exists, point to it - self.col = Collection( - MILVUS_COLLECTION, consistency_level="Strong", using=self.alias - ) # type: ignore - - # If no index on the collection, create one - if len(self.col.indexes) == 0: - if self.index_params is not None: - # Create an index on the 'embedding' field with the index params found in init - self.col.create_index("embedding", index_params=self.index_params) + try: + self._schema_ver = "V1" + # If the collection exists and create_new is True, drop the existing collection + if utility.has_collection(collection_name, using=self.alias) and create_new: + utility.drop_collection(collection_name, using=self.alias) + + # Check if the collection doesnt exist + if utility.has_collection(collection_name, using=self.alias) is False: + # If it doesnt exist use the field params from init to create a new schem + schema = [field[1] for field in SCHEMA_V2] + schema = CollectionSchema(schema) + # Use the schema to create a new collection + self.col = Collection( + collection_name, + schema=schema, + using=self.alias, + consistency_level=self._consistency_level, + ) + self._schema_ver = "V2" + self._print_info("Create Milvus collection '{}' with schema {} and consistency level {}" + .format(collection_name, self._schema_ver, self._consistency_level)) else: - # If no index param supplied, to first create an HNSW index for Milvus - try: - print("Attempting creation of Milvus default index") - i_p = { - "metric_type": "L2", - "index_type": "HNSW", - "params": {"M": 8, "efConstruction": 64}, - } - - self.col.create_index("embedding", index_params=i_p) - self.index_params = i_p - print("Creation of Milvus default index successful") - # If create fails, most likely due to being Zilliz Cloud instance, try to create an AutoIndex - except MilvusException: - print("Attempting creation of Zilliz Cloud default index") - i_p = {"metric_type": "L2", "index_type": "AUTOINDEX", "params": {}} - self.col.create_index("embedding", index_params=i_p) - self.index_params = i_p - print("Creation of Zilliz Cloud default index successful") - # If an index already exists, grab its params - else: - self.index_params = self.col.indexes[0].to_dict()["index_param"] - - self.col.load() + # If the collection exists, point to it + self.col = Collection( + collection_name, using=self.alias + ) # type: ignore + # Which sechma is used + for field in self.col.schema.fields: + if field.name == "id" and field.is_primary: + self._schema_ver = "V2" + break + self._print_info("Milvus collection '{}' already exists with schema {}" + .format(collection_name, self._schema_ver)) + except Exception as e: + self._print_err("Failed to create collection '{}', error: {}".format(collection_name, e)) + + def _create_index(self): + # TODO: verify index/search params passed by os.environ + self.index_params = MILVUS_INDEX_PARAMS or None + self.search_params = MILVUS_SEARCH_PARAMS or None + try: + # If no index on the collection, create one + if len(self.col.indexes) == 0: + if self.index_params is not None: + # Convert the string format to JSON format parameters passed by MILVUS_INDEX_PARAMS + self.index_params = json.loads(self.index_params) + self._print_info("Create Milvus index: {}".format(self.index_params)) + # Create an index on the 'embedding' field with the index params found in init + self.col.create_index(EMBEDDING_FIELD, index_params=self.index_params) + else: + # If no index param supplied, to first create an HNSW index for Milvus + try: + i_p = { + "metric_type": "IP", + "index_type": "HNSW", + "params": {"M": 8, "efConstruction": 64}, + } + self._print_info("Attempting creation of Milvus '{}' index".format(i_p["index_type"])) + self.col.create_index(EMBEDDING_FIELD, index_params=i_p) + self.index_params = i_p + self._print_info("Creation of Milvus '{}' index successful".format(i_p["index_type"])) + # If create fails, most likely due to being Zilliz Cloud instance, try to create an AutoIndex + except MilvusException: + self._print_info("Attempting creation of Milvus default index") + i_p = {"metric_type": "IP", "index_type": "AUTOINDEX", "params": {}} + self.col.create_index(EMBEDDING_FIELD, index_params=i_p) + self.index_params = i_p + self._print_info("Creation of Milvus default index successful") + # If an index already exists, grab its params + else: + # How about if the first index is not vector index? + for index in self.col.indexes: + idx = index.to_dict() + if idx["field"] == EMBEDDING_FIELD: + self._print_info("Index already exists: {}".format(idx)) + self.index_params = idx['index_param'] + break + + self.col.load() + + if self.search_params is not None: + # Convert the string format to JSON format parameters passed by MILVUS_SEARCH_PARAMS + self.search_params = json.loads(self.search_params) + else: + # The default search params + metric_type = "IP" + if "metric_type" in self.index_params: + metric_type = self.index_params["metric_type"] + default_search_params = { + "IVF_FLAT": {"metric_type": metric_type, "params": {"nprobe": 10}}, + "IVF_SQ8": {"metric_type": metric_type, "params": {"nprobe": 10}}, + "IVF_PQ": {"metric_type": metric_type, "params": {"nprobe": 10}}, + "HNSW": {"metric_type": metric_type, "params": {"ef": 10}}, + "RHNSW_FLAT": {"metric_type": metric_type, "params": {"ef": 10}}, + "RHNSW_SQ": {"metric_type": metric_type, "params": {"ef": 10}}, + "RHNSW_PQ": {"metric_type": metric_type, "params": {"ef": 10}}, + "IVF_HNSW": {"metric_type": metric_type, "params": {"nprobe": 10, "ef": 10}}, + "ANNOY": {"metric_type": metric_type, "params": {"search_k": 10}}, + "AUTOINDEX": {"metric_type": metric_type, "params": {}}, + } + # Set the search params + self.search_params = default_search_params[self.index_params["index_type"]] + self._print_info("Milvus search parameters: {}".format(self.search_params)) + except Exception as e: + self._print_err("Failed to create index, error: {}".format(e)) async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: """Upsert chunks into the datastore. @@ -231,44 +288,51 @@ async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: Returns: List[str]: The document_id's that were inserted. """ - # The doc id's to return for the upsert - doc_ids: List[str] = [] - # List to collect all the insert data - insert_data = [[] for _ in range(len(SCHEMA) - 1)] - # Go through each document chunklist and grab the data - for doc_id, chunk_list in chunks.items(): - # Append the doc_id to the list we are returning - doc_ids.append(doc_id) - # Examine each chunk in the chunklist - for chunk in chunk_list: - # Extract data from the chunk - list_of_data = self._get_values(chunk) - # Check if the data is valid - if list_of_data is not None: - # Append each field to the insert_data - for x in range(len(insert_data)): - insert_data[x].append(list_of_data[x]) - # Slice up our insert data into batches - batches = [ - insert_data[i : i + UPSERT_BATCH_SIZE] - for i in range(0, len(insert_data), UPSERT_BATCH_SIZE) - ] - - # Attempt to insert each batch into our collection - for batch in batches: - if len(batch[0]) != 0: - try: - print(f"Upserting batch of size {len(batch[0])}") - self.col.insert(batch) - print(f"Upserted batch successfully") - except Exception as e: - print(f"Error upserting batch: {e}") - raise e - - # This setting performs flushes after insert. Small insert == bad to use - # self.col.flush() + try: + # The doc id's to return for the upsert + doc_ids: List[str] = [] + # List to collect all the insert data, skip the "pk" for schema V1 + offset = 1 if self._schema_ver == "V1" else 0 + insert_data = [[] for _ in range(len(self._get_schema()) - offset)] + + # Go through each document chunklist and grab the data + for doc_id, chunk_list in chunks.items(): + # Append the doc_id to the list we are returning + doc_ids.append(doc_id) + # Examine each chunk in the chunklist + for chunk in chunk_list: + # Extract data from the chunk + list_of_data = self._get_values(chunk) + # Check if the data is valid + if list_of_data is not None: + # Append each field to the insert_data + for x in range(len(insert_data)): + insert_data[x].append(list_of_data[x]) + # Slice up our insert data into batches + batches = [ + insert_data[i : i + UPSERT_BATCH_SIZE] + for i in range(0, len(insert_data), UPSERT_BATCH_SIZE) + ] + + # Attempt to insert each batch into our collection + # batch data can work with both V1 and V2 schema + for batch in batches: + if len(batch[0]) != 0: + try: + self._print_info(f"Upserting batch of size {len(batch[0])}") + self.col.insert(batch) + self._print_info(f"Upserted batch successfully") + except Exception as e: + self._print_err(f"Failed to insert batch records, error: {e}") + raise e + + # This setting perfoms flushes after insert. Small insert == bad to use + # self.col.flush() + return doc_ids + except Exception as e: + self._print_err("Failed to insert records, error: {}".format(e)) + return [] - return doc_ids def _get_values(self, chunk: DocumentChunk) -> List[any] | None: # type: ignore """Convert the chunk into a list of values to insert whose indexes align with fields. @@ -294,13 +358,14 @@ def _get_values(self, chunk: DocumentChunk) -> List[any] | None: # type: ignore values["source"] = values["source"].value # List to collect data we will return ret = [] - # Grab data responding to each field excluding the hidden auto pk field - for key, _, default in SCHEMA[1:]: + # Grab data responding to each field, excluding the hidden auto pk field for schema V1 + offset = 1 if self._schema_ver == "V1" else 0 + for key, _, default in self._get_schema()[offset:]: # Grab the data at the key and default to our defaults set in init x = values.get(key) or default # If one of our required fields is missing, ignore the entire entry if x is Required: - print("Chunk " + values["id"] + " missing " + key + " skipping") + self._print_info("Chunk " + values["id"] + " missing " + key + " skipping") return None # Add the corresponding value if it passes the tests ret.append(x) @@ -322,53 +387,57 @@ async def _query( """ # Async to perform the query, adapted from pinecone implementation async def _single_query(query: QueryWithEmbedding) -> QueryResult: - - filter = None - # Set the filter to expression that is valid for Milvus - if query.filter is not None: - # Either a valid filter or None will be returned - filter = self._get_filter(query.filter) - - # Perform our search - res = self.col.search( - data=[query.embedding], - anns_field="embedding", - param=self.search_params, - limit=query.top_k, - expr=filter, - output_fields=[ - field[0] for field in SCHEMA[2:] - ], # Ignoring pk, embedding - ) - # Results that will hold our DocumentChunkWithScores - results = [] - # Parse every result for our search - for hit in res[0]: # type: ignore - # The distance score for the search result, falls under DocumentChunkWithScore - score = hit.score - # Our metadata info, falls under DocumentChunkMetadata - metadata = {} - # Grab the values that correspond to our fields, ignore pk and embedding. - for x in [field[0] for field in SCHEMA[2:]]: - metadata[x] = hit.entity.get(x) - # If the source isn't valid, convert to None - if metadata["source"] not in Source.__members__: - metadata["source"] = None - # Text falls under the DocumentChunk - text = metadata.pop("text") - # Id falls under the DocumentChunk - ids = metadata.pop("id") - chunk = DocumentChunkWithScore( - id=ids, - score=score, - text=text, - metadata=DocumentChunkMetadata(**metadata), + try: + filter = None + # Set the filter to expression that is valid for Milvus + if query.filter is not None: + # Either a valid filter or None will be returned + filter = self._get_filter(query.filter) + + # Perform our search + return_from = 2 if self._schema_ver == "V1" else 1 + res = self.col.search( + data=[query.embedding], + anns_field=EMBEDDING_FIELD, + param=self.search_params, + limit=query.top_k, + expr=filter, + output_fields=[ + field[0] for field in self._get_schema()[return_from:] + ], # Ignoring pk, embedding ) - results.append(chunk) + # Results that will hold our DocumentChunkWithScores + results = [] + # Parse every result for our search + for hit in res[0]: # type: ignore + # The distance score for the search result, falls under DocumentChunkWithScore + score = hit.score + # Our metadata info, falls under DocumentChunkMetadata + metadata = {} + # Grab the values that correspond to our fields, ignore pk and embedding. + for x in [field[0] for field in self._get_schema()[return_from:]]: + metadata[x] = hit.entity.get(x) + # If the source isn't valid, convert to None + if metadata["source"] not in Source.__members__: + metadata["source"] = None + # Text falls under the DocumentChunk + text = metadata.pop("text") + # Id falls under the DocumentChunk + ids = metadata.pop("id") + chunk = DocumentChunkWithScore( + id=ids, + score=score, + text=text, + metadata=DocumentChunkMetadata(**metadata), + ) + results.append(chunk) - # TODO: decide on doing queries to grab the embedding itself, slows down performance as double query occurs + # TODO: decide on doing queries to grab the embedding itself, slows down performance as double query occurs - return QueryResult(query=query.query, results=results) + return QueryResult(query=query.query, results=results) + except Exception as e: + self._print_err("Failed to query, error: {}".format(e)) + return QueryResult(query=query.query, results=[]) results: List[QueryResult] = await asyncio.gather( *[_single_query(query) for query in queries] @@ -390,49 +459,74 @@ async def delete( """ # If deleting all, drop and create the new collection if delete_all: + coll_name = self.col.name + self._print_info("Delete the entire collection {} and create new one".format(coll_name)) # Release the collection from memory self.col.release() # Drop the collection self.col.drop() # Recreate the new collection - self._create_collection(True) + self._create_collection(coll_name, True) + self._create_index() return True # Keep track of how many we have deleted for later printing delete_count = 0 - - # Check if empty ids - if ids is not None: - if len(ids) != 0: + batch_size = 100 + pk_name = "pk" if self._schema_ver == "V1" else "id" + try: + # According to the api design, the ids is a list of document_id, + # document_id is not primary key, use query+delete to workaround, + # in future version we can delete by expression + if (ids is not None) and len(ids) > 0: # Add quotation marks around the string format id ids = ['"' + str(id) + '"' for id in ids] # Query for the pk's of entries that match id's ids = self.col.query(f"document_id in [{','.join(ids)}]") # Convert to list of pks - ids = [str(entry["pk"]) for entry in ids] # type: ignore - # Check to see if there are valid pk's to delete - if len(ids) != 0: - # Delete the entries for each pk - res = self.col.delete(f"pk in [{','.join(ids)}]") + pks = [str(entry[pk_name]) for entry in ids] # type: ignore + # for schema V2, the "id" is varchar, rewrite the expression + if self._schema_ver != "V1": + pks = ['"' + pk + '"' for pk in pks] + + # Delete by ids batch by batch(avoid too long expression) + self._print_info("Apply {:d} deletions to schema {:s}".format(len(pks), self._schema_ver)) + while len(pks) > 0: + batch_pks = pks[:batch_size] + pks = pks[batch_size:] + # Delete the entries batch by batch + res = self.col.delete(f"{pk_name} in [{','.join(batch_pks)}]") # Increment our deleted count delete_count += int(res.delete_count) # type: ignore + except Exception as e: + self._print_err("Failed to delete by ids, error: {}".format(e)) - # Check if empty filter - if filter is not None: - # Convert filter to milvus expression - filter = self._get_filter(filter) # type: ignore - # Check if there is anything to filter - if len(filter) != 0: # type: ignore - # Query for the pk's of entries that match filter - filter = self.col.query(filter) # type: ignore - # Convert to list of pks - filter = [str(entry["pk"]) for entry in filter] # type: ignore - # Check to see if there are valid pk's to delete + try: + # Check if empty filter + if filter is not None: + # Convert filter to milvus expression + filter = self._get_filter(filter) # type: ignore + # Check if there is anything to filter if len(filter) != 0: # type: ignore - # Delete the entries - res = self.col.delete(f"pk in [{','.join(filter)}]") # type: ignore - # Increment our delete count - delete_count += int(res.delete_count) # type: ignore + # Query for the pk's of entries that match filter + res = self.col.query(filter) # type: ignore + # Convert to list of pks + pks = [str(entry[pk_name]) for entry in res] # type: ignore + # for schema V2, the "id" is varchar, rewrite the expression + if self._schema_ver != "V1": + pks = ['"' + pk + '"' for pk in pks] + # Check to see if there are valid pk's to delete, delete batch by batch(avoid too long expression) + while len(pks) > 0: # type: ignore + batch_pks = pks[:batch_size] + pks = pks[batch_size:] + # Delete the entries batch by batch + res = self.col.delete(f"{pk_name} in [{','.join(batch_pks)}]") # type: ignore + # Increment our delete count + delete_count += int(res.delete_count) # type: ignore + except Exception as e: + self._print_err("Failed to delete by filter, error: {}".format(e)) + + self._print_info("{:d} records deleted".format(delete_count)) # This setting performs flushes after delete. Small delete == bad to use # self.col.flush() diff --git a/datastore/providers/zilliz_datastore.py b/datastore/providers/zilliz_datastore.py index 68ace1592..1db641f63 100644 --- a/datastore/providers/zilliz_datastore.py +++ b/datastore/providers/zilliz_datastore.py @@ -1,97 +1,25 @@ import os -import asyncio -from typing import Dict, List, Optional +from typing import Optional from pymilvus import ( - Collection, connections, - utility, - FieldSchema, - DataType, - CollectionSchema, ) from uuid import uuid4 - -from services.date import to_unix_timestamp -from datastore.datastore import DataStore -from models.models import ( - DocumentChunk, - DocumentChunkMetadata, - Source, - DocumentMetadataFilter, - QueryResult, - QueryWithEmbedding, - DocumentChunkWithScore, +from datastore.providers.milvus_datastore import ( + MilvusDataStore, ) + ZILLIZ_COLLECTION = os.environ.get("ZILLIZ_COLLECTION") or "c" + uuid4().hex ZILLIZ_URI = os.environ.get("ZILLIZ_URI") ZILLIZ_USER = os.environ.get("ZILLIZ_USER") ZILLIZ_PASSWORD = os.environ.get("ZILLIZ_PASSWORD") ZILLIZ_USE_SECURITY = False if ZILLIZ_PASSWORD is None else True +ZILLIZ_CONSISTENCY_LEVEL = os.environ.get("ZILLIZ_CONSISTENCY_LEVEL") -UPSERT_BATCH_SIZE = 100 -OUTPUT_DIM = 1536 - - -class Required: - pass - - -# The fields names that we are going to be storing within Zilliz Cloud, the field declaration for schema creation, and the default value -SCHEMA = [ - ( - "pk", - FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True), - Required, - ), - ( - "embedding", - FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=OUTPUT_DIM), - Required, - ), - ( - "text", - FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), - Required, - ), - ( - "document_id", - FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=65535), - "", - ), - ( - "source_id", - FieldSchema(name="source_id", dtype=DataType.VARCHAR, max_length=65535), - "", - ), - ( - "id", - FieldSchema( - name="id", - dtype=DataType.VARCHAR, - max_length=65535, - ), - "", - ), - ( - "source", - FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=65535), - "", - ), - ("url", FieldSchema(name="url", dtype=DataType.VARCHAR, max_length=65535), ""), - ("created_at", FieldSchema(name="created_at", dtype=DataType.INT64), -1), - ( - "author", - FieldSchema(name="author", dtype=DataType.VARCHAR, max_length=65535), - "", - ), -] - - -class ZillizDataStore(DataStore): +class ZillizDataStore(MilvusDataStore): def __init__(self, create_new: Optional[bool] = False): """Create a Zilliz DataStore. @@ -100,12 +28,14 @@ def __init__(self, create_new: Optional[bool] = False): Args: create_new (Optional[bool], optional): Whether to overwrite if collection already exists. Defaults to True. """ + # Overwrite the default consistency level by MILVUS_CONSISTENCY_LEVEL + self._consistency_level = ZILLIZ_CONSISTENCY_LEVEL or "Bounded" + self._create_connection() - # # TODO: Auto infer the fields - # non_string_fields = [('embedding', List[float]), ('created_at', int)] - # fields_to_index = list(DocumentChunkMetadata.__fields__.keys()) - # fields_to_index = list(DocumentChunk.__fields__.keys()) + self._create_collection(ZILLIZ_COLLECTION, create_new) # type: ignore + self._create_index() + def _create_connection(self): # Check if the connection already exists try: i = [ @@ -117,293 +47,18 @@ def __init__(self, create_new: Optional[bool] = False): # Connect to the Zilliz instance using the passed in Environment variables self.alias = uuid4().hex connections.connect(alias=self.alias, uri=ZILLIZ_URI, user=ZILLIZ_USER, password=ZILLIZ_PASSWORD, secure=ZILLIZ_USE_SECURITY) # type: ignore + self._print_info("Connect to zilliz cloud server") - self._create_collection(create_new) # type: ignore - - def _create_collection(self, create_new: bool) -> None: - """Create a collection based on environment and passed in variables. - - Args: - create_new (bool): Whether to overwrite if collection already exists. - """ - - # If the collection exists and create_new is True, drop the existing collection - if utility.has_collection(ZILLIZ_COLLECTION, using=self.alias) and create_new: - utility.drop_collection(ZILLIZ_COLLECTION, using=self.alias) - - # Check if the collection doesn't exist - if utility.has_collection(ZILLIZ_COLLECTION, using=self.alias) is False: - # If it doesn't exist use the field params from init to create a new schema - schema = [field[1] for field in SCHEMA] - schema = CollectionSchema(schema) - # Use the schema to create a new collection - self.col = Collection( - ZILLIZ_COLLECTION, - schema=schema, - consistency_level="Strong", - using=self.alias, - ) - else: - # If the collection exists, point to it - self.col = Collection(ZILLIZ_COLLECTION, consistency_level="Strong", using=self.alias) # type: ignore - - # If no index on the collection, create one - if len(self.col.indexes) == 0: - i_p = {"metric_type": "L2", "index_type": "AUTOINDEX", "params": {}} - self.col.create_index("embedding", index_params=i_p) - - self.col.load() - - async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: - """Upsert chunks into the datastore. - - Args: - chunks (Dict[str, List[DocumentChunk]]): A list of DocumentChunks to insert - - Raises: - e: Error in upserting data. - - Returns: - List[str]: The document_id's that were inserted. - """ - # The doc id's to return for the upsert - doc_ids: List[str] = [] - # List to collect all the insert data - insert_data = [[] for _ in range(len(SCHEMA) - 1)] - # Go through each document chunklist and grab the data - for doc_id, chunk_list in chunks.items(): - # Append the doc_id to the list we are returning - doc_ids.append(doc_id) - # Examine each chunk in the chunklist - for chunk in chunk_list: - # Extract data from the chunk - list_of_data = self._get_values(chunk) - # Check if the data is valid - if list_of_data is not None: - # Append each field to the insert_data - for x in range(len(insert_data)): - insert_data[x].append(list_of_data[x]) - # Slice up our insert data into batches - batches = [ - insert_data[i : i + UPSERT_BATCH_SIZE] - for i in range(0, len(insert_data), UPSERT_BATCH_SIZE) - ] - - # Attempt to insert each batch into our collection - for batch in batches: - # Check if empty batch - if len(batch[0]) != 0: - try: - print(f"Upserting batch of size {len(batch[0])}") - self.col.insert(batch) - print(f"Upserted batch successfully") - except Exception as e: - print(f"Error upserting batch: {e}") - raise e - - # This setting performs flushes after insert. Small insert == bad to use - # self.col.flush() - - return doc_ids - - def _get_values(self, chunk: DocumentChunk) -> List[any] | None: # type: ignore - """Convert the chunk into a list of values to insert whose indexes align with fields. - - Args: - chunk (DocumentChunk): The chunk to convert. - - Returns: - List (any): The values to insert. - """ - # Convert DocumentChunk and its sub models to dict - values = chunk.dict() - # Unpack the metadata into the same dict - meta = values.pop("metadata") - values.update(meta) - - # Convert date to int timestamp form - if values["created_at"]: - values["created_at"] = to_unix_timestamp(values["created_at"]) - - # If source exists, change from Source object to the string value it holds - if values["source"]: - values["source"] = values["source"].value - # List to collect data we will return - ret = [] - # Grab data responding to each field excluding the hidden auto pk field - for key, _, default in SCHEMA[1:]: - # Grab the data at the key and default to our defaults set in init - x = values.get(key) or default - # If one of our required fields is missing, ignore the entire entry - if x is Required: - print("Chunk " + values["id"] + " missing " + key + " skipping") - return None - # Add the corresponding value if it passes the tests - ret.append(x) - return ret - - async def _query( - self, - queries: List[QueryWithEmbedding], - ) -> List[QueryResult]: - """Query the QueryWithEmbedding against the ZillizDocumentSearch - - Search the embedding and its filter in the collection. - - Args: - queries (List[QueryWithEmbedding]): The list of searches to perform. - - Returns: - List[QueryResult]: Results for each search. - """ - # Async to perform the query, adapted from pinecone implementation - async def _single_query(query: QueryWithEmbedding) -> QueryResult: - - filter = None - # Set the filter to expression that is valid for Zilliz - if query.filter != None: - # Either a valid filter or None will be returned - filter = self._get_filter(query.filter) - - # Perform our search - res = self.col.search( - data=[query.embedding], - anns_field="embedding", - param={"metric_type": "L2", "params": {}}, - limit=query.top_k, - expr=filter, - output_fields=[ - field[0] for field in SCHEMA[2:] - ], # Ignoring pk, embedding - ) - # Results that will hold our DocumentChunkWithScores - results = [] - # Parse every result for our search - for hit in res[0]: # type: ignore - # The distance score for the search result, falls under DocumentChunkWithScore - score = hit.score - # Our metadata info, falls under DocumentChunkMetadata - metadata = {} - # Grab the values that correspond to our fields, ignore pk and embedding. - for x in [field[0] for field in SCHEMA[2:]]: - metadata[x] = hit.entity.get(x) - # If the source isn't valid, convert to None - if metadata["source"] not in Source.__members__: - metadata["source"] = None - # Text falls under the DocumentChunk - text = metadata.pop("text") - # Id falls under the DocumentChunk - ids = metadata.pop("id") - chunk = DocumentChunkWithScore( - id=ids, - score=score, - text=text, - metadata=DocumentChunkMetadata(**metadata), - ) - results.append(chunk) - - # TODO: decide on doing queries to grab the embedding itself, slows down performance as double query occurs - - return QueryResult(query=query.query, results=results) - - results: List[QueryResult] = await asyncio.gather( - *[_single_query(query) for query in queries] - ) - return results - - async def delete( - self, - ids: Optional[List[str]] = None, - filter: Optional[DocumentMetadataFilter] = None, - delete_all: Optional[bool] = None, - ) -> bool: - """Delete the entities based either on the chunk_id of the vector, - - Args: - ids (Optional[List[str]], optional): The document_ids to delete. Defaults to None. - filter (Optional[DocumentMetadataFilter], optional): The filter to delete by. Defaults to None. - delete_all (Optional[bool], optional): Whether to drop the collection and recreate it. Defaults to None. - """ - # If deleting all, drop and create the new collection - if delete_all: - # Release the collection from memory - self.col.release() - # Drop the collection - self.col.drop() - # Recreate the new collection - self._create_collection(True) - return True - - # Keep track of how many we have deleted for later printing - delete_count = 0 - - # Check if empty ids - if ids != None: - if len(ids) != 0: - # Add quotation marks around the string format id - ids = ['"' + str(id) + '"' for id in ids] - # Query for the pk's of entries that match id's - ids = self.col.query(f"document_id in [{','.join(ids)}]") - # Convert to list of pks - ids = [str(entry["pk"]) for entry in ids] # type: ignore - # Check to see if there are valid pk's to delete - if len(ids) != 0: - # Delete the entries for each pk - res = self.col.delete(f"pk in [{','.join(ids)}]") - # Increment our deleted count - delete_count += int(res.delete_count) # type: ignore - - # Check if empty filter - if filter != None: - # Convert filter to Zilliz expression - filter = self._get_filter(filter) # type: ignore - # Check if there is anything to filter - if len(filter) != 0: # type: ignore - # Query for the pk's of entries that match filter - filter = self.col.query(filter) # type: ignore - # Convert to list of pks - filter = [str(entry["pk"]) for entry in filter] # type: ignore - # Check to see if there are valid pk's to delete - if len(filter) != 0: # type: ignore - # Delete the entries - res = self.col.delete(f"pk in [{','.join(filter)}]") # type: ignore - # Increment our delete count - delete_count += int(res.delete_count) # type: ignore - - # This setting performs flushes after delete. Small delete == bad to use - # self.col.flush() - - return True + def _create_index(self): + try: + # If no index on the collection, create one + if len(self.col.indexes) == 0: + self.index_params = {"metric_type": "IP", "index_type": "AUTOINDEX", "params": {}} + self.col.create_index("embedding", index_params=self.index_params) - def _get_filter(self, filter: DocumentMetadataFilter) -> Optional[str]: - """Converts a DocumentMetdataFilter to the expression that Zilliz takes. + self.col.load() + self.search_params = {"metric_type": "IP", "params": {}} + except Exception as e: + self._print_err("Failed to create index, error: {}".format(e)) - Args: - filter (DocumentMetadataFilter): The Filter to convert to Zilliz expression. - Returns: - Optional[str]: The filter if valid, otherwise None. - """ - filters = [] - # Go through all the fields and their values - for field, value in filter.dict().items(): - # Check if the Value is empty - if value is not None: - # Convert start_date to int and add greater than or equal logic - if field == "start_date": - filters.append( - "(created_at >= " + str(to_unix_timestamp(value)) + ")" - ) - # Convert end_date to int and add less than or equal logic - elif field == "end_date": - filters.append( - "(created_at <= " + str(to_unix_timestamp(value)) + ")" - ) - # Convert Source to its string value and check equivalency - elif field == "source": - filters.append("(" + field + ' == "' + str(value.value) + '")') - # Check equivalency of rest of string fields - else: - filters.append("(" + field + ' == "' + str(value) + '")') - # Join all our expressions with `and`` - return " and ".join(filters) diff --git a/docs/providers/milvus/setup.md b/docs/providers/milvus/setup.md index c13a088b6..eb0ca4b29 100644 --- a/docs/providers/milvus/setup.md +++ b/docs/providers/milvus/setup.md @@ -18,16 +18,19 @@ You can deploy and manage Milvus using Docker Compose, Helm, K8's Operator, or A **Environment Variables:** -| Name | Required | Description | -| ------------------- | -------- | ------------------------------------------------------ | -| `DATASTORE` | Yes | Datastore name, set to `milvus` | -| `BEARER_TOKEN` | Yes | Your bearer token | -| `OPENAI_API_KEY` | Yes | Your OpenAI API key | -| `MILVUS_COLLECTION` | Optional | Milvus collection name, defaults to a random UUID | -| `MILVUS_HOST` | Optional | Milvus host IP, defaults to `localhost` | -| `MILVUS_PORT` | Optional | Milvus port, defaults to `19530` | -| `MILVUS_USER` | Optional | Milvus username if RBAC is enabled, defaults to `None` | -| `MILVUS_PASSWORD` | Optional | Milvus password if required, defaults to `None` | +| Name | Required | Description | +|----------------------------| -------- |----------------------------------------------------------------------------------------------------------------------------------------------| +| `DATASTORE` | Yes | Datastore name, set to `milvus` | +| `BEARER_TOKEN` | Yes | Your bearer token | +| `OPENAI_API_KEY` | Yes | Your OpenAI API key | +| `MILVUS_COLLECTION` | Optional | Milvus collection name, defaults to a random UUID | +| `MILVUS_HOST` | Optional | Milvus host IP, defaults to `localhost` | +| `MILVUS_PORT` | Optional | Milvus port, defaults to `19530` | +| `MILVUS_USER` | Optional | Milvus username if RBAC is enabled, defaults to `None` | +| `MILVUS_PASSWORD` | Optional | Milvus password if required, defaults to `None` | +| `MILVUS_INDEX_PARAMS` | Optional | Custom index options for the collection, defaults to `{"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}` | +| `MILVUS_SEARCH_PARAMS` | Optional | Custom search options for the collection, defaults to `{"metric_type": "IP", "params": {"ef": 10}}` | +| `MILVUS_CONSISTENCY_LEVEL` | Optional | Data consistency level for the collection, defaults to `Bounded` | ## Running Milvus Integration Tests diff --git a/docs/providers/zilliz/setup.md b/docs/providers/zilliz/setup.md index 1735d059d..fddde3aa9 100644 --- a/docs/providers/zilliz/setup.md +++ b/docs/providers/zilliz/setup.md @@ -24,15 +24,16 @@ Zilliz Cloud is deployable in a few simple steps. First, create an account [here Environment Variables: -| Name | Required | Description | -| ------------------- | -------- | ------------------------------------------------- | -| `DATASTORE` | Yes | Datastore name, set to `zilliz` | -| `BEARER_TOKEN` | Yes | Your secret token | -| `OPENAI_API_KEY` | Yes | Your OpenAI API key | -| `ZILLIZ_COLLECTION` | Optional | Zilliz collection name. Defaults to a random UUID | -| `ZILLIZ_URI` | Yes | URI for the Zilliz instance | -| `ZILLIZ_USER` | Yes | Zilliz username | -| `ZILLIZ_PASSWORD` | Yes | Zilliz password | +| Name | Required | Description | +|----------------------------| -------- |------------------------------------------------------------------| +| `DATASTORE` | Yes | Datastore name, set to `zilliz` | +| `BEARER_TOKEN` | Yes | Your secret token | +| `OPENAI_API_KEY` | Yes | Your OpenAI API key | +| `ZILLIZ_COLLECTION` | Optional | Zilliz collection name. Defaults to a random UUID | +| `ZILLIZ_URI` | Yes | URI for the Zilliz instance | +| `ZILLIZ_USER` | Yes | Zilliz username | +| `ZILLIZ_PASSWORD` | Yes | Zilliz password | +| `ZILLIZ_CONSISTENCY_LEVEL` | Optional | Data consistency level for the collection, defaults to `Bounded` | ## Running Zilliz Integration Tests diff --git a/tests/datastore/providers/milvus/test_milvus_datastore.py b/tests/datastore/providers/milvus/test_milvus_datastore.py index b82b6de90..d15b1a007 100644 --- a/tests/datastore/providers/milvus/test_milvus_datastore.py +++ b/tests/datastore/providers/milvus/test_milvus_datastore.py @@ -19,9 +19,23 @@ @pytest.fixture def milvus_datastore(): - return MilvusDataStore() + return MilvusDataStore(consistency_level = "Strong") +def sample_embedding(one_element_poz: int): + embedding = [0] * OUTPUT_DIM + embedding[one_element_poz % OUTPUT_DIM] = 1 + return embedding + +def sample_embeddings(num: int, one_element_start: int = 0): + # since metric type is consine, we create vector contains only one element 1, others 0 + embeddings = [] + for x in range(num): + embedding = [0] * OUTPUT_DIM + embedding[(x + one_element_start) % OUTPUT_DIM] = 1 + embeddings.append(embedding) + return embeddings + @pytest.fixture def document_chunk_one(): doc_id = "zerp" @@ -42,7 +56,8 @@ def document_chunk_one(): "2021-01-21T10:00:00-02:00", ] authors = ["Max Mustermann", "John Doe", "Jane Doe"] - embeddings = [[x] * OUTPUT_DIM for x in range(3)] + + embeddings = sample_embeddings(len(texts)) for i in range(3): chunk = DocumentChunk( @@ -84,7 +99,7 @@ def document_chunk_two(): "3021-01-21T10:00:00-02:00", ] authors = ["Max Mustermann", "John Doe", "Jane Doe"] - embeddings = [[x] * OUTPUT_DIM for x in range(3)] + embeddings = sample_embeddings(len(texts)) for i in range(3): chunk = DocumentChunk( @@ -121,7 +136,7 @@ def document_chunk_two(): "6021-01-21T10:00:00-02:00", ] authors = ["Max Mustermann", "John Doe", "Jane Doe"] - embeddings = [[x] * OUTPUT_DIM for x in range(3, 6)] + embeddings = sample_embeddings(len(texts), 3) for i in range(3): chunk = DocumentChunk( @@ -150,6 +165,7 @@ async def test_upsert(milvus_datastore, document_chunk_one): assert res == list(document_chunk_one.keys()) milvus_datastore.col.flush() assert 3 == milvus_datastore.col.num_entities + milvus_datastore.col.drop() @pytest.mark.asyncio @@ -160,6 +176,7 @@ async def test_reload(milvus_datastore, document_chunk_one, document_chunk_two): assert res == list(document_chunk_one.keys()) milvus_datastore.col.flush() assert 3 == milvus_datastore.col.num_entities + new_store = MilvusDataStore() another_in = {i: document_chunk_two[i] for i in document_chunk_two if i != res[0]} res = await new_store._upsert(another_in) @@ -168,10 +185,11 @@ async def test_reload(milvus_datastore, document_chunk_one, document_chunk_two): query = QueryWithEmbedding( query="lorem", top_k=10, - embedding=[0.5] * OUTPUT_DIM, + embedding=sample_embedding(0), ) query_results = await milvus_datastore._query(queries=[query]) assert 1 == len(query_results) + new_store.col.drop() @pytest.mark.asyncio @@ -185,12 +203,13 @@ async def test_upsert_query_all(milvus_datastore, document_chunk_two): query = QueryWithEmbedding( query="lorem", top_k=10, - embedding=[0.5] * OUTPUT_DIM, + embedding=sample_embedding(0), ) query_results = await milvus_datastore._query(queries=[query]) assert 1 == len(query_results) assert 6 == len(query_results[0].results) + milvus_datastore.col.drop() @pytest.mark.asyncio @@ -202,14 +221,15 @@ async def test_query_accuracy(milvus_datastore, document_chunk_one): query = QueryWithEmbedding( query="lorem", top_k=1, - embedding=[0] * OUTPUT_DIM, + embedding=sample_embedding(0), ) query_results = await milvus_datastore._query(queries=[query]) assert 1 == len(query_results) assert 1 == len(query_results[0].results) - assert 0 == query_results[0].results[0].score + assert 1.0 == query_results[0].results[0].score assert "abc_123" == query_results[0].results[0].id + milvus_datastore.col.drop() @pytest.mark.asyncio @@ -221,7 +241,7 @@ async def test_query_filter(milvus_datastore, document_chunk_one): query = QueryWithEmbedding( query="lorem", top_k=1, - embedding=[0] * OUTPUT_DIM, + embedding=sample_embedding(0), filter=DocumentMetadataFilter( start_date="2000-01-03T16:39:57-08:00", end_date="2010-01-03T16:39:57-08:00" ), @@ -230,8 +250,9 @@ async def test_query_filter(milvus_datastore, document_chunk_one): assert 1 == len(query_results) assert 1 == len(query_results[0].results) - assert 0 != query_results[0].results[0].score + assert 1.0 != query_results[0].results[0].score assert "def_456" == query_results[0].results[0].id + milvus_datastore.col.drop() @pytest.mark.asyncio @@ -249,13 +270,14 @@ async def test_delete_with_date_filter(milvus_datastore, document_chunk_one): query = QueryWithEmbedding( query="lorem", top_k=9, - embedding=[0] * OUTPUT_DIM, + embedding=sample_embedding(0), ) query_results = await milvus_datastore._query(queries=[query]) assert 1 == len(query_results) assert 1 == len(query_results[0].results) assert "ghi_789" == query_results[0].results[0].id + milvus_datastore.col.drop() @pytest.mark.asyncio @@ -273,13 +295,14 @@ async def test_delete_with_source_filter(milvus_datastore, document_chunk_one): query = QueryWithEmbedding( query="lorem", top_k=9, - embedding=[0] * OUTPUT_DIM, + embedding=sample_embedding(0), ) query_results = await milvus_datastore._query(queries=[query]) assert 1 == len(query_results) assert 2 == len(query_results[0].results) assert "def_456" == query_results[0].results[0].id + milvus_datastore.col.drop() @pytest.mark.asyncio @@ -296,12 +319,13 @@ async def test_delete_with_document_id_filter(milvus_datastore, document_chunk_o query = QueryWithEmbedding( query="lorem", top_k=9, - embedding=[0] * OUTPUT_DIM, + embedding=sample_embedding(0), ) query_results = await milvus_datastore._query(queries=[query]) assert 1 == len(query_results) assert 0 == len(query_results[0].results) + milvus_datastore.col.drop() @pytest.mark.asyncio @@ -315,12 +339,13 @@ async def test_delete_with_document_id(milvus_datastore, document_chunk_one): query = QueryWithEmbedding( query="lorem", top_k=9, - embedding=[0] * OUTPUT_DIM, + embedding=sample_embedding(0), ) query_results = await milvus_datastore._query(queries=[query]) assert 1 == len(query_results) assert 0 == len(query_results[0].results) + milvus_datastore.col.drop() # if __name__ == '__main__': diff --git a/tests/datastore/providers/zilliz/test_zilliz_datastore.py b/tests/datastore/providers/zilliz/test_zilliz_datastore.py index d9d102530..f790797a8 100644 --- a/tests/datastore/providers/zilliz/test_zilliz_datastore.py +++ b/tests/datastore/providers/zilliz/test_zilliz_datastore.py @@ -4,327 +4,26 @@ # load_dotenv(dotenv_path=env_path, verbose=True) import pytest -from models.models import ( - DocumentChunkMetadata, - DocumentMetadataFilter, - DocumentChunk, - Query, - QueryWithEmbedding, - Source, -) + from datastore.providers.zilliz_datastore import ( - OUTPUT_DIM, ZillizDataStore, ) +from datastore.providers.milvus_datastore import ( + EMBEDDING_FIELD, +) + +# Note: Only do basic test here, the ZillizDataStore is derived from MilvusDataStore. @pytest.fixture def zilliz_datastore(): return ZillizDataStore() -@pytest.fixture -def document_chunk_one(): - doc_id = "zerp" - doc_chunks = [] - - ids = ["abc_123", "def_456", "ghi_789"] - texts = [ - "lorem ipsum dolor sit amet", - "consectetur adipiscing elit", - "sed do eiusmod tempor incididunt", - ] - sources = [Source.email, Source.file, Source.chat] - source_ids = ["foo", "bar", "baz"] - urls = ["foo.com", "bar.net", "baz.org"] - created_ats = [ - "1929-10-28T09:30:00-05:00", - "2009-01-03T16:39:57-08:00", - "2021-01-21T10:00:00-02:00", - ] - authors = ["Max Mustermann", "John Doe", "Jane Doe"] - embeddings = [[x] * OUTPUT_DIM for x in range(3)] - - for i in range(3): - chunk = DocumentChunk( - id=ids[i], - text=texts[i], - metadata=DocumentChunkMetadata( - document_id=doc_id, - source=sources[i], - source_id=source_ids[i], - url=urls[i], - created_at=created_ats[i], - author=authors[i], - ), - embedding=embeddings[i], # type: ignore - ) - - doc_chunks.append(chunk) - - return {doc_id: doc_chunks} - - -@pytest.fixture -def document_chunk_two(): - doc_id_1 = "zerp" - doc_chunks_1 = [] - - ids = ["abc_123", "def_456", "ghi_789"] - texts = [ - "1lorem ipsum dolor sit amet", - "2consectetur adipiscing elit", - "3sed do eiusmod tempor incididunt", - ] - sources = [Source.email, Source.file, Source.chat] - source_ids = ["foo", "bar", "baz"] - urls = ["foo.com", "bar.net", "baz.org"] - created_ats = [ - "1929-10-28T09:30:00-05:00", - "2009-01-03T16:39:57-08:00", - "3021-01-21T10:00:00-02:00", - ] - authors = ["Max Mustermann", "John Doe", "Jane Doe"] - embeddings = [[x] * OUTPUT_DIM for x in range(3)] - - for i in range(3): - chunk = DocumentChunk( - id=ids[i], - text=texts[i], - metadata=DocumentChunkMetadata( - document_id=doc_id_1, - source=sources[i], - source_id=source_ids[i], - url=urls[i], - created_at=created_ats[i], - author=authors[i], - ), - embedding=embeddings[i], # type: ignore - ) - - doc_chunks_1.append(chunk) - - doc_id_2 = "merp" - doc_chunks_2 = [] - - ids = ["jkl_123", "lmn_456", "opq_789"] - texts = [ - "3sdsc efac feas sit qweas", - "4wert sdfas fdsc", - "52dsc fdsf eiusmod asdasd incididunt", - ] - sources = [Source.email, Source.file, Source.chat] - source_ids = ["foo", "bar", "baz"] - urls = ["foo.com", "bar.net", "baz.org"] - created_ats = [ - "4929-10-28T09:30:00-05:00", - "5009-01-03T16:39:57-08:00", - "6021-01-21T10:00:00-02:00", - ] - authors = ["Max Mustermann", "John Doe", "Jane Doe"] - embeddings = [[x] * OUTPUT_DIM for x in range(3, 6)] - - for i in range(3): - chunk = DocumentChunk( - id=ids[i], - text=texts[i], - metadata=DocumentChunkMetadata( - document_id=doc_id_2, - source=sources[i], - source_id=source_ids[i], - url=urls[i], - created_at=created_ats[i], - author=authors[i], - ), - embedding=embeddings[i], # type: ignore - ) - - doc_chunks_2.append(chunk) - - return {doc_id_1: doc_chunks_1, doc_id_2: doc_chunks_2} - - -@pytest.mark.asyncio -async def test_upsert(zilliz_datastore, document_chunk_one): - await zilliz_datastore.delete(delete_all=True) - res = await zilliz_datastore._upsert(document_chunk_one) - assert res == list(document_chunk_one.keys()) - zilliz_datastore.col.flush() - assert 3 == zilliz_datastore.col.num_entities - - -@pytest.mark.asyncio -async def test_reload(zilliz_datastore, document_chunk_one, document_chunk_two): - await zilliz_datastore.delete(delete_all=True) - - res = await zilliz_datastore._upsert(document_chunk_one) - assert res == list(document_chunk_one.keys()) - zilliz_datastore.col.flush() - assert 3 == zilliz_datastore.col.num_entities - new_store = ZillizDataStore() - another_in = {i: document_chunk_two[i] for i in document_chunk_two if i != res[0]} - res = await new_store._upsert(another_in) - new_store.col.flush() - assert 6 == new_store.col.num_entities - query = QueryWithEmbedding( - query="lorem", - top_k=10, - embedding=[0.5] * OUTPUT_DIM, - ) - query_results = await zilliz_datastore._query(queries=[query]) - assert 1 == len(query_results) - - -@pytest.mark.asyncio -async def test_upsert_and_query_all(zilliz_datastore, document_chunk_two): - await zilliz_datastore.delete(delete_all=True) - res = await zilliz_datastore._upsert(document_chunk_two) - assert res == list(document_chunk_two.keys()) - zilliz_datastore.col.flush() - - # Num entities currently doesn't track deletes - query = QueryWithEmbedding( - query="lorem", - top_k=9, - embedding=[0.5] * OUTPUT_DIM, - ) - query_results = await zilliz_datastore._query(queries=[query]) - - assert 1 == len(query_results) - assert 6 == len(query_results[0].results) - - -@pytest.mark.asyncio -async def test_query_accuracy(zilliz_datastore, document_chunk_one): - await zilliz_datastore.delete(delete_all=True) - res = await zilliz_datastore._upsert(document_chunk_one) - assert res == list(document_chunk_one.keys()) - zilliz_datastore.col.flush() - query = QueryWithEmbedding( - query="lorem", - top_k=1, - embedding=[0] * OUTPUT_DIM, - ) - query_results = await zilliz_datastore._query(queries=[query]) - - assert 1 == len(query_results) - assert 1 == len(query_results[0].results) - assert 0 == query_results[0].results[0].score - assert "abc_123" == query_results[0].results[0].id - - -@pytest.mark.asyncio -async def test_query_filter(zilliz_datastore, document_chunk_one): - await zilliz_datastore.delete(delete_all=True) - res = await zilliz_datastore._upsert(document_chunk_one) - assert res == list(document_chunk_one.keys()) - zilliz_datastore.col.flush() - query = QueryWithEmbedding( - query="lorem", - top_k=1, - embedding=[0] * OUTPUT_DIM, - filter=DocumentMetadataFilter( - start_date="2000-01-03T16:39:57-08:00", end_date="2010-01-03T16:39:57-08:00" - ), - ) - query_results = await zilliz_datastore._query(queries=[query]) - - assert 1 == len(query_results) - assert 1 == len(query_results[0].results) - assert 0 != query_results[0].results[0].score - assert "def_456" == query_results[0].results[0].id - - -@pytest.mark.asyncio -async def test_delete_with_date_filter(zilliz_datastore, document_chunk_one): - await zilliz_datastore.delete(delete_all=True) - res = await zilliz_datastore._upsert(document_chunk_one) - assert res == list(document_chunk_one.keys()) - zilliz_datastore.col.flush() - await zilliz_datastore.delete( - filter=DocumentMetadataFilter( - end_date="2009-01-03T16:39:57-08:00", - ) - ) - - query = QueryWithEmbedding( - query="lorem", - top_k=9, - embedding=[0] * OUTPUT_DIM, - ) - query_results = await zilliz_datastore._query(queries=[query]) - - assert 1 == len(query_results) - assert 1 == len(query_results[0].results) - assert "ghi_789" == query_results[0].results[0].id - - -@pytest.mark.asyncio -async def test_delete_with_source_filter(zilliz_datastore, document_chunk_one): - await zilliz_datastore.delete(delete_all=True) - res = await zilliz_datastore._upsert(document_chunk_one) - assert res == list(document_chunk_one.keys()) - zilliz_datastore.col.flush() - await zilliz_datastore.delete( - filter=DocumentMetadataFilter( - source=Source.email, - ) - ) - - query = QueryWithEmbedding( - query="lorem", - top_k=9, - embedding=[0] * OUTPUT_DIM, - ) - query_results = await zilliz_datastore._query(queries=[query]) - - assert 1 == len(query_results) - assert 2 == len(query_results[0].results) - assert "def_456" == query_results[0].results[0].id - - @pytest.mark.asyncio -async def test_delete_with_document_id_filter(zilliz_datastore, document_chunk_one): - await zilliz_datastore.delete(delete_all=True) - res = await zilliz_datastore._upsert(document_chunk_one) - assert res == list(document_chunk_one.keys()) - zilliz_datastore.col.flush() - await zilliz_datastore.delete( - filter=DocumentMetadataFilter( - document_id=res[0], - ) - ) - query = QueryWithEmbedding( - query="lorem", - top_k=9, - embedding=[0] * OUTPUT_DIM, - ) - query_results = await zilliz_datastore._query(queries=[query]) - - assert 1 == len(query_results) - assert 0 == len(query_results[0].results) - - -@pytest.mark.asyncio -async def test_delete_with_document_id(zilliz_datastore, document_chunk_one): - await zilliz_datastore.delete(delete_all=True) - res = await zilliz_datastore._upsert(document_chunk_one) - assert res == list(document_chunk_one.keys()) - zilliz_datastore.col.flush() - await zilliz_datastore.delete([res[0]]) - - query = QueryWithEmbedding( - query="lorem", - top_k=9, - embedding=[0] * OUTPUT_DIM, - ) - query_results = await zilliz_datastore._query(queries=[query]) - - assert 1 == len(query_results) - assert 0 == len(query_results[0].results) - - -# if __name__ == '__main__': -# import sys -# import pytest -# pytest.main(sys.argv) +async def test_zilliz(zilliz_datastore): + assert True == zilliz_datastore.col.has_index() + index_list = [x.to_dict() for x in zilliz_datastore.col.indexes] + for index in index_list: + if index['index_name'] == EMBEDDING_FIELD: + assert 'AUTOINDEX' == index['index_param']['index_type'] \ No newline at end of file