diff --git a/datastore/factory.py b/datastore/factory.py index 5a7e13c19..48cd7c40f 100644 --- a/datastore/factory.py +++ b/datastore/factory.py @@ -36,4 +36,7 @@ async def get_datastore() -> DataStore: return QdrantDataStore() case _: - raise ValueError(f"Unsupported vector database: {datastore}") + raise ValueError( + f"Unsupported vector database: {datastore}. " + f"Try one of the following: llama, pinecone, weaviate, milvus, zilliz, redis, or qdrant" + ) \ No newline at end of file diff --git a/datastore/providers/__init__.py b/datastore/providers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/datastore/providers/redis_datastore.py b/datastore/providers/redis_datastore.py index 1ca1a0ee8..5e28e2f25 100644 --- a/datastore/providers/redis_datastore.py +++ b/datastore/providers/redis_datastore.py @@ -55,7 +55,6 @@ def unpack_schema(d: dict): yield v async def _check_redis_module_exist(client: redis.Redis, modules: List[dict]): - installed_modules = (await client.info()).get("modules", []) installed_modules = {module["name"]: module for module in installed_modules} for module in modules: @@ -66,14 +65,13 @@ async def _check_redis_module_exist(client: redis.Redis, modules: List[dict]): raise AttributeError(error_message) - class RedisDataStore(DataStore): - def __init__(self, client: redis.Redis, redisearch_schema): + def __init__(self, client: redis.Redis, redisearch_schema: dict): self.client = client self._schema = redisearch_schema # Init default metadata with sentinel values in case the document written has no metadata self._default_metadata = { - field: "_null_" for field in redisearch_schema["metadata"] + field: (0 if field == "created_at" else "_null_") for field in redisearch_schema["metadata"] } ### Redis Helper Methods ### @@ -94,11 +92,11 @@ async def init(cls, **kwargs): raise e await _check_redis_module_exist(client, modules=REDIS_REQUIRED_MODULES) - + dim = kwargs.get("dim", VECTOR_DIMENSION) redisearch_schema = { - "document_id": TagField("$.document_id", as_name="document_id"), "metadata": { + "document_id": TagField("$.metadata.document_id", as_name="document_id"), "source_id": TagField("$.metadata.source_id", as_name="source_id"), "source": TagField("$.metadata.source", as_name="source"), "author": TextField("$.metadata.author", as_name="author"), @@ -300,7 +298,7 @@ async def _query( results: List[QueryResult] = [] # Gather query results in a pipeline - logging.info(f"Gathering {len(queries)} query results", flush=True) + logging.info(f"Gathering {len(queries)} query results") for query in queries: logging.info(f"Query: {query.query}") diff --git a/tests/datastore/providers/redis/test_redis_datastore.py b/tests/datastore/providers/redis/test_redis_datastore.py index b91d9211c..6d899881c 100644 --- a/tests/datastore/providers/redis/test_redis_datastore.py +++ b/tests/datastore/providers/redis/test_redis_datastore.py @@ -1,15 +1,15 @@ from datastore.providers.redis_datastore import RedisDataStore -import datastore.providers.redis_datastore as static_redis -from models.models import DocumentChunk, DocumentChunkMetadata, QueryWithEmbedding, Source +from models.models import DocumentChunk, DocumentChunkMetadata, QueryWithEmbedding, Source, DocumentMetadataFilter import pytest import redis.asyncio as redis import numpy as np +NUM_TEST_DOCS = 10 + @pytest.fixture async def redis_datastore(): return await RedisDataStore.init(dim=5) - def create_embedding(i, dim): vec = np.array([0.1] * dim).astype(np.float64).tolist() vec[dim-1] = i+1/10 @@ -21,7 +21,7 @@ def create_document_chunk(i, dim): text=f"Lorem ipsum {i}", embedding=create_embedding(i, dim), metadata=DocumentChunkMetadata( - source=Source.file, created_at="1970-01-01", document_id=f"doc-{i}" + source=Source.file, created_at="1970-01-01", document_id="docs" ), ) @@ -31,7 +31,7 @@ def create_document_chunks(n, dim): @pytest.mark.asyncio async def test_redis_upsert_query(redis_datastore): - docs = create_document_chunks(10, 5) + docs = create_document_chunks(NUM_TEST_DOCS, 5) await redis_datastore._upsert(docs) query = QueryWithEmbedding( query="Lorem ipsum 0", @@ -42,4 +42,23 @@ async def test_redis_upsert_query(redis_datastore): assert 1 == len(query_results) for i in range(5): assert f"Lorem ipsum {i}" == query_results[0].results[i].text - assert f"doc-{i}" == query_results[0].results[i].id + assert "docs" == query_results[0].results[i].id + +@pytest.mark.asyncio +async def test_redis_filter_query(redis_datastore): + query = QueryWithEmbedding( + query="Lorem ipsum 0", + filter=DocumentMetadataFilter(document_id="docs"), + top_k=5, + embedding= create_embedding(0, 5), + ) + query_results = await redis_datastore._query(queries=[query]) + print(query_results) + assert 1 == len(query_results) + assert "docs" == query_results[0].results[0].id + + +@pytest.mark.asyncio +async def test_redis_delete_docs(redis_datastore): + res = await redis_datastore.delete(ids=["docs"]) + assert res