Skip to content

Commit b4c532f

Browse files
authored
Iclr-rebuttal (#17)
* Create branch * bimodal results for rebuttal * t-SNE visualizations * Fix ruff error * Linear classifier * optimize plots * Ablation study of encoder architectures * rewrite CCA... * multi-class roc * handwriting done
1 parent dec2290 commit b4c532f

21 files changed

+1075
-37
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -171,4 +171,5 @@ plots/*
171171
# lock files
172172
*.lock
173173
.checkpoints/
174-
.assets/
174+
.assets/
175+
*.keras

bash_scripts/handwriting_script.sh

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# classification
2+
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=10
3+
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=25
4+
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=50
5+
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=100
6+
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=200
7+
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=500

bash_scripts/imagenet_script.sh

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# mislabeled data
12
# CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_mislabeled_data.py dataset=imagenet imagenet.sim_dim=10
23
# CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_mislabeled_data.py dataset=imagenet imagenet.sim_dim=25
34
# CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_mislabeled_data.py dataset=imagenet imagenet.sim_dim=50
@@ -15,6 +16,12 @@
1516
# CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_mislabeled_data.py dataset=imagenet train_test_ratio=0.1 imagenet.sim_dim=500
1617
# CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_mislabeled_data.py dataset=imagenet train_test_ratio=0.1 imagenet.sim_dim=700
1718

18-
# # classification
19+
# classification
1920
# CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=imagenet
20-
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=imagenet imagenet.shuffle=True
21+
# CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=imagenet imagenet.shuffle=True
22+
23+
# SVM
24+
# poetry run python mmda/linear_svm_clip.py train_test_ratio=0.1
25+
# poetry run python mmda/linear_svm_clip.py train_test_ratio=0.3
26+
# poetry run python mmda/linear_svm_clip.py train_test_ratio=0.5
27+
# poetry run python mmda/linear_svm_clip.py train_test_ratio=0.7

bash_scripts/leafy_script.sh

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
# CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=leafy_spurge leafy_spurge.sim_dim=10
22
# CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=leafy_spurge leafy_spurge.sim_dim=50
33
# CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=leafy_spurge leafy_spurge.sim_dim=100
4-
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=leafy_spurge leafy_spurge.sim_dim=250
4+
# CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=leafy_spurge leafy_spurge.sim_dim=250
5+
6+
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/linear_svm_clip.py dataset=leafy_spurge leafy_spurge.sim_dim=250 train_test_ratio=0.4
7+
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/linear_svm_clip.py dataset=leafy_spurge leafy_spurge.sim_dim=250 train_test_ratio=0.6
8+
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/linear_svm_clip.py dataset=leafy_spurge leafy_spurge.sim_dim=250 train_test_ratio=0.7
9+
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/linear_svm_clip.py dataset=leafy_spurge leafy_spurge.sim_dim=250 train_test_ratio=0.888

config/main.yaml

+19-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ noisy_train_set: True
77
repo_root: "/home/pl22767/Project/MMDA/"
88
# repo_root: "/home/po-han/Desktop/Projects/MMDA/"
99

10-
dataset: "MSRVTT"
10+
dataset: "handwriting"
1111
dataset_level_datasets: [pitts, imagenet, cosmos, sop, tiil, musiccaps, flickr]
1212
class_level_datasets: [sop]
1313
object_level_datasets: [pitts, sop]
@@ -16,7 +16,7 @@ retrieval_datasets: [flickr]
1616
any_retrieval_datasets: [KITTI, MSRVTT, BTC]
1717
shuffle_llava_datasets: [pitts, sop] # datasets whose plots contains llava
1818
mislabel_llava_datasets: [imagenet]
19-
classification_datasets: [imagenet, leafy_spurge]
19+
classification_datasets: [imagenet, leafy_spurge, handwriting]
2020
dataset_size: {
2121
sop: 56222,
2222
musiccaps: 5397,
@@ -39,7 +39,6 @@ BTC:
3939
save_path: ${BTC.paths.dataset_path}/any2any/
4040
plots_path: ${repo_root}plots/BTC/
4141

42-
4342
MSRVTT:
4443
img_encoder: "clip"
4544
audio_encoder: "clap"
@@ -92,9 +91,10 @@ musiccaps:
9291
imagenet:
9392
sim_dim: 700 # dimension of the similarity score and the CCA transformation
9493
equal_weights: False
95-
img_encoder: "dino"
96-
text_encoder: "gtr"
97-
train_test_ratios: [0.7] #, 0.3, 0.5, 0.7]
94+
img_encoder: "clipopenai"
95+
text_encoder: "clipdatacomp_xl_s13b_b90k"
96+
model_name: "openai"
97+
train_test_ratios: [0.3, 0.5, 0.7]
9898
shuffle_ratios: [0.1, 0.3, 0.5, 0.7, 1.0]
9999
shuffle: False
100100
paths:
@@ -103,6 +103,19 @@ imagenet:
103103
plots_path: ${repo_root}plots/ImageNet/
104104
label_embeddings: ${imagenet.paths.dataset_path}_${text_encoder}_label_embeddings.npy
105105

106+
handwriting:
107+
sim_dim: 50 # dimension of the similarity score and the CCA transformation
108+
equal_weights: True
109+
img_encoder: "tsfresh"
110+
text_encoder: "gtr"
111+
train_test_ratios: [0.85]
112+
shuffle: False
113+
paths:
114+
dataset_path: "/nas/pohan/datasets/Handwriting/"
115+
save_path: ${handwriting.paths.dataset_path}embeddings/
116+
plots_path: ${repo_root}plots/handwriting/
117+
label_embeddings: ${handwriting.paths.dataset_path}_${text_encoder}_label_embeddings.npy
118+
106119
leafy_spurge:
107120
sim_dim: 700 # dimension of the similarity score and the CCA transformation
108121
equal_weights: False

mmda/bimodal_classification.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
@hydra.main(version_base=None, config_path="../config", config_name="main")
18-
def main(cfg: DictConfig) -> None:
18+
def main(cfg: DictConfig) -> None: # noqa: C901, PLR0915, PLR0912
1919
"""Main function to generate the classification results of the bimodal datasets.
2020
2121
Args:
@@ -27,7 +27,12 @@ def main(cfg: DictConfig) -> None:
2727
), f"{cfg.dataset} is not for classification."
2828
cfg_dataset = cfg[cfg.dataset]
2929
shuffle_tag = "shuffled" if cfg_dataset.shuffle else ""
30-
ds_size = 50_000 if cfg.dataset == "imagenet" else 900
30+
if cfg.dataset == "imagenet":
31+
ds_size = 50_000
32+
elif cfg.dataset == "leafy_spurge":
33+
ds_size = 900
34+
elif cfg.dataset == "handwriting":
35+
ds_size = 1000
3136
csv_save_path = (
3237
Path(cfg_dataset.paths.plots_path)
3338
/ f"classify_{cfg_dataset.text_encoder}_{cfg_dataset.img_encoder}/"
@@ -47,9 +52,9 @@ def main(cfg: DictConfig) -> None:
4752
f.write(f"{shuffle_ratio},{cca_accs},{asif_accs}\n")
4853
else:
4954
for train_test_ratio in cfg_dataset.train_test_ratios:
50-
asif_accs = asif_classification(cfg, train_test_ratio)
5155
cca_accs = cca_classification(cfg, train_test_ratio)
52-
clip_accs = clip_like_classification(cfg, train_test_ratio)
56+
asif_accs = 0 if True else asif_classification(cfg, train_test_ratio)
57+
clip_accs = 0 if True else clip_like_classification(cfg, train_test_ratio)
5358
# write accuracy to file
5459
if not csv_save_path.exists():
5560
# create the file and write the header
@@ -77,7 +82,7 @@ def main(cfg: DictConfig) -> None:
7782
label="CSA (ours)",
7883
color="blue",
7984
)
80-
if not cfg_dataset.shuffle:
85+
if not cfg_dataset.shuffle and cfg.dataset != "handwriting":
8186
clip_accs = df["clip_accs"]
8287
ax.plot(
8388
ratios,
@@ -99,7 +104,12 @@ def main(cfg: DictConfig) -> None:
99104
ax.set_ylabel("Classification accuracy", fontsize=20)
100105
ax.xaxis.set_tick_params(labelsize=15)
101106
ax.yaxis.set_tick_params(labelsize=15)
102-
ax.set_ylim(0, 1.03) if cfg.dataset == "imagenet" else ax.set_ylim(0.4, 0.65)
107+
if cfg.dataset == "imagenet":
108+
ax.set_ylim(0, 1.03)
109+
elif cfg.dataset == "leafy_spurge":
110+
ax.set_ylim(0.4, 0.65)
111+
else:
112+
ax.set_ylim(0, 1.03)
103113
(
104114
ax.legend(loc="lower right", fontsize=18)
105115
if not cfg_dataset.shuffle

mmda/exps/classification.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
import numpy as np
44
from omegaconf import DictConfig
55

6-
from mmda.utils.cca_class import NormalizedCCA
6+
from mmda.utils.cca_class import NormalizedCCA, ReNormalizedCCA
77
from mmda.utils.classification_dataset_class import load_classification_dataset
88
from mmda.utils.sim_utils import cosine_sim, weighted_corr_sim
99

1010

1111
def cca_classification(
1212
cfg: DictConfig, train_test_ratio: float, shuffle_ratio: float = 0.0
13-
) -> tuple[dict[float:float], dict[float : dict[float:float]]]:
13+
) -> float:
1414
"""Retrieve data using the proposed CCA method.
1515
1616
Args:
@@ -21,9 +21,10 @@ def cca_classification(
2121
data_size2accuracy: {data_size: accuracy}
2222
"""
2323
cfg_dataset = cfg[cfg.dataset]
24+
print(f"CCA {cfg_dataset.sim_dim}")
2425
ds = load_classification_dataset(cfg)
2526
ds.load_data(train_test_ratio, clip_bool=False, shuffle_ratio=shuffle_ratio)
26-
cca = NormalizedCCA()
27+
cca = ReNormalizedCCA() if True else NormalizedCCA()
2728
ds.train_img, ds.train_text, corr = cca.fit_transform_train_data(
2829
cfg_dataset, ds.train_img, ds.train_text
2930
)
@@ -39,9 +40,7 @@ def sim_fn(x: np.array, y: np.array, corr: np.array = corr) -> np.array:
3940
return ds.classification(sim_fn=sim_fn)
4041

4142

42-
def clip_like_classification(
43-
cfg: DictConfig, train_test_ratio: float
44-
) -> tuple[dict[float:float], dict[float:float]]:
43+
def clip_like_classification(cfg: DictConfig, train_test_ratio: float) -> float:
4544
"""Retrieve data using the CLIP-like method.
4645
4746
Args:
@@ -50,6 +49,7 @@ def clip_like_classification(
5049
Returns:
5150
data_size2accuracy: {data_size: accuracy}
5251
"""
52+
print("CLIP-like")
5353
ds = load_classification_dataset(cfg)
5454
ds.load_data(train_test_ratio, clip_bool=True)
5555
ds.get_labels_emb()
@@ -58,7 +58,7 @@ def clip_like_classification(
5858

5959
def asif_classification(
6060
cfg: DictConfig, train_test_ratio: float, shuffle_ratio: float = 0.0
61-
) -> tuple[dict[float:float], dict[float:float]]:
61+
) -> float:
6262
"""Retrieve data using the CLIP-like method.
6363
6464
Args:
@@ -68,6 +68,7 @@ def asif_classification(
6868
Returns:
6969
data_size2accuracy: {data_size: accuracy}
7070
"""
71+
print("ASIF")
7172
ds = load_classification_dataset(cfg)
7273
ds.load_data(train_test_ratio, clip_bool=False, shuffle_ratio=shuffle_ratio)
7374
ds.get_labels_emb()

mmda/exps/mislabel_align.py

+2
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ def __init__(self, *args, **kwargs): # noqa: ANN204, ANN002, ANN003
118118
"valdata2align": valdata2align,
119119
"valdata1unalign": valdata1unalign,
120120
"valdata2unalign": valdata2unalign,
121+
"train_idx": train_idx,
122+
"train_wrong_labels_bool": train_wrong_labels_bool,
121123
}
122124
)
123125
return alldata

mmda/get_embeddings.py

+66-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from mmda.utils.dataset_utils import (
1616
load_cosmos,
1717
load_flickr,
18+
load_handwriting,
1819
load_imagenet,
1920
load_kitti,
2021
load_leafy_spurge,
@@ -25,22 +26,24 @@
2526
load_tiil,
2627
)
2728
from mmda.utils.embed_data import (
29+
chronos_ts,
2830
clap_audio,
2931
clap_text,
3032
clip_imgs,
3133
clip_text,
3234
cosplace_img,
3335
dinov2,
36+
fair_clip_imgs,
37+
fair_clip_text,
3438
gtr_text,
3539
)
36-
from mmda.utils.imagebind_utils import ImageBindInference
3740
from mmda.utils.video_audio_utils import (
3841
get_video_emb,
3942
prepare_audio_for_imagebind,
4043
process_audio,
4144
)
4245

43-
BATCH_SIZE = 256
46+
BATCH_SIZE = 758
4447

4548

4649
@hydra.main(version_base=None, config_path="../config", config_name="main")
@@ -92,6 +95,8 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
9295
pickle.dump(clap_audio_features, f)
9396

9497
elif dataset == "MSRVTT":
98+
from mmda.utils.imagebind_utils import ImageBindInference
99+
95100
_, captions, video_info_sen_order, video_dict = load_msrvtt(cfg_dataset)
96101
id_order, img_paths, audio_start_secs, audio_num_secs = get_video_emb(
97102
cfg_dataset, video_dict
@@ -372,6 +377,24 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
372377
text_descriptions = ["An image of " + label + "." for label in orig_labels]
373378

374379
# get text embeddings
380+
model = "openai"
381+
382+
img_emb = fair_clip_imgs(img_path, BATCH_SIZE, model_name=("ViT-L-14", model))
383+
with Path(
384+
cfg_dataset.paths.save_path, f"ImageNet_img_emb_clip{model}.pkl"
385+
).open("wb") as f:
386+
pickle.dump(img_emb, f)
387+
print("FairCLIP embeddings saved")
388+
389+
text_emb = fair_clip_text(
390+
text_descriptions, BATCH_SIZE, model_name=("ViT-L-14", model)
391+
)
392+
with Path(
393+
cfg_dataset.paths.save_path, f"ImageNet_text_emb_clip{model}.pkl"
394+
).open("wb") as f:
395+
pickle.dump(text_emb, f)
396+
print("FairCLIP embeddings saved")
397+
375398
text_emb = clip_text(text_descriptions, BATCH_SIZE)
376399
with Path(cfg_dataset.paths.save_path, "ImageNet_text_emb_clip.pkl").open(
377400
"wb"
@@ -536,6 +559,47 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
536559
pickle.dump(img_emb, f)
537560
print("CLIP embeddings saved")
538561

562+
elif dataset == "handwriting":
563+
data, labels, num2alphabet, alphabets_hand = load_handwriting(cfg_dataset)
564+
# save data
565+
with Path(cfg_dataset.paths.save_path, "Handwriting_data.pkl").open("wb") as f:
566+
pickle.dump(data, f)
567+
print("Handwriting data saved")
568+
return
569+
570+
embeddings = clip_imgs(alphabets_hand, 256)
571+
print("text shape:", embeddings.shape)
572+
with Path(cfg_dataset.paths.save_path, "Handwriting_emb_clip.pkl").open(
573+
"wb"
574+
) as f:
575+
pickle.dump(embeddings, f)
576+
print("CLIP embeddings saved")
577+
578+
sentences = [f"Alphabet {num2alphabet[label]}." for label in labels]
579+
print(sentences[15:21])
580+
embeddings = gtr_text(sentences)
581+
assert np.allclose(
582+
embeddings[15], embeddings[20], atol=1e-3, rtol=1e-4
583+
), f"{embeddings[15].shape}!={embeddings[20].shape}"
584+
with Path(cfg_dataset.paths.save_path, "Handwriting_emb_gtr.pkl").open(
585+
"wb"
586+
) as f:
587+
pickle.dump(embeddings, f)
588+
print("GTR shape:", embeddings.shape)
589+
print("GTR embeddings saved")
590+
591+
embeddings = chronos_ts(data)
592+
# check if embeddings has unique rows
593+
assert embeddings.shape[0] == len(
594+
np.unique(embeddings, axis=0)
595+
), f"Embeddings has repeated entries. {embeddings.shape[0]}!={len(np.unique(embeddings, axis=0))}"
596+
print("Chronos shape:", embeddings.shape)
597+
with Path(cfg_dataset.paths.save_path, "Handwriting_emb_chronos.pkl").open(
598+
"wb"
599+
) as f:
600+
pickle.dump(embeddings, f)
601+
print("Chronos embeddings saved")
602+
539603
# TODO: add more datasets
540604
else:
541605
msg = f"Dataset {dataset} not supported."

mmda/handwriting_baseline.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""This script is for the handwriting baseline."""
2+
3+
import numpy as np
4+
from aeon.classification.deep_learning import InceptionTimeClassifier
5+
from omegaconf import DictConfig
6+
from sklearn.metrics import accuracy_score
7+
8+
import hydra
9+
from mmda.utils.dataset_utils import load_handwriting
10+
11+
12+
@hydra.main(version_base=None, config_path="../config", config_name="main")
13+
def main(cfg: DictConfig) -> None:
14+
"""Train the handwriting baseline."""
15+
x, labels, _ = load_handwriting(cfg_dataset=cfg.handwriting)
16+
inception = InceptionTimeClassifier()
17+
for train_test_ratio in cfg.handwriting.train_test_ratios:
18+
np.random.seed(42)
19+
train_size = int(train_test_ratio * x.shape[0])
20+
print(x.shape, labels.shape)
21+
inception.fit(x[:train_size], labels[:train_size])
22+
y_pred = inception.predict(x[train_size:])
23+
accuracy = accuracy_score(labels[train_size:], y_pred)
24+
print(f"train_test_ratio: {train_test_ratio}, accuracy: {accuracy}")
25+
26+
27+
if __name__ == "__main__":
28+
main()
29+
# CUDA_VISIBLE_DEVICES="" poetry run python mmda/handwriting_baseline.py

0 commit comments

Comments
 (0)