Skip to content

Commit e0aa645

Browse files
committed
Shrink size of query set
1 parent 1015a89 commit e0aa645

File tree

4 files changed

+15
-9
lines changed

4 files changed

+15
-9
lines changed

config/main.yaml

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ defaults:
44
seed: 42
55
train_test_ratio: 0.7
66
noisy_train_set: True
7-
repo_root: "/home/pl22767/Project/MMDA/"
7+
# repo_root: "/home/pl22767/Project/MMDA/"
8+
repo_root: "/home/po-han/Desktop/Projects/MMDA/"
89

910
dataset: "MSRVTT"
1011
dataset_level_datasets: [pitts, imagenet, cosmos, sop, tiil, musiccaps, flickr]
@@ -32,7 +33,8 @@ MSRVTT:
3233
retrieval_dim: "" # we use all the dimensions for retrieval
3334
mask_ratio: 0 # ratio of the missing data : size of test data
3435
paths:
35-
dataset_path: "/nas/pohan/datasets/MSR-VTT/"
36+
# dataset_path: "/nas/pohan/datasets/MSR-VTT/"
37+
dataset_path: "/home/po-han/Downloads/MSR-VTT/"
3638
save_path: ${MSRVTT.paths.dataset_path}embeddings/
3739
plots_path: ${repo_root}plots/MSR-VTT/
3840

mmda/baselines/asif_core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def normalize_sparse(
8888
return torch.sparse_coo_tensor(tensor_idx, v.t().flatten(), tensor.shape)
8989

9090

91-
def zero_shot_classification( # noqa: PLR0913
91+
def zero_shot_classification(
9292
zimgs: torch.Tensor,
9393
ztxts: torch.Tensor,
9494
aimgs: torch.Tensor,

mmda/utils/mstvtt_ds_class.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(self, cfg: DictConfig) -> None:
8383
self.cali_size = 3_800
8484
self.train_size = 53_000 # TODO: no training data is needed for MSRVTT
8585
self.test_size = 3_000
86+
self.query_step = 5
8687
self.img2txt_encoder = self.cfg_dataset.img_encoder
8788
self.audio2txt_encoder = self.cfg_dataset.audio_encoder
8889
self.save_tag = f"{self.img2txt_encoder}_{self.audio2txt_encoder}"
@@ -95,7 +96,7 @@ def load_data(self) -> None:
9596
with Path(self.cfg_dataset.paths.save_path, "MSRVTT_id_order.pkl").open(
9697
"rb"
9798
) as f:
98-
self.ref_id_order = pickle.load(f) # noqa: S301
99+
self.ref_id_order = pickle.load(f)[:: self.query_step] # noqa: S301
99100
with Path(self.cfg_dataset.paths.save_path, "MSRVTT_null_audio.pkl").open(
100101
"rb"
101102
) as f:
@@ -176,16 +177,16 @@ def preprocess_retrieval_data(self) -> None:
176177
"cali": self.txt2img_emb[txt_cali_idx],
177178
}
178179
self.img2txt_emb = {
179-
"test": self.img2txt_emb,
180-
"cali": self.img2txt_emb,
180+
"test": self.img2txt_emb[:: self.query_step],
181+
"cali": self.img2txt_emb[:: self.query_step],
181182
}
182183
self.txt2audio_emb = {
183184
"test": self.txt2audio_emb[txt_test_idx],
184185
"cali": self.txt2audio_emb[txt_cali_idx],
185186
}
186187
self.audio2txt_emb = {
187-
"test": self.audio2txt_emb,
188-
"cali": self.audio2txt_emb,
188+
"test": self.audio2txt_emb[:: self.query_step],
189+
"cali": self.audio2txt_emb[:: self.query_step],
189190
}
190191
# masking missing data in the test set. Mask the whole modality of an instance at a time.
191192
if self.cfg_dataset.mask_ratio != 0:
@@ -199,6 +200,9 @@ def preprocess_retrieval_data(self) -> None:
199200
self.mask[0] = []
200201
self.mask[1] = []
201202

203+
# check the length of the reference order
204+
assert len(self.ref_id_order) == self.audio2txt_emb["test"].shape[0]
205+
202206
def check_correct_retrieval(self, q_idx: int, r_idx: int) -> bool:
203207
"""Check if the retrieval is correct.
204208

ruff.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ select = [
8181
"RUF",
8282
]
8383

84-
ignore = ["ANN101","ANN102","COM","EXE","PD","S307","FBT001","FBT002","G004","ISC001","S101","T201","NPY002","I001"]
84+
ignore = ["ANN101","ANN102","COM","EXE","PD","S307","FBT001","FBT002","G004","ISC001","S101","T201","NPY002","I001","PLR0913"]
8585

8686

8787
# Allow fix for all enabled rules (when `--fix`) is provided.

0 commit comments

Comments
 (0)