diff --git a/README.md b/README.md index 0929e0f62..1bf02f6f3 100644 --- a/README.md +++ b/README.md @@ -88,16 +88,9 @@ Follow these steps to quickly set up and run the ChatGPT Retrieval Plugin: export PINECONE_INDEX= # Weaviate - export WEAVIATE_HOST= - export WEAVIATE_PORT= + export WEAVIATE_URL= + export WEAVIATE_API_KEY= export WEAVIATE_CLASS= - export WEAVIATE_USERNAME= - export WEAVIATE_PASSWORD= - export WEAVIATE_SCOPES= - export WEAVIATE_BATCH_SIZE= - export WEAVIATE_BATCH_DYNAMIC= - export WEAVIATE_BATCH_TIMEOUT_RETRIES= - export WEAVIATE_BATCH_NUM_WORKERS= # Zilliz export ZILLIZ_COLLECTION= diff --git a/datastore/providers/weaviate_datastore.py b/datastore/providers/weaviate_datastore.py index 136eb614a..9202835e5 100644 --- a/datastore/providers/weaviate_datastore.py +++ b/datastore/providers/weaviate_datastore.py @@ -1,31 +1,26 @@ -# TODO import asyncio -from typing import Dict, List, Optional -from loguru import logger -from weaviate import Client -import weaviate import os +import re import uuid +from typing import Dict, List, Optional +import weaviate +from loguru import logger +from weaviate import Client from weaviate.util import generate_uuid5 from datastore.datastore import DataStore from models.models import ( DocumentChunk, DocumentChunkMetadata, + DocumentChunkWithScore, DocumentMetadataFilter, QueryResult, QueryWithEmbedding, - DocumentChunkWithScore, Source, ) - -WEAVIATE_HOST = os.environ.get("WEAVIATE_HOST", "http://127.0.0.1") -WEAVIATE_PORT = os.environ.get("WEAVIATE_PORT", "8080") -WEAVIATE_USERNAME = os.environ.get("WEAVIATE_USERNAME", None) -WEAVIATE_PASSWORD = os.environ.get("WEAVIATE_PASSWORD", None) -WEAVIATE_SCOPES = os.environ.get("WEAVIATE_SCOPES", "offline_access") +WEAVIATE_URL_DEFAULT = "http://localhost:8080" WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "OpenAIDocument") WEAVIATE_BATCH_SIZE = int(os.environ.get("WEAVIATE_BATCH_SIZE", 20)) @@ -109,7 +104,7 @@ def handle_errors(self, results: Optional[List[dict]]) -> List[str]: def __init__(self): auth_credentials = self._build_auth_credentials() - url = f"{WEAVIATE_HOST}:{WEAVIATE_PORT}" + url = os.environ.get("WEAVIATE_URL", WEAVIATE_URL_DEFAULT) logger.debug( f"Connecting to weaviate instance at {url} with credential type {type(auth_credentials).__name__}" @@ -140,10 +135,14 @@ def __init__(self): @staticmethod def _build_auth_credentials(): - if WEAVIATE_USERNAME and WEAVIATE_PASSWORD: - return weaviate.auth.AuthClientPassword( - WEAVIATE_USERNAME, WEAVIATE_PASSWORD, WEAVIATE_SCOPES - ) + url = os.environ.get("WEAVIATE_URL", WEAVIATE_URL_DEFAULT) + + if WeaviateDataStore._is_wcs_domain(url): + api_key = os.environ.get("WEAVIATE_API_KEY") + if api_key is not None: + return weaviate.auth.AuthApiKey(api_key=api_key) + else: + raise ValueError("WEAVIATE_API_KEY environment variable is not set") else: return None @@ -370,3 +369,17 @@ def _is_valid_weaviate_id(candidate_id: str) -> bool: return True except ValueError: return False + + @staticmethod + def _is_wcs_domain(url: str) -> bool: + """ + Check if the given URL ends with ".weaviate.network" or ".weaviate.network/". + + Args: + url (str): The URL to check. + + Returns: + bool: True if the URL ends with the specified strings, False otherwise. + """ + pattern = r"\.(weaviate\.cloud|weaviate\.network)(/)?$" + return bool(re.search(pattern, url)) diff --git a/docs/providers/weaviate/setup.md b/docs/providers/weaviate/setup.md index f01611277..31916329e 100644 --- a/docs/providers/weaviate/setup.md +++ b/docs/providers/weaviate/setup.md @@ -65,39 +65,15 @@ You need to set some environment variables to connect to your Weaviate instance. | Name | Required | Description | Default | |------------------| -------- | ------------------------------------------------------------------ | ------------------ | -| `WEAVIATE_HOST` | Optional | Your Weaviate instance host address (see notes below) | `http://127.0.0.1` | -| `WEAVIATE_PORT` | Optional | Your Weaviate port number (use 443 for WCS) | 8080 | +| `WEAVIATE_URL` | Optional | Your weaviate instance's url/WCS endpoint | `http://localhost:8080` | | | `WEAVIATE_CLASS` | Optional | Your chosen Weaviate class/collection name to store your documents | OpenAIDocument | -> For **WCS instances**, set `WEAVIATE_PORT` to 443 and `WEAVIATE_HOST` to `https://(wcs-instance-name).weaviate.network`. For example: `https://my-project.weaviate.network/`. - -> For **self-hosted instances**, if your instance is not at 127.0.0.1:8080, set `WEAVIATE_HOST` and `WEAVIATE_PORT` accordingly. For example: `WEAVIATE_HOST=http://localhost/` and `WEAVIATE_PORT=4040`. - **Weaviate Auth Environment Variables** -If you enabled OIDC authentication for your Weaviate instance (recommended for WCS instances), set the following environment variables. If you enabled anonymous access, skip this section. +If using WCS instances, set the following environment variables: | Name | Required | Description | | ------------------- | -------- | ------------------------------ | -| `WEAVIATE_USERNAME` | Yes | Your OIDC or WCS username | -| `WEAVIATE_PASSWORD` | Yes | Your OIDC or WCS password | -| `WEAVIATE_SCOPES` | Optional | Space-separated list of scopes | - -Learn more about [authentication in Weaviate](https://weaviate.io/developers/weaviate/configuration/authentication#overview) and the [Python client authentication](https://weaviate-python-client.readthedocs.io/en/stable/weaviate.auth.html). - -**Weaviate Batch Import Environment Variables** - -Weaviate uses a batching mechanism to perform operations in bulk. This makes importing and updating your data faster and more efficient. You can adjust the batch settings with these optional environment variables: - -| Name | Required | Description | Default | -| -------------------------------- | -------- | ------------------------------------------------------------ | ------- | -| `WEAVIATE_BATCH_SIZE` | Optional | Number of insert/updates per batch operation | 20 | -| `WEAVIATE_BATCH_DYNAMIC` | Optional | Lets the batch process decide the batch size | False | -| `WEAVIATE_BATCH_TIMEOUT_RETRIES` | Optional | Number of retry-on-timeout attempts | 3 | -| `WEAVIATE_BATCH_NUM_WORKERS` | Optional | The max number of concurrent threads to run batch operations | 1 | - -> **Note:** The optimal `WEAVIATE_BATCH_SIZE` depends on the available resources (RAM, CPU). A higher value means faster bulk operations, but also higher demand for RAM and CPU. If you experience failures during the import process, reduce the batch size. - -> Setting `WEAVIATE_BATCH_SIZE` to `None` means no limit to the batch size. All insert or update operations would be sent to Weaviate in a single operation. This might be risky, as you lose control over the batch size. +| `WEAVIATE_API_KEY` | Yes | Your API key WCS | -Learn more about [batch configuration in Weaviate](https://weaviate.io/developers/weaviate/client-libraries/python#batch-configuration). +Learn more about accessing your [WCS API key](https://weaviate.io/developers/wcs/guides/authentication#access-api-keys). \ No newline at end of file diff --git a/tests/datastore/providers/weaviate/test_weaviate_datastore.py b/tests/datastore/providers/weaviate/test_weaviate_datastore.py index 66c1db148..300eb4849 100644 --- a/tests/datastore/providers/weaviate/test_weaviate_datastore.py +++ b/tests/datastore/providers/weaviate/test_weaviate_datastore.py @@ -1,18 +1,20 @@ +import logging +import os + import pytest +import weaviate +from _pytest.logging import LogCaptureFixture from fastapi.testclient import TestClient +from loguru import logger from weaviate import Client -import weaviate -import os -from models.models import DocumentMetadataFilter, Source -from server.main import app + from datastore.providers.weaviate_datastore import ( SCHEMA, WeaviateDataStore, extract_schema_properties, ) -import logging -from loguru import logger -from _pytest.logging import LogCaptureFixture +from models.models import DocumentMetadataFilter, Source +from server.main import app BEARER_TOKEN = os.getenv("BEARER_TOKEN") @@ -99,30 +101,6 @@ def documents(): yield documents -@pytest.fixture -def mock_env_public_access(monkeypatch): - monkeypatch.setattr( - "datastore.providers.weaviate_datastore.WEAVIATE_USERNAME", None - ) - monkeypatch.setattr( - "datastore.providers.weaviate_datastore.WEAVIATE_PASSWORD", None - ) - - -@pytest.fixture -def mock_env_resource_owner_password_flow(monkeypatch): - monkeypatch.setattr( - "datastore.providers.weaviate_datastore.WEAVIATE_SCOPES", - ["schema:read", "schema:write"], - ) - monkeypatch.setattr( - "datastore.providers.weaviate_datastore.WEAVIATE_USERNAME", "admin" - ) - monkeypatch.setattr( - "datastore.providers.weaviate_datastore.WEAVIATE_PASSWORD", "abc123" - ) - - @pytest.fixture def caplog(caplog: LogCaptureFixture): handler_id = logger.add(caplog.handler, format="{message}") @@ -337,16 +315,38 @@ def test_delete(test_db, weaviate_client, caplog): assert not weaviate_client.data_object.get()["objects"] -def test_access_with_username_password(mock_env_resource_owner_password_flow): - auth_credentials = WeaviateDataStore._build_auth_credentials() - - assert isinstance(auth_credentials, weaviate.auth.AuthClientPassword) - - -def test_public_access(mock_env_public_access): - auth_credentials = WeaviateDataStore._build_auth_credentials() - - assert auth_credentials is None +def test_build_auth_credentials(monkeypatch): + # Test when WEAVIATE_URL ends with weaviate.network and WEAVIATE_API_KEY is set + with monkeypatch.context() as m: + m.setenv("WEAVIATE_URL", "https://example.weaviate.network") + m.setenv("WEAVIATE_API_KEY", "your_api_key") + auth_credentials = WeaviateDataStore._build_auth_credentials() + assert auth_credentials is not None + assert isinstance(auth_credentials, weaviate.auth.AuthApiKey) + assert auth_credentials.api_key == "your_api_key" + + # Test when WEAVIATE_URL ends with weaviate.network and WEAVIATE_API_KEY is not set + with monkeypatch.context() as m: + m.setenv("WEAVIATE_URL", "https://example.weaviate.network") + m.delenv("WEAVIATE_API_KEY", raising=False) + with pytest.raises( + ValueError, match="WEAVIATE_API_KEY environment variable is not set" + ): + WeaviateDataStore._build_auth_credentials() + + # Test when WEAVIATE_URL does not end with weaviate.network + with monkeypatch.context() as m: + m.setenv("WEAVIATE_URL", "https://example.notweaviate.network") + m.setenv("WEAVIATE_API_KEY", "your_api_key") + auth_credentials = WeaviateDataStore._build_auth_credentials() + assert auth_credentials is None + + # Test when WEAVIATE_URL is not set + with monkeypatch.context() as m: + m.delenv("WEAVIATE_URL", raising=False) + m.setenv("WEAVIATE_API_KEY", "your_api_key") + auth_credentials = WeaviateDataStore._build_auth_credentials() + assert auth_credentials is None def test_extract_schema_properties(): @@ -519,3 +519,20 @@ def build_upsert_payload(document): # but it is None right now because an # update function is out of scope assert weaviate_doc[0]["source"] is None + + +@pytest.mark.parametrize( + "url, expected_result", + [ + ("https://example.weaviate.network", True), + ("https://example.weaviate.network/", True), + ("https://example.weaviate.cloud", True), + ("https://example.weaviate.cloud/", True), + ("https://example.notweaviate.network", False), + ("https://weaviate.network.example.com", False), + ("https://example.weaviate.network/somepage", False), + ("", False), + ], +) +def test_is_wcs_domain(url, expected_result): + assert WeaviateDataStore._is_wcs_domain(url) == expected_result