Skip to content

Commit a196842

Browse files
committed
multi-class roc
1 parent 40b2fb6 commit a196842

11 files changed

+205
-87
lines changed

bash_scripts/handwriting_script.sh

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# classification
2-
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=10
3-
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=25
2+
# CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=10
3+
# CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=25
4+
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=30
45
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=50
5-
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=100
6-
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=200
7-
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=700
6+
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=60
7+
CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=70
8+
# CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=100
9+
# CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=200
10+
# CUDA_VISIBLE_DEVICES=1 poetry run python mmda/bimodal_classification.py dataset=handwriting handwriting.sim_dim=500

config/main.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ imagenet:
106106
handwriting:
107107
sim_dim: 700 # dimension of the similarity score and the CCA transformation
108108
equal_weights: False
109-
img_encoder: "chronos"
109+
img_encoder: "tsfresh"
110110
text_encoder: "clip"
111-
train_test_ratios: [0.9]
111+
train_test_ratios: [0.85]
112112
shuffle: False
113113
paths:
114114
dataset_path: "/nas/pohan/datasets/Handwriting/"

mmda/bimodal_classification.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,9 @@ def main(cfg: DictConfig) -> None: # noqa: C901, PLR0915, PLR0912
5252
f.write(f"{shuffle_ratio},{cca_accs},{asif_accs}\n")
5353
else:
5454
for train_test_ratio in cfg_dataset.train_test_ratios:
55-
asif_accs = asif_classification(cfg, train_test_ratio)
5655
cca_accs = cca_classification(cfg, train_test_ratio)
57-
clip_accs = (
58-
clip_like_classification(cfg, train_test_ratio)
59-
if cfg.dataset != "handwriting"
60-
else 0
61-
)
56+
asif_accs = 0 if True else asif_classification(cfg, train_test_ratio)
57+
clip_accs = 0 if True else clip_like_classification(cfg, train_test_ratio)
6258
# write accuracy to file
6359
if not csv_save_path.exists():
6460
# create the file and write the header

