Skip to content

Commit

Permalink
Simplify weaviate auth (openai#223)
Browse files Browse the repository at this point in the history
* add function to check host

Signed-off-by: hsm207 <[email protected]>

* fix bug in test

Signed-off-by: hsm207 <[email protected]>

* simplify auth build logic

Signed-off-by: hsm207 <[email protected]>

* remove useless en vars

Signed-off-by: hsm207 <[email protected]>

* remove useless env var

Signed-off-by: hsm207 <[email protected]>

* add default url

Signed-off-by: hsm207 <[email protected]>

* remove todo

Signed-off-by: hsm207 <[email protected]>

* code cleanup

Signed-off-by: hsm207 <[email protected]>

* code cleanup

Signed-off-by: hsm207 <[email protected]>

* Update README

Signed-off-by: hsm207 <[email protected]>

* Remove batch config env vars

* fix regex to also check for WCS enterprise cluster

Signed-off-by: hsm207 <[email protected]>

---------

Signed-off-by: hsm207 <[email protected]>
  • Loading branch information
hsm207 authored May 5, 2023
1 parent 0ebb015 commit d2e0298
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 95 deletions.
11 changes: 2 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,9 @@ Follow these steps to quickly set up and run the ChatGPT Retrieval Plugin:
export PINECONE_INDEX=<your_pinecone_index>
# Weaviate
export WEAVIATE_HOST=<your_weaviate_instance_url>
export WEAVIATE_PORT=<your_weaviate_port_443_for_WCS>
export WEAVIATE_URL=<your_weaviate_instance_url>
export WEAVIATE_API_KEY=<your_api_key_for_WCS>
export WEAVIATE_CLASS=<your_optional_weaviate_class>
export WEAVIATE_USERNAME=<your_weaviate_WCS_username>
export WEAVIATE_PASSWORD=<your_weaviate_WCS_password>
export WEAVIATE_SCOPES=<your_optional_weaviate_scopes>
export WEAVIATE_BATCH_SIZE=<optional_weaviate_batch_size>
export WEAVIATE_BATCH_DYNAMIC=<optional_weaviate_batch_dynamic>
export WEAVIATE_BATCH_TIMEOUT_RETRIES=<optional_weaviate_batch_timeout_retries>
export WEAVIATE_BATCH_NUM_WORKERS=<optional_weaviate_batch_num_workers>
# Zilliz
export ZILLIZ_COLLECTION=<your_zilliz_collection>
Expand Down
47 changes: 30 additions & 17 deletions datastore/providers/weaviate_datastore.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,26 @@
# TODO
import asyncio
from typing import Dict, List, Optional
from loguru import logger
from weaviate import Client
import weaviate
import os
import re
import uuid
from typing import Dict, List, Optional

import weaviate
from loguru import logger
from weaviate import Client
from weaviate.util import generate_uuid5

from datastore.datastore import DataStore
from models.models import (
DocumentChunk,
DocumentChunkMetadata,
DocumentChunkWithScore,
DocumentMetadataFilter,
QueryResult,
QueryWithEmbedding,
DocumentChunkWithScore,
Source,
)


WEAVIATE_HOST = os.environ.get("WEAVIATE_HOST", "http://127.0.0.1")
WEAVIATE_PORT = os.environ.get("WEAVIATE_PORT", "8080")
WEAVIATE_USERNAME = os.environ.get("WEAVIATE_USERNAME", None)
WEAVIATE_PASSWORD = os.environ.get("WEAVIATE_PASSWORD", None)
WEAVIATE_SCOPES = os.environ.get("WEAVIATE_SCOPES", "offline_access")
WEAVIATE_URL_DEFAULT = "http://localhost:8080"
WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "OpenAIDocument")

WEAVIATE_BATCH_SIZE = int(os.environ.get("WEAVIATE_BATCH_SIZE", 20))
Expand Down Expand Up @@ -109,7 +104,7 @@ def handle_errors(self, results: Optional[List[dict]]) -> List[str]:
def __init__(self):
auth_credentials = self._build_auth_credentials()

url = f"{WEAVIATE_HOST}:{WEAVIATE_PORT}"
url = os.environ.get("WEAVIATE_URL", WEAVIATE_URL_DEFAULT)

