Skip to content

Commit bd4cd86

Browse files
committed
MSR done
1 parent 79ca98c commit bd4cd86

File tree

4 files changed

+606
-21
lines changed

4 files changed

+606
-21
lines changed

config/main.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ BTC:
4040
plots_path: ${repo_root}plots/BTC/
4141

4242
MSRVTT:
43-
img_encoder: "clip"
44-
audio_encoder: "clap"
43+
img_encoder: "clip" # clip, imagebind
44+
audio_encoder: "clap" # clap, imagebind
4545
retrieval_dim: "" # we use all the dimensions for retrieval
46-
mask_ratio: 4 # ratio of the missing data : size of test data
46+
mask_ratio: 0 # ratio of the missing data : size of test data
4747
paths:
4848
dataset_path: "/nas/pohan/datasets/MSR-VTT/"
4949
# dataset_path: "/home/po-han/Downloads/MSR-VTT/"
@@ -58,7 +58,7 @@ KITTI:
5858
lidar_encoder: "liploc"
5959
text_encoder: "gtr"
6060
shuffle_step: 20
61-
mask_ratio: 2 # ratio of the missing data : size of test data
61+
mask_ratio: 0 # ratio of the missing data : size of test data
6262
paths:
6363
dataset_path: "/nas/pohan/datasets/KITTI/"
6464
save_path: ${KITTI.paths.dataset_path}embeddings/

mmda/baselines/emma_kitti_class.py renamed to mmda/baselines/emma/emma_kitti_class.py

+29-16
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ def __init__(self, cfg: DictConfig) -> None:
2121
Args:
2222
cfg: configuration file
2323
"""
24-
super().__init__()
2524
np.random.seed(0)
2625
self.cfg = cfg
2726

@@ -38,6 +37,7 @@ def __init__(self, cfg: DictConfig) -> None:
3837
self.shuffle_step = cfg["KITTI"].shuffle_step
3938
self.save_tag = f"_thres_{Args.threshold_dist}_shuffle_{self.shuffle_step}"
4039
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40+
self.model_path = Path(self.cfg["KITTI"].paths.save_path) / "models"
4141

4242
def preprocess_retrieval_data(self) -> None:
4343
"""Preprocess the data for retrieval."""
@@ -89,11 +89,17 @@ def preprocess_retrieval_data(self) -> None:
8989
}
9090

9191
# masking missing data in the test set. Mask the whole modality of an instance at a time.
92-
mask_num = int(self.test_size / self.cfg_dataset.mask_ratio)
93-
self.mask = {} # modality -> masked idx
94-
self.mask[0] = np.random.choice(self.test_size, mask_num, replace=False)
95-
self.mask[1] = np.random.choice(self.test_size, mask_num, replace=False)
96-
self.mask[2] = np.random.choice(self.test_size, mask_num, replace=False)
92+
if self.cfg_dataset.mask_ratio > 0:
93+
mask_num = int(self.test_size / self.cfg_dataset.mask_ratio)
94+
self.mask = {} # modality -> masked idx
95+
self.mask[0] = np.random.choice(self.test_size, mask_num, replace=False)
96+
self.mask[1] = np.random.choice(self.test_size, mask_num, replace=False)
97+
self.mask[2] = np.random.choice(self.test_size, mask_num, replace=False)
98+
else:
99+
self.mask = {} # modality -> masked idx
100+
self.mask[0] = []
101+
self.mask[1] = []
102+
self.mask[2] = []
97103

98104
def train_crossmodal_similarity( # noqa: C901, PLR0912
99105
self, max_epoch: int
@@ -111,8 +117,7 @@ def train_crossmodal_similarity( # noqa: C901, PLR0912
111117
lr=0.001,
112118
)
113119

114-
model_path = Path(self.cfg["KITTI"].paths.save_path) / "models"
115-
model_path.mkdir(parents=True, exist_ok=True)
120+
self.model_path.mkdir(parents=True, exist_ok=True)
116121
ds_retrieval_cls = KITTI_file_Retrieval()
117122

118123
for epoch in range(max_epoch):
@@ -180,32 +185,40 @@ def train_crossmodal_similarity( # noqa: C901, PLR0912
180185
if (epoch + 1) % 5 == 0: # Save models per 5 epochs
181186
torch.save(
182187
self.img_fc.state_dict(),
183-
model_path + f"img_fc_epoch_{epoch+1}.pth",
188+
str(self.model_path / f"img_fc_epoch_{epoch+1}.pth"),
184189
)
185190
torch.save(
186191
self.lidar_fc.state_dict(),
187-
model_path + f"lidar_fc_epoch_{epoch+1}.pth",
192+
str(self.model_path / f"lidar_fc_epoch_{epoch+1}.pth"),
188193
)
189194
torch.save(
190195
self.txt_fc.state_dict(),
191-
model_path + f"txt_fc_epoch_{epoch+1}.pth",
196+
str(self.model_path / f"txt_fc_epoch_{epoch+1}.pth"),
192197
)
193198
print(f"Models saved at epoch {epoch+1}")
194199

195200
def load_fc_models(self, epoch: int) -> None:
196201
"""Load the fc models."""
197-
model_path = self.cfg["KITTI"].paths.save_path + "models/"
198202
self.define_fc_networks(output_dim=256)
199203
self.img_fc.load_state_dict(
200-
torch.load(model_path + f"img_fc_epoch_{epoch}.pth", weights_only=True)
204+
torch.load(
205+
str(self.model_path / f"img_fc_epoch_{epoch}.pth"),
206+
weights_only=True,
207+
)
201208
)
202209
self.img_fc.to(self.device)
203210
self.lidar_fc.load_state_dict(
204-
torch.load(model_path + f"lidar_fc_epoch_{epoch}.pth", weights_only=True)
211+
torch.load(
212+
str(self.model_path / f"lidar_fc_epoch_{epoch}.pth"),
213+
weights_only=True,
214+
)
205215
)
206216
self.lidar_fc.to(self.device)
207217
self.txt_fc.load_state_dict(
208-
torch.load(model_path + f"txt_fc_epoch_{epoch}.pth", weights_only=True)
218+
torch.load(
219+
str(self.model_path / f"txt_fc_epoch_{epoch}.pth"),
220+
weights_only=True,
221+
)
209222
)
210223
self.txt_fc.to(self.device)
211224

@@ -472,7 +485,7 @@ def retrieve_data(
472485

473486

474487
if __name__ == "__main__":
475-
# CUDA_VISIBLE_DEVICES=2 poetry run python mmda/baselines/emma_ds_class.py
488+
# CUDA_VISIBLE_DEVICES=4 poetry run python mmda/baselines/emma/emma_kitti_class.py
476489
import pandas as pd
477490
from omegaconf import OmegaConf
478491

0 commit comments

Comments
 (0)