Skip to content

Commit ab5fa8d

Browse files
committed
Optimize plot
1 parent cb8119c commit ab5fa8d

File tree

4 files changed

+16
-17
lines changed

4 files changed

+16
-17
lines changed

config/main.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ KITTI:
4545
img_encoder: "liploc"
4646
lidar_encoder: "liploc"
4747
text_encoder: "gtr"
48-
shuffle_step: 0
48+
shuffle_step: 20
4949
mask_ratio: 2 # ratio of the missing data : size of test data
5050
paths:
5151
dataset_path: "/nas/pohan/datasets/KITTI/"

mmda/any2any_conformal_retrieval.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def main(cfg: DictConfig) -> None:
6363
df.to_csv(df_path, index=False)
6464

6565
if cfg.dataset == "KITTI":
66-
# plot heatmap of single modality retrieval
6766
single_recalls = np.array(list(single1_recalls.values())).reshape(3, 3) * 100
6867
plt.figure(figsize=(8, 8))
6968
ax = sns.heatmap(
@@ -78,7 +77,7 @@ def main(cfg: DictConfig) -> None:
7877
annot_kws={"size": 26, "weight": "bold"},
7978
)
8079
elif cfg.dataset == "MSRVTT":
81-
# plot heatmap of single modality retrieval
80+
dir_path = dir_path / f"{cfg_dataset.img_encoder}_{cfg_dataset.audio_encoder}"
8281
single_recalls = np.array(list(single1_recalls.values())).reshape(1, 2) * 100
8382
plt.figure(figsize=(8, 6))
8483
ax = sns.heatmap(
@@ -90,20 +89,20 @@ def main(cfg: DictConfig) -> None:
9089
xticklabels=["Image", "Audio"],
9190
yticklabels=["Text"],
9291
annot=True,
93-
annot_kws={"size": 26, "weight": "bold"},
92+
annot_kws={"size": 30, "weight": "bold"},
9493
)
9594
else:
9695
msg = f"unknown dataset {cfg.dataset}"
9796
raise ValueError(msg)
9897
ax.xaxis.tick_top()
99-
plt.xlabel("Reference modality", fontsize=30)
100-
plt.ylabel("Query modality", fontsize=30)
101-
plt.xticks(fontsize=26)
102-
plt.yticks(fontsize=26)
98+
plt.xlabel("Reference modality", fontsize=34)
99+
plt.ylabel("Query modality", fontsize=34)
100+
plt.xticks(fontsize=30)
101+
plt.yticks(fontsize=30)
103102
ax.xaxis.set_label_position("top") # Move the label to the top
104103
plt.savefig(
105104
dir_path
106-
/ f"single_modal_recall5_{cfg_dataset.retrieval_dim}_{cfg_dataset.mask_ratio}{thres_tag}.png"
105+
/ f"single_modal_recall5_{cfg_dataset.retrieval_dim}_{cfg_dataset.mask_ratio}{thres_tag}.pdf"
107106
)
108107

109108

mmda/plot_nonconformity.ipynb

Lines changed: 7 additions & 7 deletions
Large diffs are not rendered by default.

mmda/utils/liploc_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class CFG:
7777
class Args:
7878
expid: str = "exp_default"
7979
eval_sequence = ["04", "05", "06", "07", "08", "09", "10"]
80-
threshold_dist: int = 20
80+
threshold_dist: int = 5
8181

8282

8383
model_import_path = f"mmda.liploc.models.{CFG.model}"

0 commit comments

Comments
 (0)