1
1
"""Get feature embeddings for the datasets."""
2
2
3
- # ruff: noqa: ERA001
3
+ # ruff: noqa: ERA001, PLR2004
4
4
import os
5
5
import pickle
6
6
from pathlib import Path
39
39
40
40
41
41
def get_video_emb (
42
- cfg_dataset : DictConfig , video_dict : dict , use_kaggle : bool = False
43
- ) -> dict [ str , np . ndarray ]:
42
+ cfg_dataset : DictConfig , video_dict : dict
43
+ ) -> tuple [ list [ str ], list [ str ], list [ str ] ]:
44
44
"""Get video embeddings for the videos in the video_dict.
45
45
46
46
Args:
@@ -49,35 +49,23 @@ def get_video_emb(
49
49
use_kaggle: whether to use the kaggle dataset
50
50
51
51
Returns:
52
- video embeddings. dict: video_id -> video_embedding (if use_kaggle)
53
- img_paths. list: list of image paths (if not use_kaggle)
52
+ id_order: list of video ids
53
+ first_img_paths: list of first image paths
54
+ last_img_paths: list of last image paths
54
55
"""
55
- # skip image embeddings (CLIP is already done from the dataset)
56
- # load the existing embeddings
57
- if use_kaggle :
58
- video_emb = {}
59
- for video_ids in tqdm (video_dict , desc = "Loading video embeddings" ):
60
- video_np_path = Path (
61
- cfg_dataset .paths .dataset_path ,
62
- f"clip-features-vit-h14/{ video_ids } .npy" ,
63
- )
64
- # only sample the first and last frame
65
- video_np = np .load (video_np_path )[[0 , - 1 ], :].reshape (1 , - 1 )
66
- video_emb [video_ids ] = video_np
67
- return video_emb
68
56
id_order = []
69
57
first_img_paths = []
70
58
last_img_paths = []
71
- for video_ids in tqdm (sorted (video_dict ), desc = "loading keyframe paths" ):
59
+ for video_id in tqdm (sorted (video_dict ), desc = "loading keyframe paths" ):
72
60
# video_ids from 7010 to 7990
73
- img_dir = Path (cfg_dataset .paths .dataset_path , "keyframes" , video_ids )
61
+ img_dir = Path (cfg_dataset .paths .dataset_path , "keyframes" , video_id )
74
62
num_frames = len (os .listdir (img_dir ))
75
63
for frame_id in range (0 , num_frames , 2 ):
76
64
if frame_id + 1 >= num_frames :
77
65
break
78
66
first_img_path = img_dir / f"{ frame_id :04d} .jpg"
79
67
last_img_path = img_dir / f"{ frame_id + 1 :04d} .jpg"
80
- id_order .append (video_ids )
68
+ id_order .append (video_id )
81
69
first_img_paths .append (str (first_img_path ))
82
70
last_img_paths .append (str (last_img_path ))
83
71
return id_order , first_img_paths , last_img_paths
@@ -135,7 +123,7 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
135
123
_ , captions , video_info_sen_order , video_dict = load_msrvtt (cfg_dataset )
136
124
137
125
id_order , first_img_paths , last_img_paths = get_video_emb (
138
- cfg_dataset , video_dict , use_kaggle = False
126
+ cfg_dataset , video_dict
139
127
)
140
128
141
129
# get audio embeddings
@@ -190,24 +178,20 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
190
178
audio_np = []
191
179
img_np = []
192
180
print (len (id_order ))
193
- for i in tqdm (range (0 , len (id_order ), BATCH_SIZE ), desc = "imagebind" ):
181
+ for i in tqdm (range (0 , len (id_order ), BATCH_SIZE ), desc = "imagebind inference " ):
194
182
audios = audio_paths [i : i + BATCH_SIZE ]
195
183
first_images = first_img_paths [i : i + BATCH_SIZE ]
196
184
last_images = last_img_paths [i : i + BATCH_SIZE ]
197
- # audio_embs = imagebind_class.inference_audio(audios).cpu().numpy()
198
- # first_embs = imagebind_class.inference_image(first_images).cpu().numpy()
199
- first_embs , audio_embs = imagebind_class .inference_image_audio (
200
- first_images , audios
201
- )
202
- first_embs = first_embs .cpu ().numpy ()
203
- audio_embs = audio_embs .cpu ().numpy ()
185
+ audio_embs = imagebind_class .inference_audio (audios ).cpu ().numpy ()
186
+ first_embs = imagebind_class .inference_image (first_images ).cpu ().numpy ()
204
187
last_embs = imagebind_class .inference_image (last_images ).cpu ().numpy ()
205
188
img_embs = np .concatenate ([first_embs , last_embs ], axis = 1 )
206
189
audio_np .append (audio_embs )
207
190
img_np .append (img_embs )
208
- # print(img_embs.shape, audio_embs.shape)
209
- audio_np = np .array (audio_np )
210
- img_np = np .array (img_np )
191
+ assert img_embs .shape [1 ] == 2048 , f"img.shape: { img_embs .shape } , { i } "
192
+ assert audio_embs .shape [1 ] == 1024 , f"audio.shape: { audio_embs .shape } , { i } "
193
+ audio_np = np .concatenate (audio_np , axis = 0 )
194
+ img_np = np .concatenate (img_np , axis = 0 )
211
195
with Path (cfg_dataset .paths .save_path , "MSRVTT_audio_emb_imagebind.pkl" ).open (
212
196
"wb"
213
197
) as f :
@@ -218,7 +202,7 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
218
202
) as f :
219
203
pickle .dump (img_np , f )
220
204
print ("imagebind embeddings saved" )
221
- return
205
+ # return
222
206
223
207
shape = video_info_sen_order [0 ]["audio_np" ].shape
224
208
audio_np = [
@@ -237,7 +221,15 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
237
221
print ("CLAP embeddings saved" )
238
222
239
223
# get text embeddings
240
- text_emb = imagebind_class .inference_text (captions )
224
+ imagebind_class = ImageBindInference ()
225
+ text_emb = []
226
+ for i in tqdm (range (0 , len (captions ), BATCH_SIZE ), desc = "imagebind txt" ):
227
+ text_emb .append (
228
+ imagebind_class .inference_text (captions [i : i + BATCH_SIZE ])
229
+ .cpu ()
230
+ .numpy ()
231
+ )
232
+ text_emb = np .concatenate (text_emb , axis = 0 )
241
233
with Path (cfg_dataset .paths .save_path , "MSRVTT_text_emb_imagebind.pkl" ).open (
242
234
"wb"
243
235
) as f :
0 commit comments