Skip to content

Commit

Permalink
first redis test (openai#122)
Browse files Browse the repository at this point in the history
* first redis test

* fixed test

* pr comment

* added docs
  • Loading branch information
DvirDukhan authored Apr 11, 2023
1 parent 15b1169 commit 71650a5
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 31 deletions.
66 changes: 35 additions & 31 deletions datastore/providers/redis_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,6 @@
{"name": "ReJSON", "ver": 20404}
]
REDIS_DEFAULT_ESCAPED_CHARS = re.compile(r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]")
REDIS_SEARCH_SCHEMA = {
"document_id": TagField("$.document_id", as_name="document_id"),
"metadata": {
# "source_id": TagField("$.metadata.source_id", as_name="source_id"),
"source": TagField("$.metadata.source", as_name="source"),
# "author": TextField("$.metadata.author", as_name="author"),
# "created_at": NumericField("$.metadata.created_at", as_name="created_at"),
},
"embedding": VectorField(
"$.embedding",
REDIS_INDEX_TYPE,
{
"TYPE": "FLOAT64",
"DIM": VECTOR_DIMENSION,
"DISTANCE_METRIC": REDIS_DISTANCE_METRIC,
},
as_name="embedding",
),
}

# Helper functions
def unpack_schema(d: dict):
Expand All @@ -82,22 +63,23 @@ async def _check_redis_module_exist(client: redis.Redis, modules: List[dict]):
error_message = "You must add the RediSearch (>= 2.6) and ReJSON (>= 2.4) modules from Redis Stack. " \
"Please refer to Redis Stack docs: https://redis.io/docs/stack/"
logging.error(error_message)
raise ValueError(error_message)
raise AttributeError(error_message)



class RedisDataStore(DataStore):
def __init__(self, client: redis.Redis):
def __init__(self, client: redis.Redis, redisearch_schema):
self.client = client
self._schema = redisearch_schema
# Init default metadata with sentinel values in case the document written has no metadata
self._default_metadata = {
field: "_null_" for field in REDIS_SEARCH_SCHEMA["metadata"]
field: "_null_" for field in redisearch_schema["metadata"]
}

### Redis Helper Methods ###

@classmethod
async def init(cls):
async def init(cls, **kwargs):
"""
Setup the index if it does not exist.
"""
Expand All @@ -112,7 +94,27 @@ async def init(cls):
raise e

await _check_redis_module_exist(client, modules=REDIS_REQUIRED_MODULES)


dim = kwargs.get("dim", VECTOR_DIMENSION)
redisearch_schema = {
"document_id": TagField("$.document_id", as_name="document_id"),
"metadata": {
"source_id": TagField("$.metadata.source_id", as_name="source_id"),
"source": TagField("$.metadata.source", as_name="source"),
"author": TextField("$.metadata.author", as_name="author"),
"created_at": NumericField("$.metadata.created_at", as_name="created_at"),
},
"embedding": VectorField(
"$.embedding",
REDIS_INDEX_TYPE,
{
"TYPE": "FLOAT64",
"DIM": dim,
"DISTANCE_METRIC": REDIS_DISTANCE_METRIC,
},
as_name="embedding",
),
}
try:
# Check for existence of RediSearch Index
await client.ft(REDIS_INDEX_NAME).info()
Expand All @@ -123,11 +125,12 @@ async def init(cls):
definition = IndexDefinition(
prefix=[REDIS_DOC_PREFIX], index_type=IndexType.JSON
)
fields = list(unpack_schema(REDIS_SEARCH_SCHEMA))
fields = list(unpack_schema(redisearch_schema))
logging.info(f"Creating index with fields: {fields}")
await client.ft(REDIS_INDEX_NAME).create_index(
fields=fields, definition=definition
)
return cls(client)
return cls(client, redisearch_schema)

@staticmethod
def _redis_key(document_id: str, chunk_id: str) -> str:
Expand Down Expand Up @@ -217,20 +220,21 @@ def _typ_to_str(typ, field, value) -> str: # type: ignore

# Build filter
if query.filter:
redisearch_schema = self._schema
for field, value in query.filter.__dict__.items():
if not value:
continue
if field in REDIS_SEARCH_SCHEMA:
filter_str += _typ_to_str(REDIS_SEARCH_SCHEMA[field], field, value)
elif field in REDIS_SEARCH_SCHEMA["metadata"]:
if field in redisearch_schema:
filter_str += _typ_to_str(redisearch_schema[field], field, value)
elif field in redisearch_schema["metadata"]:
if field == "source": # handle the enum
value = value.value
filter_str += _typ_to_str(
REDIS_SEARCH_SCHEMA["metadata"][field], field, value
redisearch_schema["metadata"][field], field, value
)
elif field in ["start_date", "end_date"]:
filter_str += _typ_to_str(
REDIS_SEARCH_SCHEMA["metadata"]["created_at"], field, value
redisearch_schema["metadata"]["created_at"], field, value
)

# Postprocess filter string
Expand Down
14 changes: 14 additions & 0 deletions docs/providers/redis/setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,17 @@
| `REDIS_DOC_PREFIX` | Optional | Redis key prefix for the index | `doc` |
| `REDIS_DISTANCE_METRIC` | Optional | Vector similarity distance metric | `COSINE` |
| `REDIS_INDEX_TYPE` | Optional | [Vector index algorithm type](https://redis.io/docs/stack/search/reference/vectors/#creation-attributes-per-algorithm) | `FLAT` |


## Redis Datastore development & testing
In order to test your changes to the Redis Datastore, you can run the following commands:

```bash
# Run the Redis stack docker image
docker run -it --rm -p 6379:6379 redis/redis-stack-server:latest
```

```bash
# Run the Redis datastore tests
poetry run pytest -s ./tests/datastore/providers/redis/test_redis_datastore.py
```
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,9 @@ pytest-asyncio = "^0.20.3"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.pytest.ini_options]
pythonpath = [
"."
]
asyncio_mode="auto"
45 changes: 45 additions & 0 deletions tests/datastore/providers/redis/test_redis_datastore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from datastore.providers.redis_datastore import RedisDataStore
import datastore.providers.redis_datastore as static_redis
from models.models import DocumentChunk, DocumentChunkMetadata, QueryWithEmbedding, Source
import pytest
import redis.asyncio as redis
import numpy as np

@pytest.fixture
async def redis_datastore():
return await RedisDataStore.init(dim=5)


def create_embedding(i, dim):
vec = np.array([0.1] * dim).astype(np.float64).tolist()
vec[dim-1] = i+1/10
return vec

def create_document_chunk(i, dim):
return DocumentChunk(
id=f"first-doc_{i}",
text=f"Lorem ipsum {i}",
embedding=create_embedding(i, dim),
metadata=DocumentChunkMetadata(
source=Source.file, created_at="1970-01-01", document_id=f"doc-{i}"
),
)

def create_document_chunks(n, dim):
docs = [create_document_chunk(i, dim) for i in range(n)]
return {"docs": docs}

@pytest.mark.asyncio
async def test_redis_upsert_query(redis_datastore):
docs = create_document_chunks(10, 5)
await redis_datastore._upsert(docs)
query = QueryWithEmbedding(
query="Lorem ipsum 0",
top_k=5,
embedding= create_embedding(0, 5),
)
query_results = await redis_datastore._query(queries=[query])
assert 1 == len(query_results)
for i in range(5):
assert f"Lorem ipsum {i}" == query_results[0].results[i].text
assert f"doc-{i}" == query_results[0].results[i].id

0 comments on commit 71650a5

Please sign in to comment.