Skip to content

Commit c920678

Browse files
committed
kitti retrieval running
1 parent 64ea30e commit c920678

File tree

2 files changed

+65
-36
lines changed

2 files changed

+65
-36
lines changed

mmda/utils/emma_ds_class.py renamed to mmda/baselines/emma_kitti_class.py

+63-34
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Dataset class for any2any retrieval task."""
22

3-
import os
3+
from pathlib import Path
44

55
import numpy as np
66
import torch
@@ -37,6 +37,7 @@ def __init__(self, cfg: DictConfig) -> None:
3737
self.shape = (3, 3) # shape of the similarity matrix
3838
self.shuffle_step = cfg["KITTI"].shuffle_step
3939
self.save_tag = f"_thres_{Args.threshold_dist}_shuffle_{self.shuffle_step}"
40+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4041

4142
def preprocess_retrieval_data(self) -> None:
4243
"""Preprocess the data for retrieval."""
@@ -94,34 +95,34 @@ def preprocess_retrieval_data(self) -> None:
9495
self.mask[1] = np.random.choice(self.test_size, mask_num, replace=False)
9596
self.mask[2] = np.random.choice(self.test_size, mask_num, replace=False)
9697

97-
def train_crossmodal_similarity(self, max_epoch: int) -> None: # noqa: C901
98+
def train_crossmodal_similarity( # noqa: C901, PLR0912
99+
self, max_epoch: int
100+
) -> None:
98101
"""Train the cross-modal similarity, aka the CSA method."""
99-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
100-
101102
data_loader = self.get_joint_dataloader(batch_size=256, num_workers=4)
102103
self.define_fc_networks(output_dim=256)
103-
self.img_fc.to(device)
104-
self.lidar_fc.to(device)
105-
self.txt_fc.to(device)
104+
self.img_fc.to(self.device)
105+
self.lidar_fc.to(self.device)
106+
self.txt_fc.to(self.device)
106107
self.optimizer = torch.optim.Adam(
107108
list(self.img_fc.parameters())
108109
+ list(self.lidar_fc.parameters())
109110
+ list(self.txt_fc.parameters()),
110111
lr=0.001,
111112
)
112113

113-
model_path = self.cfg["KITTI"].paths.save_path + "models/"
114-
os.makedirs(model_path, exist_ok=True)
114+
model_path = Path(self.cfg["KITTI"].paths.save_path) / "models"
115+
model_path.mkdir(parents=True, exist_ok=True)
115116
ds_retrieval_cls = KITTI_file_Retrieval()
116117

117118
for epoch in range(max_epoch):
118119
for _, (img, lidar, txt, orig_idx) in enumerate(data_loader):
119120
bs = img.shape[0]
120-
img_embed = self.img_fc(img.to(device))
121-
lidar_embed = self.lidar_fc(lidar.to(device))
122-
txt_embed = self.txt_fc(txt.to(device))
121+
img_embed = self.img_fc(img.to(self.device))
122+
lidar_embed = self.lidar_fc(lidar.to(self.device))
123+
txt_embed = self.txt_fc(txt.to(self.device))
123124
three_embed = torch.stack([img_embed, lidar_embed, txt_embed], dim=0)
124-
loss = torch.tensor(0.0, device=device, requires_grad=True)
125+
loss = torch.tensor(0.0, device=self.device, requires_grad=True)
125126

126127
# get gt labels once
127128
gt_labels = {}
@@ -194,14 +195,36 @@ def train_crossmodal_similarity(self, max_epoch: int) -> None: # noqa: C901
194195
def load_fc_models(self, epoch: int) -> None:
195196
"""Load the fc models."""
196197
model_path = self.cfg["KITTI"].paths.save_path + "models/"
197-
self.img_fc = torch.load(model_path + f"img_fc_epoch_{epoch}.pth")
198-
self.lidar_fc = torch.load(model_path + f"lidar_fc_epoch_{epoch}.pth")
199-
self.txt_fc = torch.load(model_path + f"txt_fc_epoch_{epoch}.pth")
198+
self.define_fc_networks(output_dim=256)
199+
self.img_fc.load_state_dict(
200+
torch.load(model_path + f"img_fc_epoch_{epoch}.pth", weights_only=True)
201+
)
202+
self.img_fc.to(self.device)
203+
self.lidar_fc.load_state_dict(
204+
torch.load(model_path + f"lidar_fc_epoch_{epoch}.pth", weights_only=True)
205+
)
206+
self.lidar_fc.to(self.device)
207+
self.txt_fc.load_state_dict(
208+
torch.load(model_path + f"txt_fc_epoch_{epoch}.pth", weights_only=True)
209+
)
210+
self.txt_fc.to(self.device)
200211

201212
def transform_with_fc(
202-
self, img: torch.Tensor, lidar: torch.Tensor, txt: torch.Tensor
213+
self,
214+
img: torch.Tensor | np.ndarray,
215+
lidar: torch.Tensor | np.ndarray,
216+
txt: torch.Tensor | np.ndarray,
203217
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
204218
"""Transform the data with the fc networks."""
219+
if isinstance(img, np.ndarray):
220+
img = torch.tensor(img)
221+
if isinstance(lidar, np.ndarray):
222+
lidar = torch.tensor(lidar)
223+
if isinstance(txt, np.ndarray):
224+
txt = torch.tensor(txt)
225+
img = img.to(self.device)
226+
lidar = lidar.to(self.device)
227+
txt = txt.to(self.device)
205228
self.img_fc.eval()
206229
self.lidar_fc.eval()
207230
self.txt_fc.eval()
@@ -327,6 +350,8 @@ def eval_similarity(
327350
q_feats[q_modality].reshape(1, -1),
328351
r_feats[r_modality].reshape(1, -1),
329352
)
353+
if cnt == 0:
354+
return -1
330355
return sim_score / cnt
331356

332357
def retrieve_data(
@@ -348,10 +373,12 @@ def retrieve_data(
348373
maps = {5: [], 20: []}
349374
ds_retrieval_cls = KITTI_file_Retrieval()
350375

351-
for idx_q in tqdm(
352-
self.test_idx,
353-
desc="Retrieving data",
354-
leave=True,
376+
for ii, idx_q in enumerate(
377+
tqdm(
378+
self.test_idx,
379+
desc="Retrieving data",
380+
leave=True,
381+
)
355382
):
356383
ds_idx_q = self.shuffle2idx[idx_q]
357384
retrieved_pairs = []
@@ -361,29 +388,29 @@ def retrieve_data(
361388
for modality in range(3):
362389
if ds_idx_q in self.mask[modality]:
363390
q_missing_modalities.append(modality)
364-
q_feats = np.concatenate(
391+
q_feats = np.stack(
365392
[
366-
self.imgdata["test"][ds_idx_q],
367-
self.lidardata["test"][ds_idx_q],
368-
self.txtdata["test"][ds_idx_q],
393+
self.imgdata["test"][ii].reshape(1, -1),
394+
self.lidardata["test"][ii].reshape(1, -1),
395+
self.txtdata["test"][ii].reshape(1, -1),
369396
],
370397
axis=0,
371398
)
372399
assert q_feats.shape[0:2] == (3, 1), f"{q_feats.shape}"
373400

374-
for idx_r in self.test_idx:
401+
for jj, idx_r in enumerate(self.test_idx):
375402
if idx_r == idx_q: # cannot retrieve itself
376403
continue
377404
ds_idx_r = self.shuffle2idx[idx_r]
378405
r_missing_modalities = []
379406
for modality in range(3):
380407
if ds_idx_r in self.mask[modality]:
381408
r_missing_modalities.append(modality)
382-
r_feats = np.concatenate(
409+
r_feats = np.stack(
383410
[
384-
self.imgdata["test"][ds_idx_r],
385-
self.lidardata["test"][ds_idx_r],
386-
self.txtdata["test"][ds_idx_r],
411+
self.imgdata["test"][jj].reshape(1, -1),
412+
self.lidardata["test"][jj].reshape(1, -1),
413+
self.txtdata["test"][jj].reshape(1, -1),
387414
],
388415
axis=0,
389416
)
@@ -433,18 +460,20 @@ def retrieve_data(
433460

434461

435462
if __name__ == "__main__":
436-
# CUDA_VISIBLE_DEVICES=2 poetry run python mmda/utils/emma_ds_class.py
463+
# CUDA_VISIBLE_DEVICES=2 poetry run python mmda/baselines/emma_ds_class.py
437464
from omegaconf import OmegaConf
438465

439466
cfg = OmegaConf.load("config/main.yaml")
440467
ds = KITTIEMMADataset(cfg)
441468
ds.preprocess_retrieval_data()
442-
ds.train_crossmodal_similarity(max_epoch=100)
443-
exit()
469+
if False:
470+
ds.train_crossmodal_similarity(max_epoch=100)
444471
ds.load_fc_models(epoch=100)
445472
img_transformed, lidar_transformed, txt_transformed = ds.transform_with_fc(
446473
ds.imgdata["test"], ds.lidardata["test"], ds.txtdata["test"]
447474
)
448-
print(img_transformed.shape, lidar_transformed.shape, txt_transformed.shape)
475+
ds.imgdata["test"] = img_transformed
476+
ds.lidardata["test"] = lidar_transformed
477+
ds.txtdata["test"] = txt_transformed
449478
maps, precisions, recalls = ds.retrieve_data()
450479
print(maps, precisions, recalls)

mmda/utils/sim_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def cosine_sim(x: np.ndarray, y: np.ndarray) -> np.ndarray:
1616
assert (
1717
x.shape == y.shape
1818
), f"x and y should have the same number of shape, but got {x.shape} and {y.shape}"
19-
x = x / np.linalg.norm(x, axis=1, keepdims=True)
20-
y = y / np.linalg.norm(y, axis=1, keepdims=True)
19+
x = x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-10)
20+
y = y / (np.linalg.norm(y, axis=1, keepdims=True) + 1e-10)
2121
return np.sum(x * y, axis=1)
2222

2323

0 commit comments

Comments
 (0)