Skip to content

Commit 0530202

Browse files
authored
Btc (#16)
* BTC done, but text performance is poor (why?) * Fix typo * fix import typo * Remove unused code and add 2 more maksing for BTC dataset * correct maksing logic * Change the logic of trend * np.max for BTC * np.max not np.mean * Remove title in correlation plots * trying diff trend settings * manually plot * News split into two parts for query * Rename plotting file name
1 parent 2f33846 commit 0530202

13 files changed

+771
-67
lines changed

bash_scripts/btc_script.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
poetry run python mmda/any2any_conformal_retrieval.py dataset=BTC BTC.retrieval_dim=10
2+
poetry run python mmda/any2any_conformal_retrieval.py dataset=BTC BTC.retrieval_dim=25
3+
poetry run python mmda/any2any_conformal_retrieval.py dataset=BTC BTC.retrieval_dim=50
4+
poetry run python mmda/any2any_conformal_retrieval.py dataset=BTC BTC.retrieval_dim=75
5+
poetry run python mmda/any2any_conformal_retrieval.py dataset=BTC BTC.retrieval_dim=100
6+

config/main.yaml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class_level_datasets: [sop]
1313
object_level_datasets: [pitts, sop]
1414
mislabeled_datasets: [imagenet, cosmos, tiil]
1515
retrieval_datasets: [flickr]
16-
any_retrieval_datasets: [KITTI, MSRVTT]
16+
any_retrieval_datasets: [KITTI, MSRVTT, BTC]
1717
shuffle_llava_datasets: [pitts, sop] # datasets whose plots contains llava
1818
mislabel_llava_datasets: [imagenet]
1919
classification_datasets: [imagenet, leafy_spurge]
@@ -27,6 +27,19 @@ dataset_size: {
2727
flickr: 155070
2828
}
2929

30+
BTC:
31+
retrieval_dim: 100
32+
equal_weights: False
33+
img_encoder: ""
34+
audio_encoder: ""
35+
horizon: 120
36+
mask_ratio: 2 # ratio of the missing data : size of test data
37+
paths:
38+
dataset_path: "/nas/timeseries/timeseries_synthesis/sameep_store/btc/split_fresh_large_120/"
39+
save_path: ${BTC.paths.dataset_path}/any2any/
40+
plots_path: ${repo_root}plots/BTC/
41+
42+
3043
MSRVTT:
3144
img_encoder: "clip"
3245
audio_encoder: "clap"

mmda/any2any_conformal_retrieval.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,16 @@ def main(cfg: DictConfig) -> None:
4848
}
4949
df = pd.DataFrame(data)
5050
dir_path = Path(cfg_dataset.paths.plots_path)
51-
if cfg.dataset == "KITTI":
51+
if cfg.dataset == "MSRVTT":
5252
df_path = (
5353
dir_path
54-
/ f"any2any_retrieval_{cfg_dataset.retrieval_dim}_{cfg_dataset.mask_ratio}{thres_tag}.csv"
54+
/ f"{cfg_dataset.img_encoder}_{cfg_dataset.audio_encoder}"
55+
/ f"any2any_retrieval_{cfg_dataset.retrieval_dim}_{cfg_dataset.mask_ratio}.csv"
5556
)
56-
elif cfg.dataset == "MSRVTT":
57+
else:
5758
df_path = (
5859
dir_path
59-
/ f"{cfg_dataset.img_encoder}_{cfg_dataset.audio_encoder}"
60-
/ f"any2any_retrieval_{cfg_dataset.retrieval_dim}_{cfg_dataset.mask_ratio}.csv"
60+
/ f"any2any_retrieval_{cfg_dataset.retrieval_dim}_{cfg_dataset.mask_ratio}{thres_tag}.csv"
6161
)
6262
df_path.parent.mkdir(parents=True, exist_ok=True)
6363
df.to_csv(df_path, index=False)
@@ -91,6 +91,20 @@ def main(cfg: DictConfig) -> None:
9191
annot=True,
9292
annot_kws={"size": 30, "weight": "bold"},
9393
)
94+
elif cfg.dataset == "BTC":
95+
single_recalls = np.array(list(single1_recalls.values())).reshape(2, 2) * 100
96+
plt.figure(figsize=(8, 8))
97+
ax = sns.heatmap(
98+
single_recalls,
99+
fmt=".1f",
100+
cmap="YlGnBu",
101+
cbar=False,
102+
square=True,
103+
xticklabels=["Time", "Stats"],
104+
yticklabels=["Text", "Trend"],
105+
annot=True,
106+
annot_kws={"size": 34, "weight": "bold"},
107+
)
94108
else:
95109
msg = f"unknown dataset {cfg.dataset}"
96110
raise ValueError(msg)

mmda/plot.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

mmda/plot_nonconformity.ipynb

Lines changed: 8 additions & 8 deletions
Large diffs are not rendered by default.

