Skip to content

Commit 1015a89

Browse files
committed
Refactor MSRVTTDataset class and remove unused variables and logic
1 parent f1f0e69 commit 1015a89

File tree

1 file changed

+13
-26
lines changed

1 file changed

+13
-26
lines changed

mmda/utils/mstvtt_ds_class.py

+13-26
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def process_similarity_pair(inputs: tuple) -> dict:
2929
range_idx_q (range): The index of the query.
3030
r_size (int): The size of the retrieval range.
3131
idx_offset (int): The offset for the indices.
32-
step_size (int): The step size for the retrieval indices.
3332
3433
Returns:
3534
dict: A dictionary containing the cosine similarities and ground truth labels for each pair.
@@ -43,7 +42,6 @@ def process_similarity_pair(inputs: tuple) -> dict:
4342
range_idx_q,
4443
r_size,
4544
idx_offset,
46-
step_size,
4745
) = inputs
4846
sim_mat_i = {}
4947
for idx_q in tqdm(range_idx_q):
@@ -82,10 +80,9 @@ def __init__(self, cfg: DictConfig) -> None:
8280
self.text2audio = "clap"
8381

8482
self.shape = (1, 2) # shape of the similarity matrix
85-
self.cali_size = 3800
86-
self.train_size = 53_000 # no training data is needed for MSRVTT
83+
self.cali_size = 3_800
84+
self.train_size = 53_000 # TODO: no training data is needed for MSRVTT
8785
self.test_size = 3_000
88-
self.step_size = 20 # 20 duplicates of different captions of a video
8986
self.img2txt_encoder = self.cfg_dataset.img_encoder
9087
self.audio2txt_encoder = self.cfg_dataset.audio_encoder
9188
self.save_tag = f"{self.img2txt_encoder}_{self.audio2txt_encoder}"
@@ -95,21 +92,16 @@ def load_data(self) -> None:
9592
self.sen_ids, self.captions, self.video_info_sen_order, self.video_dict = (
9693
load_msrvtt(self.cfg_dataset)
9794
)
98-
with Path(self.cfg_dataset.paths.save_path, "MSRVTT_ref_video_ids.pkl").open(
95+
with Path(self.cfg_dataset.paths.save_path, "MSRVTT_id_order.pkl").open(
9996
"rb"
100-
) as file:
101-
self.ref_id_order = pickle.load(file) # noqa: S301
102-
null_audio_idx = [
103-
self.video_dict[video_id]["audio_np"] is None
104-
for video_id in self.ref_id_order
105-
]
97+
) as f:
98+
self.ref_id_order = pickle.load(f) # noqa: S301
99+
with Path(self.cfg_dataset.paths.save_path, "MSRVTT_null_audio.pkl").open(
100+
"rb"
101+
) as f:
102+
# get video idx which has no audio. 355 in total.
103+
null_audio_idx = pickle.load(f) # noqa: S301 list of bool in ref_id_order
106104

107-
# get video idx which has no audio. 355 in total.
108-
# TODO: video7010 has torch.zeros wav files.
109-
null_audio_idx = []
110-
for idx, video_info in enumerate(self.video_info_sen_order):
111-
if video_info["audio_np"] is None and idx % self.step_size == 0:
112-
null_audio_idx.append(int(idx / self.step_size))
113105
# load data
114106
with Path(
115107
self.cfg_dataset.paths.save_path
@@ -176,8 +168,7 @@ def preprocess_retrieval_data(self) -> None:
176168
), f"{self.test_size} + {self.cali_size} + {self.train_size} != {self.num_data}"
177169

178170
# train/test/calibration split only on the query size (59_800)
179-
# Shuffle the array to ensure randomness
180-
idx = np.arange(self.num_data) # 2990
171+
idx = np.arange(self.num_data) # 59800
181172
txt_test_idx = idx[self.train_size : self.cali_size]
182173
txt_cali_idx = idx[-self.cali_size :]
183174
self.txt2img_emb = {
@@ -218,16 +209,13 @@ def check_correct_retrieval(self, q_idx: int, r_idx: int) -> bool:
218209
Returns:
219210
True if the retrieval is correct, False otherwise
220211
"""
221-
return (
222-
self.video_info_sen_order[q_idx]["video_id"]
223-
== self.ref_id_order[r_idx]["video_id"]
224-
)
212+
return self.video_info_sen_order[q_idx]["video_id"] == self.ref_id_order[r_idx]
225213

226214
def calculate_pairs_data_similarity(
227215
self,
228216
data_lists: list[np.ndarray],
229217
idx_offset: int,
230-
num_workers: int = 2,
218+
num_workers: int = 1,
231219
) -> np.ndarray:
232220
"""Calculate the similarity matrix for the pairs of modalities.
233221
@@ -264,7 +252,6 @@ def calculate_pairs_data_similarity(
264252
range_idx_q,
265253
r_size,
266254
idx_offset,
267-
self.step_size,
268255
)
269256
for range_idx_q in range_idx_qs
270257
],

0 commit comments

Comments
 (0)