@@ -63,7 +63,6 @@ def main(cfg: DictConfig) -> None:
6363 df .to_csv (df_path , index = False )
6464
6565 if cfg .dataset == "KITTI" :
66- # plot heatmap of single modality retrieval
6766 single_recalls = np .array (list (single1_recalls .values ())).reshape (3 , 3 ) * 100
6867 plt .figure (figsize = (8 , 8 ))
6968 ax = sns .heatmap (
@@ -78,7 +77,7 @@ def main(cfg: DictConfig) -> None:
7877 annot_kws = {"size" : 26 , "weight" : "bold" },
7978 )
8079 elif cfg .dataset == "MSRVTT" :
81- # plot heatmap of single modality retrieval
80+ dir_path = dir_path / f" { cfg_dataset . img_encoder } _ { cfg_dataset . audio_encoder } "
8281 single_recalls = np .array (list (single1_recalls .values ())).reshape (1 , 2 ) * 100
8382 plt .figure (figsize = (8 , 6 ))
8483 ax = sns .heatmap (
@@ -90,20 +89,20 @@ def main(cfg: DictConfig) -> None:
9089 xticklabels = ["Image" , "Audio" ],
9190 yticklabels = ["Text" ],
9291 annot = True ,
93- annot_kws = {"size" : 26 , "weight" : "bold" },
92+ annot_kws = {"size" : 30 , "weight" : "bold" },
9493 )
9594 else :
9695 msg = f"unknown dataset { cfg .dataset } "
9796 raise ValueError (msg )
9897 ax .xaxis .tick_top ()
99- plt .xlabel ("Reference modality" , fontsize = 30 )
100- plt .ylabel ("Query modality" , fontsize = 30 )
101- plt .xticks (fontsize = 26 )
102- plt .yticks (fontsize = 26 )
98+ plt .xlabel ("Reference modality" , fontsize = 34 )
99+ plt .ylabel ("Query modality" , fontsize = 34 )
100+ plt .xticks (fontsize = 30 )
101+ plt .yticks (fontsize = 30 )
103102 ax .xaxis .set_label_position ("top" ) # Move the label to the top
104103 plt .savefig (
105104 dir_path
106- / f"single_modal_recall5_{ cfg_dataset .retrieval_dim } _{ cfg_dataset .mask_ratio } { thres_tag } .png "
105+ / f"single_modal_recall5_{ cfg_dataset .retrieval_dim } _{ cfg_dataset .mask_ratio } { thres_tag } .pdf "
107106 )
108107
109108
0 commit comments