logger.debug(
f"Connecting to weaviate instance at {url} with credential type {type(auth_credentials).__name__}"
Expand Down Expand Up @@ -140,10 +135,14 @@ def __init__(self):

@staticmethod
def _build_auth_credentials():
if WEAVIATE_USERNAME and WEAVIATE_PASSWORD:
return weaviate.auth.AuthClientPassword(
WEAVIATE_USERNAME, WEAVIATE_PASSWORD, WEAVIATE_SCOPES
)
url = os.environ.get("WEAVIATE_URL", WEAVIATE_URL_DEFAULT)

if WeaviateDataStore._is_wcs_domain(url):
api_key = os.environ.get("WEAVIATE_API_KEY")
if api_key is not None:
return weaviate.auth.AuthApiKey(api_key=api_key)
else:
raise ValueError("WEAVIATE_API_KEY environment variable is not set")
else:
return None

Expand Down Expand Up @@ -370,3 +369,17 @@ def _is_valid_weaviate_id(candidate_id: str) -> bool:
return True
except ValueError:
return False

@staticmethod
def _is_wcs_domain(url: str) -> bool:
"""
Check if the given URL ends with ".weaviate.network" or ".weaviate.network/".
Args:
url (str): The URL to check.
Returns:
bool: True if the URL ends with the specified strings, False otherwise.
"""
pattern = r"\.(weaviate\.cloud|weaviate\.network)(/)?$"
return bool(re.search(pattern, url))
32 changes: 4 additions & 28 deletions docs/providers/weaviate/setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,39 +65,15 @@ You need to set some environment variables to connect to your Weaviate instance.
| Name | Required | Description | Default |
|------------------| -------- | ------------------------------------------------------------------ | ------------------ |
| `WEAVIATE_HOST` | Optional | Your Weaviate instance host address (see notes below) | `http://127.0.0.1` |
| `WEAVIATE_PORT` | Optional | Your Weaviate port number (use 443 for WCS) | 8080 |
| `WEAVIATE_URL` | Optional | Your weaviate instance's url/WCS endpoint | `http://localhost:8080` | |
| `WEAVIATE_CLASS` | Optional | Your chosen Weaviate class/collection name to store your documents | OpenAIDocument |
> For **WCS instances**, set `WEAVIATE_PORT` to 443 and `WEAVIATE_HOST` to `https://(wcs-instance-name).weaviate.network`. For example: `https://my-project.weaviate.network/`.
> For **self-hosted instances**, if your instance is not at 127.0.0.1:8080, set `WEAVIATE_HOST` and `WEAVIATE_PORT` accordingly. For example: `WEAVIATE_HOST=http://localhost/` and `WEAVIATE_PORT=4040`.
**Weaviate Auth Environment Variables**
If you enabled OIDC authentication for your Weaviate instance (recommended for WCS instances), set the following environment variables. If you enabled anonymous access, skip this section.
If using WCS instances, set the following environment variables:
| Name | Required | Description |
| ------------------- | -------- | ------------------------------ |
| `WEAVIATE_USERNAME` | Yes | Your OIDC or WCS username |
| `WEAVIATE_PASSWORD` | Yes | Your OIDC or WCS password |
| `WEAVIATE_SCOPES` | Optional | Space-separated list of scopes |
Learn more about [authentication in Weaviate](https://weaviate.io/developers/weaviate/configuration/authentication#overview) and the [Python client authentication](https://weaviate-python-client.readthedocs.io/en/stable/weaviate.auth.html).
**Weaviate Batch Import Environment Variables**
Weaviate uses a batching mechanism to perform operations in bulk. This makes importing and updating your data faster and more efficient. You can adjust the batch settings with these optional environment variables:
| Name | Required | Description | Default |
| -------------------------------- | -------- | ------------------------------------------------------------ | ------- |
| `WEAVIATE_BATCH_SIZE` | Optional | Number of insert/updates per batch operation | 20 |
| `WEAVIATE_BATCH_DYNAMIC` | Optional | Lets the batch process decide the batch size | False |
| `WEAVIATE_BATCH_TIMEOUT_RETRIES` | Optional | Number of retry-on-timeout attempts | 3 |
| `WEAVIATE_BATCH_NUM_WORKERS` | Optional | The max number of concurrent threads to run batch operations | 1 |
> **Note:** The optimal `WEAVIATE_BATCH_SIZE` depends on the available resources (RAM, CPU). A higher value means faster bulk operations, but also higher demand for RAM and CPU. If you experience failures during the import process, reduce the batch size.
> Setting `WEAVIATE_BATCH_SIZE` to `None` means no limit to the batch size. All insert or update operations would be sent to Weaviate in a single operation. This might be risky, as you lose control over the batch size.
| `WEAVIATE_API_KEY` | Yes | Your API key WCS |
Learn more about [batch configuration in Weaviate](https://weaviate.io/developers/weaviate/client-libraries/python#batch-configuration).
Learn more about accessing your [WCS API key](https://weaviate.io/developers/wcs/guides/authentication#access-api-keys).
99 changes: 58 additions & 41 deletions tests/datastore/providers/weaviate/test_weaviate_datastore.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import logging
import os

import pytest
import weaviate
from _pytest.logging import LogCaptureFixture
from fastapi.testclient import TestClient
from loguru import logger
from weaviate import Client
import weaviate
import os
from models.models import DocumentMetadataFilter, Source
from server.main import app

from datastore.providers.weaviate_datastore import (
SCHEMA,
WeaviateDataStore,
extract_schema_properties,
)
import logging
from loguru import logger
from _pytest.logging import LogCaptureFixture
from models.models import DocumentMetadataFilter, Source
from server.main import app

BEARER_TOKEN = os.getenv("BEARER_TOKEN")

Expand Down Expand Up @@ -99,30 +101,6 @@ def documents():
yield documents


@pytest.fixture
def mock_env_public_access(monkeypatch):
monkeypatch.setattr(
"datastore.providers.weaviate_datastore.WEAVIATE_USERNAME", None
)
monkeypatch.setattr(
"datastore.providers.weaviate_datastore.WEAVIATE_PASSWORD", None
)


@pytest.fixture
def mock_env_resource_owner_password_flow(monkeypatch):
monkeypatch.setattr(
"datastore.providers.weaviate_datastore.WEAVIATE_SCOPES",
["schema:read", "schema:write"],
)
monkeypatch.setattr(
"datastore.providers.weaviate_datastore.WEAVIATE_USERNAME", "admin"
)
monkeypatch.setattr(
"datastore.providers.weaviate_datastore.WEAVIATE_PASSWORD", "abc123"
)


@pytest.fixture
def caplog(caplog: LogCaptureFixture):
handler_id = logger.add(caplog.handler, format="{message}")
Expand Down Expand Up @@ -337,16 +315,38 @@ def test_delete(test_db, weaviate_client, caplog):
assert not weaviate_client.data_object.get()["objects"]


def test_access_with_username_password(mock_env_resource_owner_password_flow):
auth_credentials = WeaviateDataStore._build_auth_credentials()

assert isinstance(auth_credentials, weaviate.auth.AuthClientPassword)


def test_public_access(mock_env_public_access):
auth_credentials = WeaviateDataStore._build_auth_credentials()

assert auth_credentials is None
def test_build_auth_credentials(monkeypatch):
# Test when WEAVIATE_URL ends with weaviate.network and WEAVIATE_API_KEY is set
with monkeypatch.context() as m:
m.setenv("WEAVIATE_URL", "https://example.weaviate.network")
m.setenv("WEAVIATE_API_KEY", "your_api_key")
auth_credentials = WeaviateDataStore._build_auth_credentials()
assert auth_credentials is not None
assert isinstance(auth_credentials, weaviate.auth.AuthApiKey)
assert auth_credentials.api_key == "your_api_key"

# Test when WEAVIATE_URL ends with weaviate.network and WEAVIATE_API_KEY is not set
with monkeypatch.context() as m:
m.setenv("WEAVIATE_URL", "https://example.weaviate.network")
m.delenv("WEAVIATE_API_KEY", raising=False)
with pytest.raises(
ValueError, match="WEAVIATE_API_KEY environment variable is not set"
):
WeaviateDataStore._build_auth_credentials()

# Test when WEAVIATE_URL does not end with weaviate.network
with monkeypatch.context() as m:
m.setenv("WEAVIATE_URL", "https://example.notweaviate.network")
m.setenv("WEAVIATE_API_KEY", "your_api_key")
auth_credentials = WeaviateDataStore._build_auth_credentials()
assert auth_credentials is None

# Test when WEAVIATE_URL is not set
with monkeypatch.context() as m:
m.delenv("WEAVIATE_URL", raising=False)
m.setenv("WEAVIATE_API_KEY", "your_api_key")
auth_credentials = WeaviateDataStore._build_auth_credentials()
assert auth_credentials is None


def test_extract_schema_properties():
Expand Down Expand Up @@ -519,3 +519,20 @@ def build_upsert_payload(document):
# but it is None right now because an
# update function is out of scope
assert weaviate_doc[0]["source"] is None


@pytest.mark.parametrize(
"url, expected_result",
[
("https://example.weaviate.network", True),
("https://example.weaviate.network/", True),
("https://example.weaviate.cloud", True),
("https://example.weaviate.cloud/", True),
("https://example.notweaviate.network", False),
("https://weaviate.network.example.com", False),
("https://example.weaviate.network/somepage", False),
("", False),
],
)
def test_is_wcs_domain(url, expected_result):
assert WeaviateDataStore._is_wcs_domain(url) == expected_result

0 comments on commit d2e0298

Please sign in to comment.