Skip to content

Commit e71de51

Browse files
committed
add faiss
1 parent e0aa645 commit e71de51

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

mmda/utils/mstvtt_ds_class.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from multiprocessing import Pool
66
from pathlib import Path
77

8+
import faiss
89
import numpy as np
910
from omegaconf import DictConfig
1011
from tqdm import tqdm
@@ -80,9 +81,9 @@ def __init__(self, cfg: DictConfig) -> None:
8081
self.text2audio = "clap"
8182

8283
self.shape = (1, 2) # shape of the similarity matrix
83-
self.cali_size = 3_800
84+
self.cali_size = 800
8485
self.train_size = 53_000 # TODO: no training data is needed for MSRVTT
85-
self.test_size = 3_000
86+
self.test_size = 6_000
8687
self.query_step = 5
8788
self.img2txt_encoder = self.cfg_dataset.img_encoder
8889
self.audio2txt_encoder = self.cfg_dataset.audio_encoder
@@ -202,6 +203,12 @@ def preprocess_retrieval_data(self) -> None:
202203

203204
# check the length of the reference order
204205
assert len(self.ref_id_order) == self.audio2txt_emb["test"].shape[0]
206+
# build the faiss index for the test set
207+
red_video_ids = np.array(self.ref_id_order, dtype="int64") # Faiss requires int64 for IDs
208+
self.audio2txt_faiss = faiss.IndexFlatIP(self.audio2txt_emb["test"].shape[1])
209+
self.audio2txt_faiss.add_with_ids(self.audio2txt_emb["test"], red_video_ids)
210+
self.img2txt_faiss = faiss.IndexFlatIP(self.img2txt_emb["test"].shape[1])
211+
self.img2txt_faiss.add_with_ids(self.img2txt_emb["test"], red_video_ids)
205212

206213
def check_correct_retrieval(self, q_idx: int, r_idx: int) -> bool:
207214
"""Check if the retrieval is correct.

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ timm = "*"
6969
albumentations = "*"
7070
kaggle = "*"
7171
moviepy = "*"
72+
faiss-gpu = "*"
7273
imagebind = {git = "https://github.com/facebookresearch/ImageBind"}
7374
# LLaVA = {git = "https://github.com/haotian-liu/LLaVA.git"} # contradicting with imagebind
7475

0 commit comments

Comments
 (0)