Skip to content

Commit 62da752

Browse files
committed
Still getting video embeddings
1 parent fd1520a commit 62da752

File tree

6 files changed

+251
-82
lines changed

6 files changed

+251
-82
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,4 +171,4 @@ plots/*
171171
# lock files
172172
*.lock
173173
.checkpoints/
174-
ImageBind/
174+
.assets/

mmda/get_embeddings.py

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from pathlib import Path
77

88
import numpy as np
9+
import torch
10+
import torchaudio
911
from omegaconf import DictConfig
1012
from tqdm import tqdm
1113

@@ -33,7 +35,7 @@
3335
)
3436
from mmda.utils.imagebind_utils import ImageBindInference
3537

36-
BATCH_SIZE = 128
38+
BATCH_SIZE = 256
3739

3840

3941
def get_video_emb(
@@ -70,7 +72,7 @@ def get_video_emb(
7072
# video_ids from 7010 to 7990
7173
img_dir = Path(cfg_dataset.paths.dataset_path, "keyframes", video_ids)
7274
num_frames = len(os.listdir(img_dir))
73-
for frame_id in range(num_frames, 2):
75+
for frame_id in range(0, num_frames, 2):
7476
if frame_id + 1 >= num_frames:
7577
break
7678
first_img_path = img_dir / f"{frame_id:04d}.jpg"
@@ -135,42 +137,75 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
135137
id_order, first_img_paths, last_img_paths = get_video_emb(
136138
cfg_dataset, video_dict, use_kaggle=False
137139
)
138-
first_img_emb = clip_imgs(first_img_paths, BATCH_SIZE)
139-
last_img_emb = clip_imgs(last_img_paths, BATCH_SIZE)
140-
video_id_emb = np.concatenate([first_img_emb, last_img_emb], axis=1)
141-
with Path(cfg_dataset.paths.save_path, "MSRVTT_video_emb_clip.pkl").open(
142-
"wb"
143-
) as f:
144-
pickle.dump(video_id_emb, f)
145-
print("CLIP embeddings saved")
146-
with Path(cfg_dataset.paths.save_path, "MSRVTT_ref_video_ids.pkl").open(
147-
"wb"
148-
) as f:
149-
pickle.dump(id_order, f)
150140

151141
# get audio embeddings
152-
audio_paths = []
153-
for video_id in id_order:
154-
audio_path = str(
155-
Path(cfg_dataset.paths.dataset_path, f"TestVideo/{video_id}.wav")
156-
)
157-
audio_paths.append(audio_path)
142+
if not (
143+
Path(cfg_dataset.paths.save_path, "MSRVTT_id_order.pkl").exists
144+
and Path(cfg_dataset.paths.save_path, "MSRVTT_null_audio.pkl").exists()
145+
and Path(cfg_dataset.paths.save_path, "MSRVTT_audio_paths.pkl").exists()
146+
):
147+
audio_paths, null_audio = [], []
148+
for video_id in tqdm(id_order, desc="process id_order"):
149+
audio_path = Path(
150+
cfg_dataset.paths.dataset_path, f"TestVideo/{video_id}.wav"
151+
)
152+
if (
153+
not audio_path.exists()
154+
or torch.sum(torchaudio.load(str(audio_path))[0]) == 0
155+
):
156+
null_audio.append(True)
157+
# just a placeholder for wav path
158+
audio_paths.append(".assets/bird_audio.wav")
159+
else:
160+
null_audio.append(False)
161+
audio_paths.append(str(audio_path))
162+
with Path(cfg_dataset.paths.save_path, "MSRVTT_id_order.pkl").open(
163+
"wb"
164+
) as f:
165+
pickle.dump(id_order, f)
166+
with Path(cfg_dataset.paths.save_path, "MSRVTT_null_audio.pkl").open(
167+
"wb"
168+
) as f:
169+
pickle.dump(null_audio, f)
170+
with Path(cfg_dataset.paths.save_path, "MSRVTT_audio_paths.pkl").open(
171+
"wb"
172+
) as f:
173+
pickle.dump(audio_paths, f)
174+
else:
175+
with Path(cfg_dataset.paths.save_path, "MSRVTT_id_order.pkl").open(
176+
"rb"
177+
) as f:
178+
id_order = pickle.load(f) # noqa: S301
179+
with Path(cfg_dataset.paths.save_path, "MSRVTT_null_audio.pkl").open(
180+
"rb"
181+
) as f:
182+
null_audio = pickle.load(f) # noqa: S301
183+
with Path(cfg_dataset.paths.save_path, "MSRVTT_audio_paths.pkl").open(
184+
"rb"
185+
) as f:
186+
audio_paths = pickle.load(f) # noqa: S301
187+
158188
# inference imagebind
159-
imagebind_class = ImageBindInference(device=0)
189+
imagebind_class = ImageBindInference()
160190
audio_np = []
161191
img_np = []
162-
for i in range(len(id_order)), BATCH_SIZE:
192+
print(len(id_order))
193+
for i in tqdm(range(0, len(id_order), BATCH_SIZE), desc="imagebind"):
163194
audios = audio_paths[i : i + BATCH_SIZE]
164195
first_images = first_img_paths[i : i + BATCH_SIZE]
165196
last_images = last_img_paths[i : i + BATCH_SIZE]
166-
audio_embs = imagebind_class.inference_audio(audios).cpu().numpy()
167-
first_embs = (
168-
imagebind_class.inference_image_only(first_images).cpu().numpy()
197+
# audio_embs = imagebind_class.inference_audio(audios).cpu().numpy()
198+
# first_embs = imagebind_class.inference_image(first_images).cpu().numpy()
199+
first_embs, audio_embs = imagebind_class.inference_image_audio(
200+
first_images, audios
169201
)
170-
last_embs = imagebind_class.inference_image_only(last_images).cpu().numpy()
202+
first_embs = first_embs.cpu().numpy()
203+
audio_embs = audio_embs.cpu().numpy()
204+
last_embs = imagebind_class.inference_image(last_images).cpu().numpy()
171205
img_embs = np.concatenate([first_embs, last_embs], axis=1)
172206
audio_np.append(audio_embs)
173207
img_np.append(img_embs)
208+
# print(img_embs.shape, audio_embs.shape)
174209
audio_np = np.array(audio_np)
175210
img_np = np.array(img_np)
176211
with Path(cfg_dataset.paths.save_path, "MSRVTT_audio_emb_imagebind.pkl").open(
@@ -183,6 +218,7 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
183218
) as f:
184219
pickle.dump(img_np, f)
185220
print("imagebind embeddings saved")
221+
return
186222

187223
shape = video_info_sen_order[0]["audio_np"].shape
188224
audio_np = [
@@ -511,4 +547,4 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
511547

512548
if __name__ == "__main__":
513549
main()
514-
# CUDA_VISIBLE_DEVICES=0 poetry run python mmda/get_embeddings.py
550+
# CUDA_VISIBLE_DEVICES=5 poetry run python mmda/get_embeddings.py

mmda/utils/dataset_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def load_msrvtt(
9999
"category": category,
100100
"url": url,
101101
}
102-
num_processes = 32
102+
num_processes = 64
103103
p = Pool(processes=num_processes)
104104
print("num_processes:", num_processes)
105105
data = p.map(

mmda/utils/imagebind_utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010

1111

1212
class ImageBindInference:
13-
def __init__(self, device: int = 0):
14-
self.device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
13+
def __init__(self):
14+
self.device = f"cuda" if torch.cuda.is_available() else "cpu"
1515
self.model = imagebind_model.imagebind_huge(pretrained=True)
1616
self.model.eval()
1717
self.model.to(self.device)
1818

19-
def inference_audio(self, image_paths, audio_paths):
19+
def inference_audio(self, audio_paths):
2020
inputs = {
2121
ModalityType.AUDIO: load_and_transform_audio_data(audio_paths, self.device),
2222
}
@@ -44,3 +44,15 @@ def inference_text(self, text_list):
4444
with torch.no_grad():
4545
embeddings = self.model(inputs)
4646
return embeddings[ModalityType.TEXT]
47+
48+
def inference_image_audio(self, image_paths, audio_paths):
49+
inputs = {
50+
ModalityType.VISION: load_and_transform_vision_data(
51+
image_paths, self.device
52+
),
53+
ModalityType.AUDIO: load_and_transform_audio_data(audio_paths, self.device),
54+
}
55+
56+
with torch.no_grad():
57+
embeddings = self.model(inputs)
58+
return embeddings[ModalityType.VISION], embeddings[ModalityType.AUDIO]

mmda/utils/mstvtt_ds_class.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def load_data(self) -> None:
105105
]
106106

107107
# get video idx which has no audio. 355 in total.
108+
# TODO: video7010 has torch.zeros wav files.
108109
null_audio_idx = []
109110
for idx, video_info in enumerate(self.video_info_sen_order):
110111
if video_info["audio_np"] is None and idx % self.step_size == 0:
@@ -114,23 +115,23 @@ def load_data(self) -> None:
114115
self.cfg_dataset.paths.save_path
115116
+ f"MSRVTT_text_emb_{self.img2txt_encoder}.pkl"
116117
).open("rb") as file:
117-
self.txt2img_emb = pickle.load(file) # (59800, 1280) # noqa: S301
118+
self.txt2img_emb = pickle.load(file) # (59800,) # noqa: S301
118119
with Path(
119120
self.cfg_dataset.paths.save_path
120121
+ f"MSRVTT_video_emb_{self.img2txt_encoder}.pkl"
121122
).open("rb") as file:
122-
self.img2txt_emb = pickle.load(file) # noqa: S301
123+
self.img2txt_emb = pickle.load(file) # (47392,) # noqa: S301
123124
print(self.img2txt_emb.shape)
124125
with Path(
125126
self.cfg_dataset.paths.save_path
126127
+ f"MSRVTT_text_emb_{self.audio2txt_encoder}.pkl"
127128
).open("rb") as file:
128-
self.txt2audio_emb = pickle.load(file) # (59800, 512) # noqa: S301
129+
self.txt2audio_emb = pickle.load(file) # (59800,) # noqa: S301
129130
with Path(
130131
self.cfg_dataset.paths.save_path
131132
+ f"MSRVTT_audio_emb_{self.audio2txt_encoder}.pkl"
132133
).open("rb") as file:
133-
self.audio2txt_emb = pickle.load(file) # (???, 512) # noqa: S301
134+
self.audio2txt_emb = pickle.load(file) # (47392,) # noqa: S301
134135
print(self.audio2txt_emb.shape)
135136

136137
# normalize all the embeddings to have unit norm using L2 normalization

0 commit comments

Comments
 (0)