6
6
from pathlib import Path
7
7
8
8
import numpy as np
9
+ import torch
10
+ import torchaudio
9
11
from omegaconf import DictConfig
10
12
from tqdm import tqdm
11
13
33
35
)
34
36
from mmda .utils .imagebind_utils import ImageBindInference
35
37
36
- BATCH_SIZE = 128
38
+ BATCH_SIZE = 256
37
39
38
40
39
41
def get_video_emb (
@@ -70,7 +72,7 @@ def get_video_emb(
70
72
# video_ids from 7010 to 7990
71
73
img_dir = Path (cfg_dataset .paths .dataset_path , "keyframes" , video_ids )
72
74
num_frames = len (os .listdir (img_dir ))
73
- for frame_id in range (num_frames , 2 ):
75
+ for frame_id in range (0 , num_frames , 2 ):
74
76
if frame_id + 1 >= num_frames :
75
77
break
76
78
first_img_path = img_dir / f"{ frame_id :04d} .jpg"
@@ -135,42 +137,75 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
135
137
id_order , first_img_paths , last_img_paths = get_video_emb (
136
138
cfg_dataset , video_dict , use_kaggle = False
137
139
)
138
- first_img_emb = clip_imgs (first_img_paths , BATCH_SIZE )
139
- last_img_emb = clip_imgs (last_img_paths , BATCH_SIZE )
140
- video_id_emb = np .concatenate ([first_img_emb , last_img_emb ], axis = 1 )
141
- with Path (cfg_dataset .paths .save_path , "MSRVTT_video_emb_clip.pkl" ).open (
142
- "wb"
143
- ) as f :
144
- pickle .dump (video_id_emb , f )
145
- print ("CLIP embeddings saved" )
146
- with Path (cfg_dataset .paths .save_path , "MSRVTT_ref_video_ids.pkl" ).open (
147
- "wb"
148
- ) as f :
149
- pickle .dump (id_order , f )
150
140
151
141
# get audio embeddings
152
- audio_paths = []
153
- for video_id in id_order :
154
- audio_path = str (
155
- Path (cfg_dataset .paths .dataset_path , f"TestVideo/{ video_id } .wav" )
156
- )
157
- audio_paths .append (audio_path )
142
+ if not (
143
+ Path (cfg_dataset .paths .save_path , "MSRVTT_id_order.pkl" ).exists
144
+ and Path (cfg_dataset .paths .save_path , "MSRVTT_null_audio.pkl" ).exists ()
145
+ and Path (cfg_dataset .paths .save_path , "MSRVTT_audio_paths.pkl" ).exists ()
146
+ ):
147
+ audio_paths , null_audio = [], []
148
+ for video_id in tqdm (id_order , desc = "process id_order" ):
149
+ audio_path = Path (
150
+ cfg_dataset .paths .dataset_path , f"TestVideo/{ video_id } .wav"
151
+ )
152
+ if (
153
+ not audio_path .exists ()
154
+ or torch .sum (torchaudio .load (str (audio_path ))[0 ]) == 0
155
+ ):
156
+ null_audio .append (True )
157
+ # just a placeholder for wav path
158
+ audio_paths .append (".assets/bird_audio.wav" )
159
+ else :
160
+ null_audio .append (False )
161
+ audio_paths .append (str (audio_path ))
162
+ with Path (cfg_dataset .paths .save_path , "MSRVTT_id_order.pkl" ).open (
163
+ "wb"
164
+ ) as f :
165
+ pickle .dump (id_order , f )
166
+ with Path (cfg_dataset .paths .save_path , "MSRVTT_null_audio.pkl" ).open (
167
+ "wb"
168
+ ) as f :
169
+ pickle .dump (null_audio , f )
170
+ with Path (cfg_dataset .paths .save_path , "MSRVTT_audio_paths.pkl" ).open (
171
+ "wb"
172
+ ) as f :
173
+ pickle .dump (audio_paths , f )
174
+ else :
175
+ with Path (cfg_dataset .paths .save_path , "MSRVTT_id_order.pkl" ).open (
176
+ "rb"
177
+ ) as f :
178
+ id_order = pickle .load (f ) # noqa: S301
179
+ with Path (cfg_dataset .paths .save_path , "MSRVTT_null_audio.pkl" ).open (
180
+ "rb"
181
+ ) as f :
182
+ null_audio = pickle .load (f ) # noqa: S301
183
+ with Path (cfg_dataset .paths .save_path , "MSRVTT_audio_paths.pkl" ).open (
184
+ "rb"
185
+ ) as f :
186
+ audio_paths = pickle .load (f ) # noqa: S301
187
+
158
188
# inference imagebind
159
- imagebind_class = ImageBindInference (device = 0 )
189
+ imagebind_class = ImageBindInference ()
160
190
audio_np = []
161
191
img_np = []
162
- for i in range (len (id_order )), BATCH_SIZE :
192
+ print (len (id_order ))
193
+ for i in tqdm (range (0 , len (id_order ), BATCH_SIZE ), desc = "imagebind" ):
163
194
audios = audio_paths [i : i + BATCH_SIZE ]
164
195
first_images = first_img_paths [i : i + BATCH_SIZE ]
165
196
last_images = last_img_paths [i : i + BATCH_SIZE ]
166
- audio_embs = imagebind_class .inference_audio (audios ).cpu ().numpy ()
167
- first_embs = (
168
- imagebind_class .inference_image_only (first_images ).cpu ().numpy ()
197
+ # audio_embs = imagebind_class.inference_audio(audios).cpu().numpy()
198
+ # first_embs = imagebind_class.inference_image(first_images).cpu().numpy()
199
+ first_embs , audio_embs = imagebind_class .inference_image_audio (
200
+ first_images , audios
169
201
)
170
- last_embs = imagebind_class .inference_image_only (last_images ).cpu ().numpy ()
202
+ first_embs = first_embs .cpu ().numpy ()
203
+ audio_embs = audio_embs .cpu ().numpy ()
204
+ last_embs = imagebind_class .inference_image (last_images ).cpu ().numpy ()
171
205
img_embs = np .concatenate ([first_embs , last_embs ], axis = 1 )
172
206
audio_np .append (audio_embs )
173
207
img_np .append (img_embs )
208
+ # print(img_embs.shape, audio_embs.shape)
174
209
audio_np = np .array (audio_np )
175
210
img_np = np .array (img_np )
176
211
with Path (cfg_dataset .paths .save_path , "MSRVTT_audio_emb_imagebind.pkl" ).open (
@@ -183,6 +218,7 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
183
218
) as f :
184
219
pickle .dump (img_np , f )
185
220
print ("imagebind embeddings saved" )
221
+ return
186
222
187
223
shape = video_info_sen_order [0 ]["audio_np" ].shape
188
224
audio_np = [
@@ -511,4 +547,4 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
511
547
512
548
if __name__ == "__main__" :
513
549
main ()
514
- # CUDA_VISIBLE_DEVICES=0 poetry run python mmda/get_embeddings.py
550
+ # CUDA_VISIBLE_DEVICES=5 poetry run python mmda/get_embeddings.py
0 commit comments