Skip to content

Commit f1f0e69

Browse files
committed
feat: Add support for imagebind embeddings with MSRVTT dataset
1 parent 62da752 commit f1f0e69

File tree

2 files changed

+29
-36
lines changed

2 files changed

+29
-36
lines changed

config/main.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,5 @@ asif: # Hyperparameters of asif baseline
169169
# gtr: 768
170170
# dino: 1536
171171
# clip: 1280 (1024 for msrvtt)
172-
# liploc: 256
172+
# liploc: 256
173+
# imagebind: 1024

mmda/get_embeddings.py

+27-35
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Get feature embeddings for the datasets."""
22

3-
# ruff: noqa: ERA001
3+
# ruff: noqa: ERA001, PLR2004
44
import os
55
import pickle
66
from pathlib import Path
@@ -39,8 +39,8 @@
3939

4040

4141
def get_video_emb(
42-
cfg_dataset: DictConfig, video_dict: dict, use_kaggle: bool = False
43-
) -> dict[str, np.ndarray]:
42+
cfg_dataset: DictConfig, video_dict: dict
43+
) -> tuple[list[str], list[str], list[str]]:
4444
"""Get video embeddings for the videos in the video_dict.
4545
4646
Args:
@@ -49,35 +49,23 @@ def get_video_emb(
4949
use_kaggle: whether to use the kaggle dataset
5050
5151
Returns:
52-
video embeddings. dict: video_id -> video_embedding (if use_kaggle)
53-
img_paths. list: list of image paths (if not use_kaggle)
52+
id_order: list of video ids
53+
first_img_paths: list of first image paths
54+
last_img_paths: list of last image paths
5455
"""
55-
# skip image embeddings (CLIP is already done from the dataset)
56-
# load the existing embeddings
57-
if use_kaggle:
58-
video_emb = {}
59-
for video_ids in tqdm(video_dict, desc="Loading video embeddings"):
60-
video_np_path = Path(
61-
cfg_dataset.paths.dataset_path,
62-
f"clip-features-vit-h14/{video_ids}.npy",
63-
)
64-
# only sample the first and last frame
65-
video_np = np.load(video_np_path)[[0, -1], :].reshape(1, -1)
66-
video_emb[video_ids] = video_np
67-
return video_emb
6856
id_order = []
6957
first_img_paths = []
7058
last_img_paths = []
71-
for video_ids in tqdm(sorted(video_dict), desc="loading keyframe paths"):
59+
for video_id in tqdm(sorted(video_dict), desc="loading keyframe paths"):
7260
# video_ids from 7010 to 7990
73-
img_dir = Path(cfg_dataset.paths.dataset_path, "keyframes", video_ids)
61+
img_dir = Path(cfg_dataset.paths.dataset_path, "keyframes", video_id)
7462
num_frames = len(os.listdir(img_dir))
7563
for frame_id in range(0, num_frames, 2):
7664
if frame_id + 1 >= num_frames:
7765
break
7866
first_img_path = img_dir / f"{frame_id:04d}.jpg"
7967
last_img_path = img_dir / f"{frame_id + 1:04d}.jpg"
80-
id_order.append(video_ids)
68+
id_order.append(video_id)
8169
first_img_paths.append(str(first_img_path))
8270
last_img_paths.append(str(last_img_path))
8371
return id_order, first_img_paths, last_img_paths
@@ -135,7 +123,7 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
135123
_, captions, video_info_sen_order, video_dict = load_msrvtt(cfg_dataset)
136124

137125
id_order, first_img_paths, last_img_paths = get_video_emb(
138-
cfg_dataset, video_dict, use_kaggle=False
126+
cfg_dataset, video_dict
139127
)
140128

141129
# get audio embeddings
@@ -190,24 +178,20 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
190178
audio_np = []
191179
img_np = []
192180
print(len(id_order))
193-
for i in tqdm(range(0, len(id_order), BATCH_SIZE), desc="imagebind"):
181+
for i in tqdm(range(0, len(id_order), BATCH_SIZE), desc="imagebind inference"):
194182
audios = audio_paths[i : i + BATCH_SIZE]
195183
first_images = first_img_paths[i : i + BATCH_SIZE]
196184
last_images = last_img_paths[i : i + BATCH_SIZE]
197-
# audio_embs = imagebind_class.inference_audio(audios).cpu().numpy()
198-
# first_embs = imagebind_class.inference_image(first_images).cpu().numpy()
199-
first_embs, audio_embs = imagebind_class.inference_image_audio(
200-
first_images, audios
201-
)
202-
first_embs = first_embs.cpu().numpy()
203-
audio_embs = audio_embs.cpu().numpy()
185+
audio_embs = imagebind_class.inference_audio(audios).cpu().numpy()
186+
first_embs = imagebind_class.inference_image(first_images).cpu().numpy()
204187
last_embs = imagebind_class.inference_image(last_images).cpu().numpy()
205188
img_embs = np.concatenate([first_embs, last_embs], axis=1)
206189
audio_np.append(audio_embs)
207190
img_np.append(img_embs)
208-
# print(img_embs.shape, audio_embs.shape)
209-
audio_np = np.array(audio_np)
210-
img_np = np.array(img_np)
191+
assert img_embs.shape[1] == 2048, f"img.shape: {img_embs.shape}, {i}"
192+
assert audio_embs.shape[1] == 1024, f"audio.shape: {audio_embs.shape}, {i}"
193+
audio_np = np.concatenate(audio_np, axis=0)
194+
img_np = np.concatenate(img_np, axis=0)
211195
with Path(cfg_dataset.paths.save_path, "MSRVTT_audio_emb_imagebind.pkl").open(
212196
"wb"
213197
) as f:
@@ -218,7 +202,7 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
218202
) as f:
219203
pickle.dump(img_np, f)
220204
print("imagebind embeddings saved")
221-
return
205+
# return
222206

223207
shape = video_info_sen_order[0]["audio_np"].shape
224208
audio_np = [
@@ -237,7 +221,15 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
237221
print("CLAP embeddings saved")
238222

239223
# get text embeddings
240-
text_emb = imagebind_class.inference_text(captions)
224+
imagebind_class = ImageBindInference()
225+
text_emb = []
226+
for i in tqdm(range(0, len(captions), BATCH_SIZE), desc="imagebind txt"):
227+
text_emb.append(
228+
imagebind_class.inference_text(captions[i : i + BATCH_SIZE])
229+
.cpu()
230+
.numpy()
231+
)
232+
text_emb = np.concatenate(text_emb, axis=0)
241233
with Path(cfg_dataset.paths.save_path, "MSRVTT_text_emb_imagebind.pkl").open(
242234
"wb"
243235
) as f:

0 commit comments

Comments
 (0)