mmda/plot_single_modal.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""Plot functions."""
2+
3+
from pathlib import Path
4+
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
import seaborn as sns
8+
from omegaconf import DictConfig
9+
10+
import hydra
11+
12+
13+
@hydra.main(version_base=None, config_path="../config", config_name="main")
14+
def plot_single_modal_recall(cfg: DictConfig) -> None:
15+
"""Plot single-modal recall."""
16+
cell_size = 30
17+
label_size = 30
18+
ticks_size = 28
19+
20+
cfg_dataset = cfg["KITTI"]
21+
dir_path = Path(cfg_dataset.paths.plots_path)
22+
single1_recalls = [[31.9, 31.9, 31.7], [32.4, 32.4, 31.8], [33.7, 32.8, 32.2]]
23+
single_recalls = np.array(single1_recalls).reshape(3, 3)
24+
plt.figure(figsize=(9, 9))
25+
ax = sns.heatmap(
26+
single_recalls,
27+
fmt=".1f",
28+
cmap="YlGnBu",
29+
cbar=False,
30+
square=True,
31+
xticklabels=["Image", "Lidar", "Text"],
32+
yticklabels=["Image", "Lidar", "Text"],
33+
annot=True,
34+
annot_kws={"size": cell_size, "weight": "bold"},
35+
)
36+
ax.xaxis.tick_top()
37+
plt.xlabel("Reference modality", fontsize=label_size)
38+
plt.ylabel("Query modality", fontsize=label_size)
39+
plt.xticks(fontsize=ticks_size)
40+
plt.yticks(fontsize=ticks_size)
41+
plt.tight_layout()
42+
ax.xaxis.set_label_position("top") # Move the label to the top
43+
plt.subplots_adjust(bottom=-0.05)
44+
plt.savefig(
45+
dir_path
46+
/ f"single_modal_recall5_{cfg_dataset.retrieval_dim}_{cfg_dataset.mask_ratio}.pdf"
47+
)
48+
49+
cfg_dataset = cfg["MSRVTT"]
50+
dir_path = (
51+
Path(cfg_dataset.paths.plots_path)
52+
/ f"{cfg_dataset.img_encoder}_{cfg_dataset.audio_encoder}"
53+
)
54+
single1_recalls = [49.3, 2.6]
55+
single_recalls = np.array(single1_recalls).reshape(1, 2)
56+
plt.figure(figsize=(6, 4.5))
57+
ax = sns.heatmap(
58+
single_recalls,
59+
fmt=".1f",
60+
cmap="YlGnBu",
61+
cbar=False,
62+
square=True,
63+
xticklabels=["Image", "Audio"],
64+
yticklabels=["Text"],
65+
annot=True,
66+
annot_kws={"size": cell_size, "weight": "bold"},
67+
)
68+
ax.xaxis.tick_top()
69+
plt.xlabel("Reference modality", fontsize=label_size)
70+
plt.ylabel("Query modality", fontsize=label_size)
71+
plt.xticks(fontsize=ticks_size)
72+
plt.yticks(fontsize=ticks_size)
73+
plt.tight_layout()
74+
ax.xaxis.set_label_position("top") # Move the label to the top
75+
plt.savefig(
76+
dir_path
77+
/ f"single_modal_recall5_{cfg_dataset.retrieval_dim}_{cfg_dataset.mask_ratio}.pdf"
78+
)
79+
80+
cfg_dataset = cfg["BTC"]
81+
dir_path = Path(cfg_dataset.paths.plots_path)
82+
single1_recalls = [[4.1, 4.7], [3.4, 4.7]]
83+
single_recalls = np.array(single1_recalls).reshape(2, 2)
84+
plt.figure(figsize=(6, 6))
85+
ax = sns.heatmap(
86+
single_recalls,
87+
fmt=".1f",
88+
cmap="YlGnBu",
89+
cbar=False,
90+
square=True,
91+
xticklabels=["Time", "Stats"],
92+
yticklabels=["Prev News", "Text (2)"],
93+
annot=True,
94+
annot_kws={"size": cell_size, "weight": "bold"},
95+
)
96+
ax.xaxis.tick_top()
97+
plt.xlabel("Reference modality", fontsize=label_size)
98+
plt.ylabel("Query modality", fontsize=label_size)
99+
plt.xticks(fontsize=ticks_size)
100+
plt.yticks(fontsize=ticks_size)
101+
plt.tight_layout()
102+
ax.xaxis.set_label_position("top") # Move the label to the top
103+
plt.subplots_adjust(bottom=-0.05)
104+
plt.savefig(
105+
dir_path
106+
/ f"single_modal_recall5_{cfg_dataset.retrieval_dim}_{cfg_dataset.mask_ratio}.pdf"
107+
)
108+
109+
110+
if __name__ == "__main__":
111+
plot_single_modal_recall()

mmda/utils/any2any_ds_class.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def __init__(self) -> None:
2727

2828
def preprocess_retrieval_data(self) -> None:
2929
"""Preprocess the data for retrieval."""
30+
# create the save path if not exists
31+
Path(self.cfg_dataset.paths.save_path).mkdir(parents=True, exist_ok=True)
3032

3133
def train_crossmodal_similarity(self) -> None:
3234
"""Train the cross-modal similarity, aka the CSA method."""

0 commit comments

Comments
 (0)