Skip to content

Commit 469de70

Browse files
committed
feat: optional test set logprobs during inference
1 parent 76b0516 commit 469de70

File tree

3 files changed

+210
-12
lines changed

3 files changed

+210
-12
lines changed

distance_synth_data_genjaxmix.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# # %%
2+
# %load_ext autoreload
3+
# %autoreload 2
4+
5+
# %%
6+
import jax
7+
import jax.numpy as jnp
8+
import numpy as np
9+
import polars as pl
10+
11+
# %%
12+
dataset_paths = [
13+
"data/CTGAN/covertype",
14+
"data/CTGAN/kddcup",
15+
"data/CTGAN/sydt",
16+
"data/lpm/CES",
17+
"data/lpm/PUMS",
18+
"data/lpm/PUMD",
19+
]
20+
21+
# times = {
22+
# "covertype": 304.31628918647766,
23+
# "kddcup": 5809.921473503113,
24+
# "sydt": 4147.667294740677,
25+
# "CES": 49.699519634246826,
26+
# "PUMS": 557.475483417511,
27+
# "PUMD": 127.56089353561401,
28+
# }
29+
times_single_rejuvenation_100 = {
30+
"covertype": 5.269169092178345,
31+
"kddcup": 61.410168170928955,
32+
"sydt": 46.48610043525696,
33+
"CES": 2.556819438934326,
34+
"PUMS": 7.585843563079834,
35+
"PUMD": 5.1315598487854,
36+
}
37+
times_single_rejuvenation_300 = {
38+
"covertype": 35.914592266082764,
39+
"kddcup": 511.4724328517914,
40+
"sydt": 376.814204454422,
41+
"CES": 13.747230291366577,
42+
"PUMS": 57.392295598983765,
43+
"PUMD": 38.446903228759766,
44+
}
45+
times_single_rejuvenation_500 = {
46+
"covertype": 95.82071185112,
47+
"kddcup": 1392.8732736110687,
48+
"sydt": 1017.3137822151184,
49+
"CES": 33.89802026748657,
50+
"PUMS": 159.05469465255737,
51+
"PUMD": 109.2985634803772,
52+
}
53+
54+
55+
56+
# %%
57+
from minijaxmix.io import load_huggingface, discretize_dataframe, to_dummies
58+
from minijaxmix.infer import sample_categorical
59+
from minijaxmix.distances import js
60+
from functools import partial
61+
62+
# partial_js = partial(js, batch_size=10)
63+
# jit_js = jax.jit(partial_js)
64+
jit_js = jax.jit(js)
65+
66+
dfs = []
67+
for dataset_path in dataset_paths:
68+
print(dataset_path)
69+
train_df, test_df = load_huggingface(dataset_path)
70+
df = pl.concat((train_df, test_df))
71+
72+
schema, discretized_df, categorical_idxs = discretize_dataframe(df)
73+
dummies_df = to_dummies(discretized_df)
74+
data = dummies_df.to_numpy().astype(np.bool_)
75+
76+
train_data = data[:len(train_df)]
77+
test_data = data[len(train_df):][:10000]
78+
79+
files = jnp.load(f"{dataset_path.split('/')[-1]}_single_rejuvenation.npz")
80+
81+
p_ys = files["p_ys"]
82+
ws = files["ws"]
83+
84+
n_sample = 10000
85+
86+
cs = jax.random.categorical(jax.random.PRNGKey(0), jnp.log(p_ys), shape=(n_sample,))
87+
sample_ws = ws.take(cs, axis=0)
88+
n_categories = categorical_idxs.max() + 1
89+
90+
samples = jax.vmap(sample_categorical, in_axes=(0, 0, None, None))(jax.random.split(jax.random.PRNGKey(0), n_sample), jnp.log(sample_ws), categorical_idxs, n_categories)
91+
92+
distances = jit_js(jnp.array(test_data), jnp.array(samples))
93+
94+
dfs.append(pl.DataFrame({
95+
"distance": np.array(distances),
96+
"dataset": dataset_path,
97+
"model": "GenJaxMix",
98+
"time": times_single_rejuvenation_300[dataset_path.split("/")[-1]]
99+
}))
100+
101+
# %%
102+
result_df = pl.concat(dfs)
103+
104+
# %%
105+
# times_no_rejuvenation = {
106+
# "covertype": 35.12278389930725,
107+
# "kddcup": 495.74342131614685,
108+
# "sydt": 365.2887644767761,
109+
# "CES": 13.5840482711792,
110+
# "PUMS": 55.90764021873474,
111+
# "PUMD": 38.18173289299011,
112+
# }
113+
114+
# %%
115+
dfs = []
116+
for dataset_path in dataset_paths:
117+
print(dataset_path)
118+
train_df, test_df = load_huggingface(dataset_path)
119+
df = pl.concat((train_df, test_df))
120+
121+
schema, discretized_df, categorical_idxs = discretize_dataframe(df)
122+
dummies_df = to_dummies(discretized_df)
123+
data = dummies_df.to_numpy().astype(np.bool_)
124+
125+
train_data = data[:len(train_df)]
126+
test_data = data[len(train_df):][:10000]
127+
128+
files = jnp.load(f"{dataset_path.split('/')[-1]}_single_rejuvenation_100.npz")
129+
130+
p_ys = files["p_ys"]
131+
ws = files["ws"]
132+
133+
n_sample = 10000
134+
135+
cs = jax.random.categorical(jax.random.PRNGKey(0), jnp.log(p_ys), shape=(n_sample,))
136+
sample_ws = ws.take(cs, axis=0)
137+
n_categories = categorical_idxs.max() + 1
138+
139+
samples = jax.vmap(sample_categorical, in_axes=(0, 0, None, None))(jax.random.split(jax.random.PRNGKey(0), n_sample), jnp.log(sample_ws), categorical_idxs, n_categories)
140+
141+
distances = jit_js(jnp.array(test_data), jnp.array(samples))
142+
143+
dfs.append(pl.DataFrame({
144+
"distance": np.array(distances),
145+
"dataset": dataset_path,
146+
"model": "GenJaxMix",
147+
"time": times_single_rejuvenation_100[dataset_path.split("/")[-1]]
148+
}))
149+
150+
# %%
151+
no_rejuvenation_result_df = pl.concat(dfs)
152+
153+
# %%
154+
prev_result_df = pl.read_parquet("distance_synth_data.parquet")
155+
156+
# %%
157+
new_result_df = pl.concat((prev_result_df, result_df, no_rejuvenation_result_df), how="diagonal")
158+
159+
# %%
160+
new_result_df.write_parquet("new_distance_synth_data.parquet")
161+
# %%

