Skip to content

Commit

Permalink
Redis improvements and bug fixes (openai#179)
Browse files Browse the repository at this point in the history
* fixes openai#132

* enumify datastore providers and fix openai#154

* add providers enum

* fixes default value for numerics and fixes openai#51

* add more tests and cleanup

* revert enum approach
  • Loading branch information
tylerhutcherson authored May 9, 2023
1 parent 3ed6df1 commit 919e543
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 14 deletions.
5 changes: 4 additions & 1 deletion datastore/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,7 @@ async def get_datastore() -> DataStore:

return QdrantDataStore()
case _:
raise ValueError(f"Unsupported vector database: {datastore}")
raise ValueError(
f"Unsupported vector database: {datastore}. "
f"Try one of the following: llama, pinecone, weaviate, milvus, zilliz, redis, or qdrant"
)
Empty file added datastore/providers/__init__.py
Empty file.
12 changes: 5 additions & 7 deletions datastore/providers/redis_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def unpack_schema(d: dict):
yield v

async def _check_redis_module_exist(client: redis.Redis, modules: List[dict]):

installed_modules = (await client.info()).get("modules", [])
installed_modules = {module["name"]: module for module in installed_modules}
for module in modules:
Expand All @@ -66,14 +65,13 @@ async def _check_redis_module_exist(client: redis.Redis, modules: List[dict]):
raise AttributeError(error_message)



class RedisDataStore(DataStore):
def __init__(self, client: redis.Redis, redisearch_schema):
def __init__(self, client: redis.Redis, redisearch_schema: dict):
self.client = client
self._schema = redisearch_schema
# Init default metadata with sentinel values in case the document written has no metadata
self._default_metadata = {
field: "_null_" for field in redisearch_schema["metadata"]
field: (0 if field == "created_at" else "_null_") for field in redisearch_schema["metadata"]
}

### Redis Helper Methods ###
Expand All @@ -94,11 +92,11 @@ async def init(cls, **kwargs):
raise e

await _check_redis_module_exist(client, modules=REDIS_REQUIRED_MODULES)

dim = kwargs.get("dim", VECTOR_DIMENSION)
redisearch_schema = {
"document_id": TagField("$.document_id", as_name="document_id"),
"metadata": {
"document_id": TagField("$.metadata.document_id", as_name="document_id"),
"source_id": TagField("$.metadata.source_id", as_name="source_id"),
"source": TagField("$.metadata.source", as_name="source"),
"author": TextField("$.metadata.author", as_name="author"),
Expand Down Expand Up @@ -300,7 +298,7 @@ async def _query(
results: List[QueryResult] = []

# Gather query results in a pipeline
logging.info(f"Gathering {len(queries)} query results", flush=True)
logging.info(f"Gathering {len(queries)} query results")
for query in queries:

logging.info(f"Query: {query.query}")
Expand Down
31 changes: 25 additions & 6 deletions tests/datastore/providers/redis/test_redis_datastore.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from datastore.providers.redis_datastore import RedisDataStore
import datastore.providers.redis_datastore as static_redis
from models.models import DocumentChunk, DocumentChunkMetadata, QueryWithEmbedding, Source
from models.models import DocumentChunk, DocumentChunkMetadata, QueryWithEmbedding, Source, DocumentMetadataFilter
import pytest
import redis.asyncio as redis
import numpy as np

NUM_TEST_DOCS = 10

@pytest.fixture
async def redis_datastore():
return await RedisDataStore.init(dim=5)


def create_embedding(i, dim):
vec = np.array([0.1] * dim).astype(np.float64).tolist()
vec[dim-1] = i+1/10
Expand All @@ -21,7 +21,7 @@ def create_document_chunk(i, dim):
text=f"Lorem ipsum {i}",
embedding=create_embedding(i, dim),
metadata=DocumentChunkMetadata(
source=Source.file, created_at="1970-01-01", document_id=f"doc-{i}"
source=Source.file, created_at="1970-01-01", document_id="docs"
),
)

Expand All @@ -31,7 +31,7 @@ def create_document_chunks(n, dim):

@pytest.mark.asyncio
async def test_redis_upsert_query(redis_datastore):
docs = create_document_chunks(10, 5)
docs = create_document_chunks(NUM_TEST_DOCS, 5)
await redis_datastore._upsert(docs)
query = QueryWithEmbedding(
query="Lorem ipsum 0",
Expand All @@ -42,4 +42,23 @@ async def test_redis_upsert_query(redis_datastore):
assert 1 == len(query_results)
for i in range(5):
assert f"Lorem ipsum {i}" == query_results[0].results[i].text
assert f"doc-{i}" == query_results[0].results[i].id
assert "docs" == query_results[0].results[i].id

@pytest.mark.asyncio
async def test_redis_filter_query(redis_datastore):
query = QueryWithEmbedding(
query="Lorem ipsum 0",
filter=DocumentMetadataFilter(document_id="docs"),
top_k=5,
embedding= create_embedding(0, 5),
)
query_results = await redis_datastore._query(queries=[query])
print(query_results)
assert 1 == len(query_results)
assert "docs" == query_results[0].results[0].id


@pytest.mark.asyncio
async def test_redis_delete_docs(redis_datastore):
res = await redis_datastore.delete(ids=["docs"])
assert res

0 comments on commit 919e543

Please sign in to comment.