-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathimage_embeddings.py
67 lines (54 loc) · 2.3 KB
/
image_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""Extract and store image embeddings from a collection in s3,
using an off-the-shelf pre-trained model"""
import os
import logging
from tqdm import tqdm
import yaml
from dotenv import load_dotenv
from cyto_ml.models.utils import flat_embeddings, resnet18
from cyto_ml.data.image import load_image_from_url
from cyto_ml.data.vectorstore import vector_store
import pandas as pd
logging.basicConfig(level=logging.info)
load_dotenv()
STATE_FILE='../models/ResNet_18_3classes_RGB.pth'
if __name__ == "__main__":
# Limited to the Lancaster FlowCam dataset for now:
image_bucket = yaml.safe_load(open("params.yaml"))["collection"]
file_index = f"{image_bucket}.csv"
# We have a static file index, written by image_index.py
df = pd.read_csv(file_index)
# Keep a sqlite db per-collection. Plan to sync these to s3 with DVC
db_dir = '../data'
if not os.path.exists(db_dir):
os.mkdir(db_dir)
collection = vector_store("sqlite", f"{db_dir}/{image_bucket}.db")
# Turing Inst 3-class lightweight model needs downloaded manually.
# Please see https://github.com/alan-turing-institute/ViT-LASNet/issues/2
model = resnet18(num_classes=3, filename=STATE_FILE, strip_final_layer=True)
def store_embeddings(url):
try:
image_data = load_image_from_url(url)
except ValueError as err:
# TODO diagnose and fix for this happening, in rare circumstances:
# (would be nice to know rather than just buffer the image and add code)
# File "python3.9/site-packages/PIL/PcdImagePlugin.py", line 34, in _open
# self.fp.seek(2048)
# File "python3.9/site-packages/fsspec/implementations/http.py", line 745, in seek
# raise ValueError("Cannot seek streaming HTTP file")
# Is this still reproducible? - JW
logging.info(err)
logging.info(url)
return
except OSError as err:
logging.info(err)
logging.info(row.Filename)
return
embeddings = flat_embeddings(model(image_data))
collection.add(
url=url,
embeddings=embeddings,
)
for _, row in tqdm(df.iterrows()):
image_url = f"{os.environ['AWS_URL_ENDPOINT']}/{image_bucket}/{row[0]}"
store_embeddings(image_url)