1
1
"""Dataset class for any2any retrieval task."""
2
2
3
- import os
3
+ from pathlib import Path
4
4
5
5
import numpy as np
6
6
import torch
@@ -37,6 +37,7 @@ def __init__(self, cfg: DictConfig) -> None:
37
37
self .shape = (3 , 3 ) # shape of the similarity matrix
38
38
self .shuffle_step = cfg ["KITTI" ].shuffle_step
39
39
self .save_tag = f"_thres_{ Args .threshold_dist } _shuffle_{ self .shuffle_step } "
40
+ self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
40
41
41
42
def preprocess_retrieval_data (self ) -> None :
42
43
"""Preprocess the data for retrieval."""
@@ -94,34 +95,34 @@ def preprocess_retrieval_data(self) -> None:
94
95
self .mask [1 ] = np .random .choice (self .test_size , mask_num , replace = False )
95
96
self .mask [2 ] = np .random .choice (self .test_size , mask_num , replace = False )
96
97
97
- def train_crossmodal_similarity (self , max_epoch : int ) -> None : # noqa: C901
98
+ def train_crossmodal_similarity ( # noqa: C901, PLR0912
99
+ self , max_epoch : int
100
+ ) -> None :
98
101
"""Train the cross-modal similarity, aka the CSA method."""
99
- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
100
-
101
102
data_loader = self .get_joint_dataloader (batch_size = 256 , num_workers = 4 )
102
103
self .define_fc_networks (output_dim = 256 )
103
- self .img_fc .to (device )
104
- self .lidar_fc .to (device )
105
- self .txt_fc .to (device )
104
+ self .img_fc .to (self . device )
105
+ self .lidar_fc .to (self . device )
106
+ self .txt_fc .to (self . device )
106
107
self .optimizer = torch .optim .Adam (
107
108
list (self .img_fc .parameters ())
108
109
+ list (self .lidar_fc .parameters ())
109
110
+ list (self .txt_fc .parameters ()),
110
111
lr = 0.001 ,
111
112
)
112
113
113
- model_path = self .cfg ["KITTI" ].paths .save_path + "models/ "
114
- os . makedirs ( model_path , exist_ok = True )
114
+ model_path = Path ( self .cfg ["KITTI" ].paths .save_path ) / "models"
115
+ model_path . mkdir ( parents = True , exist_ok = True )
115
116
ds_retrieval_cls = KITTI_file_Retrieval ()
116
117
117
118
for epoch in range (max_epoch ):
118
119
for _ , (img , lidar , txt , orig_idx ) in enumerate (data_loader ):
119
120
bs = img .shape [0 ]
120
- img_embed = self .img_fc (img .to (device ))
121
- lidar_embed = self .lidar_fc (lidar .to (device ))
122
- txt_embed = self .txt_fc (txt .to (device ))
121
+ img_embed = self .img_fc (img .to (self . device ))
122
+ lidar_embed = self .lidar_fc (lidar .to (self . device ))
123
+ txt_embed = self .txt_fc (txt .to (self . device ))
123
124
three_embed = torch .stack ([img_embed , lidar_embed , txt_embed ], dim = 0 )
124
- loss = torch .tensor (0.0 , device = device , requires_grad = True )
125
+ loss = torch .tensor (0.0 , device = self . device , requires_grad = True )
125
126
126
127
# get gt labels once
127
128
gt_labels = {}
@@ -194,14 +195,36 @@ def train_crossmodal_similarity(self, max_epoch: int) -> None: # noqa: C901
194
195
def load_fc_models (self , epoch : int ) -> None :
195
196
"""Load the fc models."""
196
197
model_path = self .cfg ["KITTI" ].paths .save_path + "models/"
197
- self .img_fc = torch .load (model_path + f"img_fc_epoch_{ epoch } .pth" )
198
- self .lidar_fc = torch .load (model_path + f"lidar_fc_epoch_{ epoch } .pth" )
199
- self .txt_fc = torch .load (model_path + f"txt_fc_epoch_{ epoch } .pth" )
198
+ self .define_fc_networks (output_dim = 256 )
199
+ self .img_fc .load_state_dict (
200
+ torch .load (model_path + f"img_fc_epoch_{ epoch } .pth" , weights_only = True )
201
+ )
202
+ self .img_fc .to (self .device )
203
+ self .lidar_fc .load_state_dict (
204
+ torch .load (model_path + f"lidar_fc_epoch_{ epoch } .pth" , weights_only = True )
205
+ )
206
+ self .lidar_fc .to (self .device )
207
+ self .txt_fc .load_state_dict (
208
+ torch .load (model_path + f"txt_fc_epoch_{ epoch } .pth" , weights_only = True )
209
+ )
210
+ self .txt_fc .to (self .device )
200
211
201
212
def transform_with_fc (
202
- self , img : torch .Tensor , lidar : torch .Tensor , txt : torch .Tensor
213
+ self ,
214
+ img : torch .Tensor | np .ndarray ,
215
+ lidar : torch .Tensor | np .ndarray ,
216
+ txt : torch .Tensor | np .ndarray ,
203
217
) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
204
218
"""Transform the data with the fc networks."""
219
+ if isinstance (img , np .ndarray ):
220
+ img = torch .tensor (img )
221
+ if isinstance (lidar , np .ndarray ):
222
+ lidar = torch .tensor (lidar )
223
+ if isinstance (txt , np .ndarray ):
224
+ txt = torch .tensor (txt )
225
+ img = img .to (self .device )
226
+ lidar = lidar .to (self .device )
227
+ txt = txt .to (self .device )
205
228
self .img_fc .eval ()
206
229
self .lidar_fc .eval ()
207
230
self .txt_fc .eval ()
@@ -327,6 +350,8 @@ def eval_similarity(
327
350
q_feats [q_modality ].reshape (1 , - 1 ),
328
351
r_feats [r_modality ].reshape (1 , - 1 ),
329
352
)
353
+ if cnt == 0 :
354
+ return - 1
330
355
return sim_score / cnt
331
356
332
357
def retrieve_data (
@@ -348,10 +373,12 @@ def retrieve_data(
348
373
maps = {5 : [], 20 : []}
349
374
ds_retrieval_cls = KITTI_file_Retrieval ()
350
375
351
- for idx_q in tqdm (
352
- self .test_idx ,
353
- desc = "Retrieving data" ,
354
- leave = True ,
376
+ for ii , idx_q in enumerate (
377
+ tqdm (
378
+ self .test_idx ,
379
+ desc = "Retrieving data" ,
380
+ leave = True ,
381
+ )
355
382
):
356
383
ds_idx_q = self .shuffle2idx [idx_q ]
357
384
retrieved_pairs = []
@@ -361,29 +388,29 @@ def retrieve_data(
361
388
for modality in range (3 ):
362
389
if ds_idx_q in self .mask [modality ]:
363
390
q_missing_modalities .append (modality )
364
- q_feats = np .concatenate (
391
+ q_feats = np .stack (
365
392
[
366
- self .imgdata ["test" ][ds_idx_q ] ,
367
- self .lidardata ["test" ][ds_idx_q ] ,
368
- self .txtdata ["test" ][ds_idx_q ] ,
393
+ self .imgdata ["test" ][ii ]. reshape ( 1 , - 1 ) ,
394
+ self .lidardata ["test" ][ii ]. reshape ( 1 , - 1 ) ,
395
+ self .txtdata ["test" ][ii ]. reshape ( 1 , - 1 ) ,
369
396
],
370
397
axis = 0 ,
371
398
)
372
399
assert q_feats .shape [0 :2 ] == (3 , 1 ), f"{ q_feats .shape } "
373
400
374
- for idx_r in self .test_idx :
401
+ for jj , idx_r in enumerate ( self .test_idx ) :
375
402
if idx_r == idx_q : # cannot retrieve itself
376
403
continue
377
404
ds_idx_r = self .shuffle2idx [idx_r ]
378
405
r_missing_modalities = []
379
406
for modality in range (3 ):
380
407
if ds_idx_r in self .mask [modality ]:
381
408
r_missing_modalities .append (modality )
382
- r_feats = np .concatenate (
409
+ r_feats = np .stack (
383
410
[
384
- self .imgdata ["test" ][ds_idx_r ] ,
385
- self .lidardata ["test" ][ds_idx_r ] ,
386
- self .txtdata ["test" ][ds_idx_r ] ,
411
+ self .imgdata ["test" ][jj ]. reshape ( 1 , - 1 ) ,
412
+ self .lidardata ["test" ][jj ]. reshape ( 1 , - 1 ) ,
413
+ self .txtdata ["test" ][jj ]. reshape ( 1 , - 1 ) ,
387
414
],
388
415
axis = 0 ,
389
416
)
@@ -433,18 +460,20 @@ def retrieve_data(
433
460
434
461
435
462
if __name__ == "__main__" :
436
- # CUDA_VISIBLE_DEVICES=2 poetry run python mmda/utils /emma_ds_class.py
463
+ # CUDA_VISIBLE_DEVICES=2 poetry run python mmda/baselines /emma_ds_class.py
437
464
from omegaconf import OmegaConf
438
465
439
466
cfg = OmegaConf .load ("config/main.yaml" )
440
467
ds = KITTIEMMADataset (cfg )
441
468
ds .preprocess_retrieval_data ()
442
- ds . train_crossmodal_similarity ( max_epoch = 100 )
443
- exit ( )
469
+ if False :
470
+ ds . train_crossmodal_similarity ( max_epoch = 100 )
444
471
ds .load_fc_models (epoch = 100 )
445
472
img_transformed , lidar_transformed , txt_transformed = ds .transform_with_fc (
446
473
ds .imgdata ["test" ], ds .lidardata ["test" ], ds .txtdata ["test" ]
447
474
)
448
- print (img_transformed .shape , lidar_transformed .shape , txt_transformed .shape )
475
+ ds .imgdata ["test" ] = img_transformed
476
+ ds .lidardata ["test" ] = lidar_transformed
477
+ ds .txtdata ["test" ] = txt_transformed
449
478
maps , precisions , recalls = ds .retrieve_data ()
450
479
print (maps , precisions , recalls )
0 commit comments