minijaxmix/infer.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
import jax
2-
jax.config.update("jax_compilation_cache_dir", "jax_cache")
3-
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
4-
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
5-
62
import jax.numpy as jnp
73
from functools import partial
8-
import numpy as np
94
from jaxtyping import Array, Float, Bool, Integer
10-
import time
11-
from minijaxmix.query import sample_dirichlet, sample_categorical
5+
from minijaxmix.query import sample_dirichlet, logprob
126

137
ALPHA = 1e-5
148

@@ -22,8 +16,8 @@ def conditional_entropy(data, c):
2216
res = - jnp.sum(jnp.where(c, c * p_x_y, 0), axis=0) / jnp.sum(c, axis=0)
2317
return res
2418

25-
@partial(jax.jit, static_argnames=("n_clusters", "n_gibbs", "n_categories", "n_branch", "rejuvenation", "minibatch_size"))
26-
def infer(key, data, categorical_idxs, n_clusters, n_gibbs, n_categories, n_branch=2, rejuvenation=True, minibatch_size=1000):
19+
@partial(jax.jit, static_argnames=("n_clusters", "n_gibbs", "n_categories", "n_branch", "rejuvenation", "minibatch_size", "test"))
20+
def infer(key, data, categorical_idxs, n_clusters, n_gibbs, n_categories, n_branch=2, rejuvenation=True, minibatch_size=1000, test=False, test_data=None):
2721
N, k = data.shape
2822
p_ys = jnp.zeros(n_clusters)
2923
p_ys = p_ys.at[0].set(1.)
@@ -77,7 +71,13 @@ def infer_step(carry, key_i):
7771

