@@ -83,6 +83,7 @@ def __init__(self, cfg: DictConfig) -> None:
8383 self .cali_size = 3_800
8484 self .train_size = 53_000 # TODO: no training data is needed for MSRVTT
8585 self .test_size = 3_000
86+ self .query_step = 5
8687 self .img2txt_encoder = self .cfg_dataset .img_encoder
8788 self .audio2txt_encoder = self .cfg_dataset .audio_encoder
8889 self .save_tag = f"{ self .img2txt_encoder } _{ self .audio2txt_encoder } "
@@ -95,7 +96,7 @@ def load_data(self) -> None:
9596 with Path (self .cfg_dataset .paths .save_path , "MSRVTT_id_order.pkl" ).open (
9697 "rb"
9798 ) as f :
98- self .ref_id_order = pickle .load (f ) # noqa: S301
99+ self .ref_id_order = pickle .load (f )[:: self . query_step ] # noqa: S301
99100 with Path (self .cfg_dataset .paths .save_path , "MSRVTT_null_audio.pkl" ).open (
100101 "rb"
101102 ) as f :
@@ -176,16 +177,16 @@ def preprocess_retrieval_data(self) -> None:
176177 "cali" : self .txt2img_emb [txt_cali_idx ],
177178 }
178179 self .img2txt_emb = {
179- "test" : self .img2txt_emb ,
180- "cali" : self .img2txt_emb ,
180+ "test" : self .img2txt_emb [:: self . query_step ] ,
181+ "cali" : self .img2txt_emb [:: self . query_step ] ,
181182 }
182183 self .txt2audio_emb = {
183184 "test" : self .txt2audio_emb [txt_test_idx ],
184185 "cali" : self .txt2audio_emb [txt_cali_idx ],
185186 }
186187 self .audio2txt_emb = {
187- "test" : self .audio2txt_emb ,
188- "cali" : self .audio2txt_emb ,
188+ "test" : self .audio2txt_emb [:: self . query_step ] ,
189+ "cali" : self .audio2txt_emb [:: self . query_step ] ,
189190 }
190191 # masking missing data in the test set. Mask the whole modality of an instance at a time.
191192 if self .cfg_dataset .mask_ratio != 0 :
@@ -199,6 +200,9 @@ def preprocess_retrieval_data(self) -> None:
199200 self .mask [0 ] = []
200201 self .mask [1 ] = []
201202
203+ # check the length of the reference order
204+ assert len (self .ref_id_order ) == self .audio2txt_emb ["test" ].shape [0 ]
205+
202206 def check_correct_retrieval (self , q_idx : int , r_idx : int ) -> bool :
203207 """Check if the retrieval is correct.
204208
0 commit comments