mmda/exps/classification.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,14 @@ def cca_classification(
2020
Returns:
2121
data_size2accuracy: {data_size: accuracy}
2222
"""
23-
print("CCA")
2423
cfg_dataset = cfg[cfg.dataset]
24+
print(f"CCA {cfg_dataset.sim_dim}")
2525
ds = load_classification_dataset(cfg)
2626
ds.load_data(train_test_ratio, clip_bool=False, shuffle_ratio=shuffle_ratio)
27-
cca = ReNormalizedCCA() if cfg.dataset == "handwriting" else NormalizedCCA()
27+
cca = ReNormalizedCCA() if True else NormalizedCCA()
2828
ds.train_img, ds.train_text, corr = cca.fit_transform_train_data(
2929
cfg_dataset, ds.train_img, ds.train_text
3030
)
31-
print("corr", corr)
3231
ds.test_img, ds.test_text = cca.transform_data(ds.test_img, ds.test_text)
3332

3433
ds.get_labels_emb()
@@ -50,6 +49,7 @@ def clip_like_classification(cfg: DictConfig, train_test_ratio: float) -> float:
5049
Returns:
5150
data_size2accuracy: {data_size: accuracy}
5251
"""
52+
print("CLIP-like")
5353
ds = load_classification_dataset(cfg)
5454
ds.load_data(train_test_ratio, clip_bool=True)
5555
ds.get_labels_emb()
@@ -68,6 +68,7 @@ def asif_classification(
6868
Returns:
6969
data_size2accuracy: {data_size: accuracy}
7070
"""
71+
print("ASIF")
7172
ds = load_classification_dataset(cfg)
7273
ds.load_data(train_test_ratio, clip_bool=False, shuffle_ratio=shuffle_ratio)
7374
ds.get_labels_emb()

mmda/get_embeddings.py

+27-38
Original file line numberDiff line numberDiff line change
@@ -560,39 +560,35 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
560560
print("CLIP embeddings saved")
561561

562562
elif dataset == "handwriting":
563-
# sentence_26 = {
564-
# 1: "apple.",
565-
# 2: "ball.",
566-
# 3: "cat.",
567-
# 4: "dog.",
568-
# 5: "elephant.",
569-
# 6: "fish.",
570-
# 7: "giraffe.",
571-
# 8: "hat.",
572-
# 9: "ice cream.",
573-
# 10: "jaguar.",
574-
# 11: "kangaroo.",
575-
# 12: "lion.",
576-
# 13: "monkey.",
577-
# 14: "nest.",
578-
# 15: "owl.",
579-
# 16: "penguin.",
580-
# 17: "queen.",
581-
# 18: "rabbit.",
582-
# 19: "snake.",
583-
# 20: "tiger.",
584-
# 21: "umbrella.",
585-
# 22: "vase.",
586-
# 23: "whale.",
587-
# 24: "x-ray.",
588-
# 25: "yak.",
589-
# 26: "zebra.",
590-
# }
591563
data, labels, num2alphabet, alphabets_hand = load_handwriting(cfg_dataset)
592-
# sentences = [sentence_26[int(label.split(".")[0])] for label in labels]
593-
# int_labels = [int(label.split(".")[0]) - 1 for label in labels]
564+
# save data
565+
with Path(cfg_dataset.paths.save_path, "Handwriting_data.pkl").open("wb") as f:
566+
pickle.dump(data, f)
567+
print("Handwriting data saved")
568+
return
594569

595-
embeddings = chronos_ts(data) if False else data.reshape(data.shape[0], -1)
570+
embeddings = clip_imgs(alphabets_hand, 256)
571+
print("text shape:", embeddings.shape)
572+
with Path(cfg_dataset.paths.save_path, "Handwriting_emb_clip.pkl").open(
573+
"wb"
574+
) as f:
575+
pickle.dump(embeddings, f)
576+
print("CLIP embeddings saved")
577+
578+
sentences = [f"Alphabet {num2alphabet[label]}." for label in labels]
579+
print(sentences[15:21])
580+
embeddings = gtr_text(sentences)
581+
assert np.allclose(
582+
embeddings[15], embeddings[20], atol=1e-3, rtol=1e-4
583+
), f"{embeddings[15].shape}!={embeddings[20].shape}"
584+
with Path(cfg_dataset.paths.save_path, "Handwriting_emb_gtr.pkl").open(
585+
"wb"
586+
) as f:
587+
pickle.dump(embeddings, f)
588+
print("GTR shape:", embeddings.shape)
589+
print("GTR embeddings saved")
590+
591+
embeddings = chronos_ts(data)
596592
# check if embeddings has unique rows
597593
assert embeddings.shape[0] == len(
598594
np.unique(embeddings, axis=0)
@@ -604,13 +600,6 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
604600
pickle.dump(embeddings, f)
605601
print("Chronos embeddings saved")
606602

607-
embeddings = clip_imgs(alphabets_hand, 256)
608-
print("text shape:", embeddings.shape)
609-
with Path(cfg_dataset.paths.save_path, "Handwriting_text_emb_clip.pkl").open(
610-
"wb"
611-
) as f:
612-
pickle.dump(embeddings, f)
613-
print("CLIP embeddings saved")
614603
# TODO: add more datasets
615604
else:
616605
msg = f"Dataset {dataset} not supported."

mmda/tsfresh_features.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""Extract tsfresh features from the Handwriting dataset."""
2+
3+
import pickle
4+
from pathlib import Path
5+
6+
import kagglehub
7+
import numpy as np
8+
import pandas as pd
9+
from aeon.datasets import load_classification
10+
from PIL import Image
11+
from tsfresh import extract_features
12+
13+
PATH = "/nas/pohan/datasets/Handwriting/"
14+
PATH_SAVE = "/nas/pohan/datasets/Handwriting/embeddings/"
15+
16+
17+
def load_handwriting() -> tuple[np.ndarray, np.ndarray, dict[str, tuple[str, str]]]:
18+
"""Load the Handwriting dataset (https://github.com/amazon-science/aeon).
19+
20+
Args:
21+
cfg_dataset: configuration file
22+
Returns:
23+
data: data. shape: (num_samples, 3, 152)
24+
labels: labels. e.g. "1.0"
25+
num2alphabet: a dict of index to alphabet
26+
alphabets_hand: list of PIL images
27+
"""
28+
# train_x.shape: (150, 3, 152), test_x.shape: (850, 3, 152)
29+
train_x, train_y = load_classification(
30+
"Handwriting", split="train"
31+
) # np.ndarray, list[str]
32+
test_x, test_y = load_classification("Handwriting", split="test")
33+
# merge train and test
34+
x = np.concatenate([train_x, test_x], axis=0)
35+
y = np.concatenate([train_y, test_y], axis=0)
36+
num2alphabet = {f"{i+1}.0": chr(65 + i) for i in range(26)}
37+
idx = np.arange(x.shape[0])
38+
x = x[idx]
39+
y = y[idx]
40+
41+
def load_alphabets_img() -> tuple[np.ndarray, np.ndarray]:
42+
"""Load the MNIST dataset.
43+
44+
Returns:
45+
data: data
46+
labels: labels
47+
"""
48+
# Download latest version
49+
path = kagglehub.dataset_download(
50+
"sachinpatel21/az-handwritten-alphabets-in-csv-format"
51+
)
52+
df = pd.read_csv(path + "/A_Z Handwritten Data.csv")
53+
labels = df.iloc[:, 0]
54+
data = df.iloc[:, 1:]
55+
return data, labels
56+
57+
alphabets_x, alphabets_y = load_alphabets_img()
58+
alphabets_img = {}
59+
for i in range(26):
60+
alphabets_img[i + 1] = alphabets_x[alphabets_y == i][:100]
61+
62+
alphabets_hand = []
63+
for i in range(x.shape[0]):
64+
label = int(y[i].split(".")[0])
65+
random_idx = np.random.choice(alphabets_img[label].shape[0])
66+
random_df = alphabets_img[label].iloc[random_idx].to_numpy()
67+
random_df = random_df.reshape(28, 28).astype(np.uint8)
68+
# save image to png
69+
path = Path(PATH, f"alphabet_{label}_{random_idx}.png")
70+
Image.fromarray(random_df, mode="L").save(path)
71+
alphabets_hand.append(path)
72+
return (
73+
x,
74+
y,
75+
num2alphabet,
76+
alphabets_hand,
77+
)
78+
79+
80+
def tsfresh_features() -> np.ndarray:
81+
"""Extract tsfresh features from the data.
82+
83+
Returns:
84+
features: features
85+
"""
86+
data, labels, num2alphabet, alphabets_hand = load_handwriting()
87+
88+
path = Path(PATH_SAVE, "Handwriting_tsfresh.csv")
89+
90+
if path.exists():
91+
df = pd.read_csv(path)
92+
else:
93+
# convert data to a df
94+
# column_id: id, column_sort: time, values: 3 channels
95+
df = pd.DataFrame(columns=["id", "time", "channel_1", "channel_2", "channel_3"])
96+
for idx in range(data.shape[0]):
97+
for time in range(data.shape[2]): # 152
98+
df.loc[idx, "id"] = idx
99+
df.loc[idx, "time"] = time
100+
df.loc[idx, "channel_1"] = data[idx, 0, time]
101+
df.loc[idx, "channel_2"] = data[idx, 1, time]
102+
df.loc[idx, "channel_3"] = data[idx, 2, time]
103+
print(df.head())
104+
print(df.tail())
105+
106+
df.to_csv(path, index=False)
107+
ts_features = extract_features(df, column_id="id", column_sort="time")
108+
ts_features = ts_features.dropna(axis=1)
109+
print(type(ts_features))
110+
print(ts_features.shape)
111+
print(ts_features.head())
112+
print("ts_features shape:", ts_features.shape)
113+
with Path(PATH_SAVE, "Handwriting_emb_tsfresh.pkl.pkl").open("wb") as f:
114+
pickle.dump(ts_features, f)
115+
print("TSFresh features saved")
116+
117+
118+
if __name__ == "__main__":
119+
tsfresh_features()

mmda/utils/cca_class.py

+14-18
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def fit_transform_train_data(
6767
corr_coeff >= 0
6868
).all(), f"Correlation should be non-negative. {corr_coeff}"
6969
assert (
70-
corr_coeff <= 1
70+
corr_coeff <= 1.05 # noqa: PLR2004
7171
).all(), f"Correlation should be less than 1. {corr_coeff}"
7272
self.corr_coeff = corr_coeff
7373
self.traindata1, self.traindata2 = traindata1, traindata2
@@ -141,6 +141,8 @@ def fit_transform_train_data(
141141
corr_coeff: the correlation coefficient. shape: (dim,)
142142
"""
143143
# Check the shape of the training data
144+
traindata1 = traindata1.astype(np.float32)
145+
traindata2 = traindata2.astype(np.float32)
144146
# zero mean data
145147
traindata1, traindata1_mean = origin_centered(traindata1)
146148
traindata2, traindata2_mean = origin_centered(traindata2)
@@ -155,23 +157,15 @@ def fit_transform_train_data(
155157
), f"traindata2align not zero mean: {max(abs(traindata2.mean(axis=0)))}"
156158

157159
# CCA dimensionality reduction
158-
print((traindata1.T @ traindata1).shape)
159-
sigma_z1_inv = np.linalg.inv(traindata1.T @ traindata1)
160+
sigma_z1_inv = np.linalg.inv(
161+
traindata1.T @ traindata1 + np.eye(traindata1.shape[1]) * 1e-5
162+
)
160163
sigma_z1_inv_sqrt = sqrtm(sigma_z1_inv)
161-
assert np.allclose(
162-
sigma_z1_inv_sqrt @ sigma_z1_inv_sqrt, sigma_z1_inv
163-
), "sigma_z1_inv_sqrt is not the square root of sigma_z1_inv"
164164
sigma_z2_inv = np.linalg.inv(traindata2.T @ traindata2)
165165
sigma_z2_inv_sqrt = sqrtm(sigma_z2_inv)
166-
assert np.allclose(
167-
sigma_z2_inv_sqrt @ sigma_z2_inv_sqrt, sigma_z2_inv
168-
), "sigma_z2_inv_sqrt is not the square root of sigma_z2_inv"
169166

170167
svd_mat = sigma_z1_inv_sqrt @ traindata1.T @ traindata2 @ sigma_z2_inv_sqrt
171168
u, s, vh = np.linalg.svd(svd_mat)
172-
assert np.allclose(
173-
u @ np.diag(s) @ vh, svd_mat
174-
), "svd_mat is not the SVD of svd_mat"
175169

176170
self.A = u @ sigma_z1_inv_sqrt
177171
self.B = vh @ sigma_z2_inv_sqrt
@@ -180,13 +174,12 @@ def fit_transform_train_data(
180174
assert (
181175
corr_coeff >= 0
182176
).all(), f"Correlation should be non-negative. {corr_coeff}"
183-
assert (
184-
corr_coeff <= 1
185-
).all(), f"Correlation should be less than 1. {corr_coeff}"
186177
self.corr_coeff = corr_coeff
178+
if self.sim_dim is None:
179+
self.sim_dim = cfg_dataset.sim_dim
187180
self.traindata1, self.traindata2 = (
188-
(self.A @ traindata1.T).T,
189-
(self.B @ traindata2.T).T,
181+
(self.A @ traindata1.T).T[:, : self.sim_dim],
182+
(self.B @ traindata2.T).T[:, : self.sim_dim],
190183
)
191184
return self.traindata1, self.traindata2, corr_coeff
192185

@@ -203,12 +196,15 @@ def transform_data(
203196
data1: the first transformed data. shape: (num_samples, dim)
204197
data2: the second transformed data. shape: (num_samples, dim)
205198
"""
199+
data1 = data1.astype(np.float32)
200+
data2 = data2.astype(np.float32)
206201
assert self.traindata1_mean is not None, "Please fit the cca model first."
207202
assert self.traindata2_mean is not None, "Please fit the cca model first."
208203
# zero mean data and transform
209204
data1 = data1 - self.traindata1_mean
210205
data2 = data2 - self.traindata2_mean
211-
data1, data2 = (self.A @ data1.T).T, (self.B @ data2.T).T
206+
data1 = (self.A @ data1.T).T[:, : self.sim_dim]
207+
data2 = (self.B @ data2.T).T[:, : self.sim_dim]
212208
return data1, data2
213209

214210
def save_model(self, path: str | Path) -> None:

0 commit comments

Comments
 (0)