7872
total_H_split = jnp.nansum(conditional_H * p_y) - jnp.nansum(p_y * jnp.log(p_y))
7973

80-
return (p_y, w, conditional_H), (total_H_split, total_H_hard_clustering)
74+
if test:
75+
logprobs = jax.vmap(jax.vmap(logprob, in_axes=(None, 0)), in_axes=(0, None))(test_data, w)
76+
logprobs = jax.nn.logsumexp(logprobs, b=p_y, axis=1)
77+
logprobs = jnp.sum(logprobs)
78+
return (p_y, w, conditional_H), (total_H_split, total_H_hard_clustering, logprobs)
79+
else:
80+
return (p_y, w, conditional_H), (total_H_split, total_H_hard_clustering, None)
8181

8282
def rejuvenation(carry, key):
8383
p_y, w, conditional_H = carry
@@ -100,13 +100,13 @@ def rejuvenation_step(p_y_w, key):
100100
keys = jax.random.split(subkey, n_clusters - 1)
101101
# we could use lax.scan here, but at the cost of padding each step to the max number of clusters
102102

103-
(p_ys, ws, conditional_H), (total_H_split, total_H_hard_clustering) = jax.lax.scan(
103+
(p_ys, ws, conditional_H), (total_H_split, total_H_hard_clustering, logprobs) = jax.lax.scan(
104104
infer_step, (p_ys, ws, conditional_H), (keys, jnp.arange(n_clusters - 1)))
105105

106106
if rejuvenation:
107107
(p_ys, ws, conditional_H), total_H_rejuvenation = rejuvenation((p_ys, ws, conditional_H), key)
108108

109-
return p_ys, ws, conditional_H, total_H_split, total_H_rejuvenation, total_H_hard_clustering
109+
return p_ys, ws, conditional_H, total_H_split, total_H_rejuvenation, total_H_hard_clustering, logprobs
110110

111111
def make_minibatches(key, data, c, num_clusters, minibatch_size):
112112
keys = jax.random.split(key, num_clusters)

plot_time.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# %%
2+
import plotnine as pn
3+
import polars as pl
4+
5+
# %%
6+
df = pl.read_parquet("new_distance_synth_data.parquet")
7+
df
8+
9+
# %%
10+
median_df = df.group_by(["dataset", "model", "time"]).agg(pl.median("distance").alias("median_distance"))
11+
median_df
12+
13+
14+
# %%
15+
dataset_map = {
16+
"data/CTGAN/covertype": "Covertype",
17+
"data/CTGAN/kddcup": "KDDCup",
18+
"data/CTGAN/sydt": "SYDT",
19+
"data/lpm/CES": "CES",
20+
"data/lpm/PUMS": "PUMS",
21+
"data/lpm/PUMD": "PUMD",
22+
}
23+
median_df = median_df.with_columns(pl.col("dataset").replace(dataset_map))
24+
25+
# %%
26+
(
27+
# pn.ggplot(median_df.filter(pl.col("dataset") == "data/CTGAN/covertype"))
28+
pn.ggplot(median_df)
29+
+ pn.geom_line(pn.aes(x="time", y="median_distance", color="model", fill="model"))
30+
+ pn.geom_point(pn.aes(x="time", y="median_distance", color="model", fill="model"))
31+
+ pn.labs(y="2D Jensen-Shannon distance between\nreal and synthetic data (median)", x="Training time (seconds)")
32+
+ pn.scale_x_log10()
33+
+ pn.scale_y_log10()
34+
+ pn.facet_wrap("~dataset", scales="free")
35+
)
36+
37+
# %%

0 commit comments

Comments
 (0)