@@ -63,7 +63,6 @@ def main(cfg: DictConfig) -> None:
63
63
df .to_csv (df_path , index = False )
64
64
65
65
if cfg .dataset == "KITTI" :
66
- # plot heatmap of single modality retrieval
67
66
single_recalls = np .array (list (single1_recalls .values ())).reshape (3 , 3 ) * 100
68
67
plt .figure (figsize = (8 , 8 ))
69
68
ax = sns .heatmap (
@@ -78,7 +77,7 @@ def main(cfg: DictConfig) -> None:
78
77
annot_kws = {"size" : 26 , "weight" : "bold" },
79
78
)
80
79
elif cfg .dataset == "MSRVTT" :
81
- # plot heatmap of single modality retrieval
80
+ dir_path = dir_path / f" { cfg_dataset . img_encoder } _ { cfg_dataset . audio_encoder } "
82
81
single_recalls = np .array (list (single1_recalls .values ())).reshape (1 , 2 ) * 100
83
82
plt .figure (figsize = (8 , 6 ))
84
83
ax = sns .heatmap (
@@ -90,20 +89,20 @@ def main(cfg: DictConfig) -> None:
90
89
xticklabels = ["Image" , "Audio" ],
91
90
yticklabels = ["Text" ],
92
91
annot = True ,
93
- annot_kws = {"size" : 26 , "weight" : "bold" },
92
+ annot_kws = {"size" : 30 , "weight" : "bold" },
94
93
)
95
94
else :
96
95
msg = f"unknown dataset { cfg .dataset } "
97
96
raise ValueError (msg )
98
97
ax .xaxis .tick_top ()
99
- plt .xlabel ("Reference modality" , fontsize = 30 )
100
- plt .ylabel ("Query modality" , fontsize = 30 )
101
- plt .xticks (fontsize = 26 )
102
- plt .yticks (fontsize = 26 )
98
+ plt .xlabel ("Reference modality" , fontsize = 34 )
99
+ plt .ylabel ("Query modality" , fontsize = 34 )
100
+ plt .xticks (fontsize = 30 )
101
+ plt .yticks (fontsize = 30 )
103
102
ax .xaxis .set_label_position ("top" ) # Move the label to the top
104
103
plt .savefig (
105
104
dir_path
106
- / f"single_modal_recall5_{ cfg_dataset .retrieval_dim } _{ cfg_dataset .mask_ratio } { thres_tag } .png "
105
+ / f"single_modal_recall5_{ cfg_dataset .retrieval_dim } _{ cfg_dataset .mask_ratio } { thres_tag } .pdf "
107
106
)
108
107
109
108
0 commit comments