|
| 1 | +"""Plot functions.""" |
| 2 | + |
| 3 | +from pathlib import Path |
| 4 | + |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +import numpy as np |
| 7 | +import seaborn as sns |
| 8 | +from omegaconf import DictConfig |
| 9 | + |
| 10 | +import hydra |
| 11 | + |
| 12 | + |
| 13 | +@hydra.main(version_base=None, config_path="../config", config_name="main") |
| 14 | +def plot_single_modal_recall(cfg: DictConfig) -> None: |
| 15 | + """Plot single-modal recall.""" |
| 16 | + cfg_dataset = cfg[cfg.dataset] |
| 17 | + dir_path = ( |
| 18 | + Path(cfg_dataset.paths.plots_path) |
| 19 | + / f"{cfg_dataset.img_encoder}_{cfg_dataset.audio_encoder}" |
| 20 | + ) |
| 21 | + single1_recalls = [49.3, 2.6] |
| 22 | + single_recalls = np.array(single1_recalls).reshape(1, 2) |
| 23 | + plt.figure(figsize=(6, 4.3)) |
| 24 | + ax = sns.heatmap( |
| 25 | + single_recalls, |
| 26 | + fmt=".1f", |
| 27 | + cmap="YlGnBu", |
| 28 | + cbar=False, |
| 29 | + square=True, |
| 30 | + xticklabels=["Image", "Audio"], |
| 31 | + yticklabels=["Text"], |
| 32 | + annot=True, |
| 33 | + annot_kws={"size": 26, "weight": "bold"}, |
| 34 | + ) |
| 35 | + ax.xaxis.tick_top() |
| 36 | + plt.xlabel("Reference modality", fontsize=30) |
| 37 | + plt.ylabel("Query modality", fontsize=30) |
| 38 | + plt.xticks(fontsize=26) |
| 39 | + plt.yticks(fontsize=26) |
| 40 | + plt.tight_layout() |
| 41 | + ax.xaxis.set_label_position("top") # Move the label to the top |
| 42 | + plt.savefig( |
| 43 | + dir_path |
| 44 | + / f"single_modal_recall5_{cfg_dataset.retrieval_dim}_{cfg_dataset.mask_ratio}.pdf" |
| 45 | + ) |
| 46 | + print(f"Single-modal recall plot saved to {dir_path}") |
| 47 | + |
| 48 | + |
| 49 | +if __name__ == "__main__": |
| 50 | + plot_single_modal_recall() |
0 commit comments