@@ -21,7 +21,6 @@ def __init__(self, cfg: DictConfig) -> None:
21
21
Args:
22
22
cfg: configuration file
23
23
"""
24
- super ().__init__ ()
25
24
np .random .seed (0 )
26
25
self .cfg = cfg
27
26
@@ -38,6 +37,7 @@ def __init__(self, cfg: DictConfig) -> None:
38
37
self .shuffle_step = cfg ["KITTI" ].shuffle_step
39
38
self .save_tag = f"_thres_{ Args .threshold_dist } _shuffle_{ self .shuffle_step } "
40
39
self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
40
+ self .model_path = Path (self .cfg ["KITTI" ].paths .save_path ) / "models"
41
41
42
42
def preprocess_retrieval_data (self ) -> None :
43
43
"""Preprocess the data for retrieval."""
@@ -89,11 +89,17 @@ def preprocess_retrieval_data(self) -> None:
89
89
}
90
90
91
91
# masking missing data in the test set. Mask the whole modality of an instance at a time.
92
- mask_num = int (self .test_size / self .cfg_dataset .mask_ratio )
93
- self .mask = {} # modality -> masked idx
94
- self .mask [0 ] = np .random .choice (self .test_size , mask_num , replace = False )
95
- self .mask [1 ] = np .random .choice (self .test_size , mask_num , replace = False )
96
- self .mask [2 ] = np .random .choice (self .test_size , mask_num , replace = False )
92
+ if self .cfg_dataset .mask_ratio > 0 :
93
+ mask_num = int (self .test_size / self .cfg_dataset .mask_ratio )
94
+ self .mask = {} # modality -> masked idx
95
+ self .mask [0 ] = np .random .choice (self .test_size , mask_num , replace = False )
96
+ self .mask [1 ] = np .random .choice (self .test_size , mask_num , replace = False )
97
+ self .mask [2 ] = np .random .choice (self .test_size , mask_num , replace = False )
98
+ else :
99
+ self .mask = {} # modality -> masked idx
100
+ self .mask [0 ] = []
101
+ self .mask [1 ] = []
102
+ self .mask [2 ] = []
97
103
98
104
def train_crossmodal_similarity ( # noqa: C901, PLR0912
99
105
self , max_epoch : int
@@ -111,8 +117,7 @@ def train_crossmodal_similarity( # noqa: C901, PLR0912
111
117
lr = 0.001 ,
112
118
)
113
119
114
- model_path = Path (self .cfg ["KITTI" ].paths .save_path ) / "models"
115
- model_path .mkdir (parents = True , exist_ok = True )
120
+ self .model_path .mkdir (parents = True , exist_ok = True )
116
121
ds_retrieval_cls = KITTI_file_Retrieval ()
117
122
118
123
for epoch in range (max_epoch ):
@@ -180,32 +185,40 @@ def train_crossmodal_similarity( # noqa: C901, PLR0912
180
185
if (epoch + 1 ) % 5 == 0 : # Save models per 5 epochs
181
186
torch .save (
182
187
self .img_fc .state_dict (),
183
- model_path + f"img_fc_epoch_{ epoch + 1 } .pth" ,
188
+ str ( self . model_path / f"img_fc_epoch_{ epoch + 1 } .pth" ) ,
184
189
)
185
190
torch .save (
186
191
self .lidar_fc .state_dict (),
187
- model_path + f"lidar_fc_epoch_{ epoch + 1 } .pth" ,
192
+ str ( self . model_path / f"lidar_fc_epoch_{ epoch + 1 } .pth" ) ,
188
193
)
189
194
torch .save (
190
195
self .txt_fc .state_dict (),
191
- model_path + f"txt_fc_epoch_{ epoch + 1 } .pth" ,
196
+ str ( self . model_path / f"txt_fc_epoch_{ epoch + 1 } .pth" ) ,
192
197
)
193
198
print (f"Models saved at epoch { epoch + 1 } " )
194
199
195
200
def load_fc_models (self , epoch : int ) -> None :
196
201
"""Load the fc models."""
197
- model_path = self .cfg ["KITTI" ].paths .save_path + "models/"
198
202
self .define_fc_networks (output_dim = 256 )
199
203
self .img_fc .load_state_dict (
200
- torch .load (model_path + f"img_fc_epoch_{ epoch } .pth" , weights_only = True )
204
+ torch .load (
205
+ str (self .model_path / f"img_fc_epoch_{ epoch } .pth" ),
206
+ weights_only = True ,
207
+ )
201
208
)
202
209
self .img_fc .to (self .device )
203
210
self .lidar_fc .load_state_dict (
204
- torch .load (model_path + f"lidar_fc_epoch_{ epoch } .pth" , weights_only = True )
211
+ torch .load (
212
+ str (self .model_path / f"lidar_fc_epoch_{ epoch } .pth" ),
213
+ weights_only = True ,
214
+ )
205
215
)
206
216
self .lidar_fc .to (self .device )
207
217
self .txt_fc .load_state_dict (
208
- torch .load (model_path + f"txt_fc_epoch_{ epoch } .pth" , weights_only = True )
218
+ torch .load (
219
+ str (self .model_path / f"txt_fc_epoch_{ epoch } .pth" ),
220
+ weights_only = True ,
221
+ )
209
222
)
210
223
self .txt_fc .to (self .device )
211
224
@@ -472,7 +485,7 @@ def retrieve_data(
472
485
473
486
474
487
if __name__ == "__main__" :
475
- # CUDA_VISIBLE_DEVICES=2 poetry run python mmda/baselines/emma_ds_class .py
488
+ # CUDA_VISIBLE_DEVICES=4 poetry run python mmda/baselines/emma/emma_kitti_class .py
476
489
import pandas as pd
477
490
from omegaconf import OmegaConf
478
491
0 commit comments