diff --git a/datastore/providers/redis_datastore.py b/datastore/providers/redis_datastore.py index 077e9b5f6..1ca1a0ee8 100644 --- a/datastore/providers/redis_datastore.py +++ b/datastore/providers/redis_datastore.py @@ -45,25 +45,6 @@ {"name": "ReJSON", "ver": 20404} ] REDIS_DEFAULT_ESCAPED_CHARS = re.compile(r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]") -REDIS_SEARCH_SCHEMA = { - "document_id": TagField("$.document_id", as_name="document_id"), - "metadata": { - # "source_id": TagField("$.metadata.source_id", as_name="source_id"), - "source": TagField("$.metadata.source", as_name="source"), - # "author": TextField("$.metadata.author", as_name="author"), - # "created_at": NumericField("$.metadata.created_at", as_name="created_at"), - }, - "embedding": VectorField( - "$.embedding", - REDIS_INDEX_TYPE, - { - "TYPE": "FLOAT64", - "DIM": VECTOR_DIMENSION, - "DISTANCE_METRIC": REDIS_DISTANCE_METRIC, - }, - as_name="embedding", - ), -} # Helper functions def unpack_schema(d: dict): @@ -82,22 +63,23 @@ async def _check_redis_module_exist(client: redis.Redis, modules: List[dict]): error_message = "You must add the RediSearch (>= 2.6) and ReJSON (>= 2.4) modules from Redis Stack. " \ "Please refer to Redis Stack docs: https://redis.io/docs/stack/" logging.error(error_message) - raise ValueError(error_message) + raise AttributeError(error_message) class RedisDataStore(DataStore): - def __init__(self, client: redis.Redis): + def __init__(self, client: redis.Redis, redisearch_schema): self.client = client + self._schema = redisearch_schema # Init default metadata with sentinel values in case the document written has no metadata self._default_metadata = { - field: "_null_" for field in REDIS_SEARCH_SCHEMA["metadata"] + field: "_null_" for field in redisearch_schema["metadata"] } ### Redis Helper Methods ### @classmethod - async def init(cls): + async def init(cls, **kwargs): """ Setup the index if it does not exist. """ @@ -112,7 +94,27 @@ async def init(cls): raise e await _check_redis_module_exist(client, modules=REDIS_REQUIRED_MODULES) - + + dim = kwargs.get("dim", VECTOR_DIMENSION) + redisearch_schema = { + "document_id": TagField("$.document_id", as_name="document_id"), + "metadata": { + "source_id": TagField("$.metadata.source_id", as_name="source_id"), + "source": TagField("$.metadata.source", as_name="source"), + "author": TextField("$.metadata.author", as_name="author"), + "created_at": NumericField("$.metadata.created_at", as_name="created_at"), + }, + "embedding": VectorField( + "$.embedding", + REDIS_INDEX_TYPE, + { + "TYPE": "FLOAT64", + "DIM": dim, + "DISTANCE_METRIC": REDIS_DISTANCE_METRIC, + }, + as_name="embedding", + ), + } try: # Check for existence of RediSearch Index await client.ft(REDIS_INDEX_NAME).info() @@ -123,11 +125,12 @@ async def init(cls): definition = IndexDefinition( prefix=[REDIS_DOC_PREFIX], index_type=IndexType.JSON ) - fields = list(unpack_schema(REDIS_SEARCH_SCHEMA)) + fields = list(unpack_schema(redisearch_schema)) + logging.info(f"Creating index with fields: {fields}") await client.ft(REDIS_INDEX_NAME).create_index( fields=fields, definition=definition ) - return cls(client) + return cls(client, redisearch_schema) @staticmethod def _redis_key(document_id: str, chunk_id: str) -> str: @@ -217,20 +220,21 @@ def _typ_to_str(typ, field, value) -> str: # type: ignore # Build filter if query.filter: + redisearch_schema = self._schema for field, value in query.filter.__dict__.items(): if not value: continue - if field in REDIS_SEARCH_SCHEMA: - filter_str += _typ_to_str(REDIS_SEARCH_SCHEMA[field], field, value) - elif field in REDIS_SEARCH_SCHEMA["metadata"]: + if field in redisearch_schema: + filter_str += _typ_to_str(redisearch_schema[field], field, value) + elif field in redisearch_schema["metadata"]: if field == "source": # handle the enum value = value.value filter_str += _typ_to_str( - REDIS_SEARCH_SCHEMA["metadata"][field], field, value + redisearch_schema["metadata"][field], field, value ) elif field in ["start_date", "end_date"]: filter_str += _typ_to_str( - REDIS_SEARCH_SCHEMA["metadata"]["created_at"], field, value + redisearch_schema["metadata"]["created_at"], field, value ) # Postprocess filter string diff --git a/docs/providers/redis/setup.md b/docs/providers/redis/setup.md index 894e00f79..37f993941 100644 --- a/docs/providers/redis/setup.md +++ b/docs/providers/redis/setup.md @@ -21,3 +21,17 @@ | `REDIS_DOC_PREFIX` | Optional | Redis key prefix for the index | `doc` | | `REDIS_DISTANCE_METRIC` | Optional | Vector similarity distance metric | `COSINE` | | `REDIS_INDEX_TYPE` | Optional | [Vector index algorithm type](https://redis.io/docs/stack/search/reference/vectors/#creation-attributes-per-algorithm) | `FLAT` | + + +## Redis Datastore development & testing +In order to test your changes to the Redis Datastore, you can run the following commands: + +```bash +# Run the Redis stack docker image +docker run -it --rm -p 6379:6379 redis/redis-stack-server:latest +``` + +```bash +# Run the Redis datastore tests +poetry run pytest -s ./tests/datastore/providers/redis/test_redis_datastore.py +``` \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index fc7ca5fb2..8d2fc24e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,3 +41,9 @@ pytest-asyncio = "^0.20.3" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.pytest.ini_options] +pythonpath = [ + "." +] +asyncio_mode="auto" diff --git a/tests/datastore/providers/redis/test_redis_datastore.py b/tests/datastore/providers/redis/test_redis_datastore.py new file mode 100644 index 000000000..b91d9211c --- /dev/null +++ b/tests/datastore/providers/redis/test_redis_datastore.py @@ -0,0 +1,45 @@ +from datastore.providers.redis_datastore import RedisDataStore +import datastore.providers.redis_datastore as static_redis +from models.models import DocumentChunk, DocumentChunkMetadata, QueryWithEmbedding, Source +import pytest +import redis.asyncio as redis +import numpy as np + +@pytest.fixture +async def redis_datastore(): + return await RedisDataStore.init(dim=5) + + +def create_embedding(i, dim): + vec = np.array([0.1] * dim).astype(np.float64).tolist() + vec[dim-1] = i+1/10 + return vec + +def create_document_chunk(i, dim): + return DocumentChunk( + id=f"first-doc_{i}", + text=f"Lorem ipsum {i}", + embedding=create_embedding(i, dim), + metadata=DocumentChunkMetadata( + source=Source.file, created_at="1970-01-01", document_id=f"doc-{i}" + ), + ) + +def create_document_chunks(n, dim): + docs = [create_document_chunk(i, dim) for i in range(n)] + return {"docs": docs} + +@pytest.mark.asyncio +async def test_redis_upsert_query(redis_datastore): + docs = create_document_chunks(10, 5) + await redis_datastore._upsert(docs) + query = QueryWithEmbedding( + query="Lorem ipsum 0", + top_k=5, + embedding= create_embedding(0, 5), + ) + query_results = await redis_datastore._query(queries=[query]) + assert 1 == len(query_results) + for i in range(5): + assert f"Lorem ipsum {i}" == query_results[0].results[i].text + assert f"doc-{i}" == query_results[0].results[i].id