Skip to content

Commit

Permalink
simplify the embeddings script. save sqlite-vec db
Browse files Browse the repository at this point in the history
  • Loading branch information
metazool committed Feb 13, 2025
1 parent 6fa9017 commit 91f255f
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 46 deletions.
33 changes: 21 additions & 12 deletions scripts/image_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,42 @@

import os
import logging
from tqdm import tqdm
import yaml
from dotenv import load_dotenv
from cyto_ml.models.utils import flat_embeddings
from cyto_ml.models.utils import flat_embeddings, resnet18
from cyto_ml.data.image import load_image_from_url

from resnet50_cefas import load_model
from cyto_ml.data.vectorstore import vector_store
import pandas as pd

logging.basicConfig(level=logging.info)
load_dotenv()

STATE_FILE='../models/ResNet_18_3classes_RGB.pth'

if __name__ == "__main__":

# Limited to the Lancaster FlowCam dataset for now:
image_bucket = yaml.safe_load(open("params.yaml"))["collection"]
catalog = f"{image_bucket}/catalog.csv"
file_index = f"{image_bucket}.csv"

file_index = f"{os.environ.get('AWS_URL_ENDPOINT')}/{catalog}"
# We have a static file index, written by image_index.py
df = pd.read_csv(file_index)

collection = vector_store("sqlite", image_bucket)
# Keep a sqlite db per-collection. Plan to sync these to s3 with DVC
db_dir = '../data'
if not os.path.exists(db_dir):
os.mkdir(db_dir)
collection = vector_store("sqlite", f"{db_dir}/{image_bucket}.db")

model = load_model(strip_final_layer=True)
# Turing Inst 3-class lightweight model needs downloaded manually.
# Please see https://github.com/alan-turing-institute/ViT-LASNet/issues/2
model = resnet18(num_classes=3, filename=STATE_FILE, strip_final_layer=True)

def store_embeddings(row):
def store_embeddings(url):
try:
image_data = load_image_from_url(row.Filename)
image_data = load_image_from_url(url)
except ValueError as err:
# TODO diagnose and fix for this happening, in rare circumstances:
# (would be nice to know rather than just buffer the image and add code)
Expand All @@ -40,7 +48,7 @@ def store_embeddings(row):
# raise ValueError("Cannot seek streaming HTTP file")
# Is this still reproducible? - JW
logging.info(err)
logging.info(row.Filename)
logging.info(url)
return
except OSError as err:
logging.info(err)
Expand All @@ -50,9 +58,10 @@ def store_embeddings(row):
embeddings = flat_embeddings(model(image_data))

collection.add(
url=row.Filename,
url=url,
embeddings=embeddings,
)

for _, row in df.iterrows():
store_embeddings(row)
for _, row in tqdm(df.iterrows()):
image_url = f"{os.environ['AWS_URL_ENDPOINT']}/{image_bucket}/{row[0]}"
store_embeddings(image_url)
32 changes: 0 additions & 32 deletions scripts/image_metadata.py

This file was deleted.

2 changes: 1 addition & 1 deletion scripts/params.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cluster:
n_clusters: 5

collection: untagged-images-lana
collection: untagged-images-wala
2 changes: 1 addition & 1 deletion src/cyto_ml/data/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def ids(self) -> List[str]:


class SQLiteVecStore(VectorStore):
def __init__(self, db_name: str, embedding_len: Optional[int] = 2048, check_same_thread: bool = True):
def __init__(self, db_name: str, embedding_len: Optional[int] = 512, check_same_thread: bool = True):
self._check_same_thread = check_same_thread
self.embedding_len = embedding_len
self.load_ext(db_name)
Expand Down

0 comments on commit 91f255f

Please sign in to comment.