Skip to content

Commit

Permalink
Support for different vector stores (#56)
Browse files Browse the repository at this point in the history
* Outline for a minimal FastAPI interface to range of models

* start sketching out API based on existing code / model

* fastapi[standard] in env, dump of catalog data locally

* kludge interface, full of TODOs to the new 3 class model

* Return a class label from the Resnet18 model

* Return both classification and embeddings from Resnet18 model, roughly

* load models as globals when worker starts, see comments

* add a basic API test and a wee bit of error handling

* pyproject dependencies needed for the pipeline

* limit CI tests to run only for code changes

* should have set this a lot sooner!

* Add an abstract interface to vectorstores, breaking the tests

* fix up the existing chromadb tests

* clean up some points of vector store reuse

* extend the test coverage, have the app use new interface

* remember to commit the new config.py for the app

* fix interface in the scripts (we're not using much)

* remove workflow_call from test CI config

* slowly flesh out the sqlite-vec storage option

* fill out the sqlite implementation

* Revert "remove workflow_call from test CI config"

This reverts commit d330422.

* Reapply "remove workflow_call from test CI config"

This reverts commit cfe1c25.

* put `workflow_call` back and try to limit caller's paths

* YAML whitespace glitch?

* test queries in the chroma wrapper, tweak output

* deserialise vectors packed as bytes back to floats, test

* remove the caller pipeline, just complication

* expand the base class, stub interfaces for different stores

* test shared behaviour of different backends, prune print statements

* give the workflow a nudge without paths, regret the change now

* paths need glob expansion, "naturally"

* sqlite3 is in python core!

* whitespace change, nudge the workflow

* limit N syntax is sqlite3 version specific...

* .[lint] still installs everything else - direct from pypi instead

* (brittle) test check whether we have weights downloaded

* optional embeddings length on init database; explicit commit()

* generalised config options for different backends
  • Loading branch information
metazool authored Dec 11, 2024
1 parent 3c4c09d commit ba7e710
Show file tree
Hide file tree
Showing 19 changed files with 459 additions and 125 deletions.
12 changes: 0 additions & 12 deletions .github/workflows/pipeline.yml

This file was deleted.

24 changes: 11 additions & 13 deletions .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,9 @@ name: Python Install and Tests
on:
push:
paths:
- 'src'
- 'tests'
workflow_call:
inputs:
coverage_threshold:
description: 'Minimum required test coverage percentage'
required: false
type: number
default: 50
- 'src/**'
- 'tests/**'
pull_request:

permissions:
contents: read
Expand All @@ -28,15 +22,19 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install just the lint dependencies first
run: python -m pip install flake8 isort ruff
- name: Lint with ruff
run: ruff check
- name: Check format with ruff
run: ruff format --check

- name: Install dependencies
run: |
sudo apt install libimage-exiftool-perl -y
python -m pip install --upgrade pip
python -m pip install .[lint,test,pipeline]
- name: Lint with ruff
run: ruff check
- name: Check format with ruff
run: ruff format --check
- name: Test with pytest
run: |
if [ "${{ inputs.coverage_threshold }}" -gt 0 ]; then
Expand Down
38 changes: 38 additions & 0 deletions VECTOR_STORES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Vector stores

Investigation of alternative vector stores for image model embeddings.

## ChromaDB

* "Simplest useful thing", default in the LangChain examples for LLM rapid prototyping
* Idiosyncratic, not standards-oriented
* Evolving quickly (a couple of back-incompatible API changes since starting with it)

## SQLite-vec

* Lightweight and helpful examples, quick to start with?
* Single process
* "_expect breaking changes!_"

https://til.simonwillison.net/sqlite/sqlite-vec

https://github.com/asg017/sqlite-vec

https://github.com/asg017/sqlite-vec/releases

```
pip install sqlite-utils
sqlite-utils install sqlite-utils-sqlite-vec
```

Main use is in the `streamlit` app which is _really_ tied to the internal logic of `chromadb` :/

Queries are

* get all identifiers (need `LIMIT` for large collection) - URLs were used directly as IDs
* get embeddings vector for one ID
* get N closest results to one set of embeddings by cosine similarity




1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies:
- python-dotenv
- scikit-learn
- scikit-image
- sqlite-vec
- sqlalchemy==1.4.54 # see https://github.com/spotify/luigi/issues/3227
- xarray
- pip
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"requests",
"scikit-image",
"scikit-learn",
"sqlite-vec",
"streamlit",
"torch",
"torchvision",
Expand Down
9 changes: 3 additions & 6 deletions scripts/image_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
logging.basicConfig(level=logging.info)
load_dotenv()


if __name__ == "__main__":

# Limited to the Lancaster FlowCam dataset for now:
Expand All @@ -25,7 +24,7 @@
file_index = f"{os.environ.get('AWS_URL_ENDPOINT')}/{catalog}"
df = pd.read_csv(file_index)

collection = vector_store(image_bucket)
collection = vector_store("sqlite", image_bucket)

model = load_model(strip_final_layer=True)

Expand All @@ -51,10 +50,8 @@ def store_embeddings(row):
embeddings = flat_embeddings(model(image_data))

collection.add(
documents=[row.Filename],
embeddings=[embeddings],
ids=[row.Filename], # must be unique
# Note - optional arg name is "metadatas" (we don't have any)
url=row.Filename,
embeddings=embeddings,
)

for _, row in df.iterrows():
Expand Down
40 changes: 40 additions & 0 deletions scripts/image_embeddings_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Extract and store image embeddings from a collection in s3,
using an API that calls one or more off-the-shelf pre-trained models"""

import os
import logging
import yaml
from dotenv import load_dotenv
from cyto_ml.data.vectorstore import vector_store
import pandas as pd
import requests

logging.basicConfig(level=logging.info)
load_dotenv()

ENDPOINT = "http://localhost:8000/resnet18/"
PARAMS = os.path.join(os.path.abspath(os.path.dirname(__file__)), "params.yaml")

if __name__ == "__main__":

# Limited to the Lancaster FlowCam dataset for now:
image_bucket = yaml.safe_load(open(PARAMS))["collection"]
catalog = f"{image_bucket}/catalog.csv"

file_index = f"{os.environ.get('AWS_URL_ENDPOINT')}/{catalog}"
df = pd.read_csv(file_index)

# TODO - optional embedding length param at this point, it's not ideal
collection = vector_store("sqlite", image_bucket, embedding_len=512)

def store_embeddings(url):
response = requests.post(ENDPOINT, data={"url": url}).json()
if not "embeddings" in response:
logging.error(response)
raise

response["url"] = url
collection.add(**response)

for _, row in df.iterrows():
store_embeddings(row.item())
2 changes: 1 addition & 1 deletion scripts/image_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
s3 = boto3_client()

catalog_csv = metadata.to_csv(index=False)
with open('catalog.csv', 'w') as out:
with open("catalog.csv", "w") as out:
out.write(catalog_csv)

s3.put_object(Bucket=image_bucket, Key="catalog.csv", Body=catalog_csv)
21 changes: 12 additions & 9 deletions scripts/run_luigi_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
from cyto_ml.pipeline.pipeline_decollage import FlowCamPipeline


if __name__ == '__main__':
luigi.build([
FlowCamPipeline(
directory="./tests/fixtures/MicrobialMethane_MESO_Tank10_54.0143_-2.7770_04052023_1",
output_directory="./data/images_decollage",
experiment_name="test_experiment",
s3_bucket="test-upload-alba"
)
], local_scheduler=False)
if __name__ == "__main__":
luigi.build(
[
FlowCamPipeline(
directory="./tests/fixtures/MicrobialMethane_MESO_Tank10_54.0143_-2.7770_04052023_1",
output_directory="./data/images_decollage",
experiment_name="test_experiment",
s3_bucket="test-upload-alba",
)
],
local_scheduler=False,
)
11 changes: 11 additions & 0 deletions src/cyto_ml/data/db_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# TODO manage this better elsewhere, once we settle on a storage option
SQLITE_SCHEMA = """
create virtual table embeddings using vec0(
id integer primary key,
url text not null,
classification text not null,
embedding float[{}]);
"""

# Options passed as keyword arguments when setting a db connection
OPTIONS = {"sqlite": {"embedding_len": 512, "check_same_thread": False}, "chromadb": {}}
1 change: 0 additions & 1 deletion src/cyto_ml/data/decollage.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def read_metadata(self) -> None:
self.metadata = {}

files = glob.glob(f"{self.directory}/*.lst")
print(files)

if len(files) == 0:
raise FileNotFoundError("no lst file in this directory")
Expand Down
Loading

0 comments on commit ba7e710

Please sign in to comment.