Skip to content

Commit 2f33846

Browse files
committed
Plotting done
1 parent ab5fa8d commit 2f33846

File tree

3 files changed

+152
-68
lines changed

3 files changed

+152
-68
lines changed

config/main.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ dataset_size: {
2828
}
2929

3030
MSRVTT:
31-
img_encoder: "imagebind"
32-
audio_encoder: "imagebind"
31+
img_encoder: "clip"
32+
audio_encoder: "clap"
3333
retrieval_dim: "" # we use all the dimensions for retrieval
3434
mask_ratio: 4 # ratio of the missing data : size of test data
3535
paths:

mmda/plot.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""Plot functions."""
2+
3+
from pathlib import Path
4+
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
import seaborn as sns
8+
from omegaconf import DictConfig
9+
10+
import hydra
11+
12+
13+
@hydra.main(version_base=None, config_path="../config", config_name="main")
14+
def plot_single_modal_recall(cfg: DictConfig) -> None:
15+
"""Plot single-modal recall."""
16+
cfg_dataset = cfg[cfg.dataset]
17+
dir_path = (
18+
Path(cfg_dataset.paths.plots_path)
19+
/ f"{cfg_dataset.img_encoder}_{cfg_dataset.audio_encoder}"
20+
)
21+
single1_recalls = [49.3, 2.6]
22+
single_recalls = np.array(single1_recalls).reshape(1, 2)
23+
plt.figure(figsize=(6, 4.3))
24+
ax = sns.heatmap(
25+
single_recalls,
26+
fmt=".1f",
27+
cmap="YlGnBu",
28+
cbar=False,
29+
square=True,
30+
xticklabels=["Image", "Audio"],
31+
yticklabels=["Text"],
32+
annot=True,
33+
annot_kws={"size": 26, "weight": "bold"},
34+
)
35+
ax.xaxis.tick_top()
36+
plt.xlabel("Reference modality", fontsize=30)
37+
plt.ylabel("Query modality", fontsize=30)
38+
plt.xticks(fontsize=26)
39+
plt.yticks(fontsize=26)
40+
plt.tight_layout()
41+
ax.xaxis.set_label_position("top") # Move the label to the top
42+
plt.savefig(
43+
dir_path
44+
/ f"single_modal_recall5_{cfg_dataset.retrieval_dim}_{cfg_dataset.mask_ratio}.pdf"
45+
)
46+
print(f"Single-modal recall plot saved to {dir_path}")
47+
48+
49+
if __name__ == "__main__":
50+
plot_single_modal_recall()

mmda/plot_nonconformity.ipynb

Lines changed: 100 additions & 66 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)