@@ -83,6 +83,7 @@ def __init__(self, cfg: DictConfig) -> None:
83
83
self .cali_size = 3_800
84
84
self .train_size = 53_000 # TODO: no training data is needed for MSRVTT
85
85
self .test_size = 3_000
86
+ self .query_step = 5
86
87
self .img2txt_encoder = self .cfg_dataset .img_encoder
87
88
self .audio2txt_encoder = self .cfg_dataset .audio_encoder
88
89
self .save_tag = f"{ self .img2txt_encoder } _{ self .audio2txt_encoder } "
@@ -95,7 +96,7 @@ def load_data(self) -> None:
95
96
with Path (self .cfg_dataset .paths .save_path , "MSRVTT_id_order.pkl" ).open (
96
97
"rb"
97
98
) as f :
98
- self .ref_id_order = pickle .load (f ) # noqa: S301
99
+ self .ref_id_order = pickle .load (f )[:: self . query_step ] # noqa: S301
99
100
with Path (self .cfg_dataset .paths .save_path , "MSRVTT_null_audio.pkl" ).open (
100
101
"rb"
101
102
) as f :
@@ -176,16 +177,16 @@ def preprocess_retrieval_data(self) -> None:
176
177
"cali" : self .txt2img_emb [txt_cali_idx ],
177
178
}
178
179
self .img2txt_emb = {
179
- "test" : self .img2txt_emb ,
180
- "cali" : self .img2txt_emb ,
180
+ "test" : self .img2txt_emb [:: self . query_step ] ,
181
+ "cali" : self .img2txt_emb [:: self . query_step ] ,
181
182
}
182
183
self .txt2audio_emb = {
183
184
"test" : self .txt2audio_emb [txt_test_idx ],
184
185
"cali" : self .txt2audio_emb [txt_cali_idx ],
185
186
}
186
187
self .audio2txt_emb = {
187
- "test" : self .audio2txt_emb ,
188
- "cali" : self .audio2txt_emb ,
188
+ "test" : self .audio2txt_emb [:: self . query_step ] ,
189
+ "cali" : self .audio2txt_emb [:: self . query_step ] ,
189
190
}
190
191
# masking missing data in the test set. Mask the whole modality of an instance at a time.
191
192
if self .cfg_dataset .mask_ratio != 0 :
@@ -199,6 +200,9 @@ def preprocess_retrieval_data(self) -> None:
199
200
self .mask [0 ] = []
200
201
self .mask [1 ] = []
201
202
203
+ # check the length of the reference order
204
+ assert len (self .ref_id_order ) == self .audio2txt_emb ["test" ].shape [0 ]
205
+
202
206
def check_correct_retrieval (self , q_idx : int , r_idx : int ) -> bool :
203
207
"""Check if the retrieval is correct.
204
208
0 commit comments