|
| 1 | +# # %% |
| 2 | +# %load_ext autoreload |
| 3 | +# %autoreload 2 |
| 4 | + |
| 5 | +# %% |
| 6 | +import jax |
| 7 | +import jax.numpy as jnp |
| 8 | +import numpy as np |
| 9 | +import polars as pl |
| 10 | + |
| 11 | +# %% |
| 12 | +dataset_paths = [ |
| 13 | + "data/CTGAN/covertype", |
| 14 | + "data/CTGAN/kddcup", |
| 15 | + "data/CTGAN/sydt", |
| 16 | + "data/lpm/CES", |
| 17 | + "data/lpm/PUMS", |
| 18 | + "data/lpm/PUMD", |
| 19 | +] |
| 20 | + |
| 21 | +# times = { |
| 22 | +# "covertype": 304.31628918647766, |
| 23 | +# "kddcup": 5809.921473503113, |
| 24 | +# "sydt": 4147.667294740677, |
| 25 | +# "CES": 49.699519634246826, |
| 26 | +# "PUMS": 557.475483417511, |
| 27 | +# "PUMD": 127.56089353561401, |
| 28 | +# } |
| 29 | +times_single_rejuvenation_100 = { |
| 30 | + "covertype": 5.269169092178345, |
| 31 | + "kddcup": 61.410168170928955, |
| 32 | + "sydt": 46.48610043525696, |
| 33 | + "CES": 2.556819438934326, |
| 34 | + "PUMS": 7.585843563079834, |
| 35 | + "PUMD": 5.1315598487854, |
| 36 | +} |
| 37 | +times_single_rejuvenation_300 = { |
| 38 | + "covertype": 35.914592266082764, |
| 39 | + "kddcup": 511.4724328517914, |
| 40 | + "sydt": 376.814204454422, |
| 41 | + "CES": 13.747230291366577, |
| 42 | + "PUMS": 57.392295598983765, |
| 43 | + "PUMD": 38.446903228759766, |
| 44 | +} |
| 45 | +times_single_rejuvenation_500 = { |
| 46 | + "covertype": 95.82071185112, |
| 47 | + "kddcup": 1392.8732736110687, |
| 48 | + "sydt": 1017.3137822151184, |
| 49 | + "CES": 33.89802026748657, |
| 50 | + "PUMS": 159.05469465255737, |
| 51 | + "PUMD": 109.2985634803772, |
| 52 | +} |
| 53 | + |
| 54 | + |
| 55 | + |
| 56 | +# %% |
| 57 | +from minijaxmix.io import load_huggingface, discretize_dataframe, to_dummies |
| 58 | +from minijaxmix.infer import sample_categorical |
| 59 | +from minijaxmix.distances import js |
| 60 | +from functools import partial |
| 61 | + |
| 62 | +# partial_js = partial(js, batch_size=10) |
| 63 | +# jit_js = jax.jit(partial_js) |
| 64 | +jit_js = jax.jit(js) |
| 65 | + |
| 66 | +dfs = [] |
| 67 | +for dataset_path in dataset_paths: |
| 68 | + print(dataset_path) |
| 69 | + train_df, test_df = load_huggingface(dataset_path) |
| 70 | + df = pl.concat((train_df, test_df)) |
| 71 | + |
| 72 | + schema, discretized_df, categorical_idxs = discretize_dataframe(df) |
| 73 | + dummies_df = to_dummies(discretized_df) |
| 74 | + data = dummies_df.to_numpy().astype(np.bool_) |
| 75 | + |
| 76 | + train_data = data[:len(train_df)] |
| 77 | + test_data = data[len(train_df):][:10000] |
| 78 | + |
| 79 | + files = jnp.load(f"{dataset_path.split('/')[-1]}_single_rejuvenation.npz") |
| 80 | + |
| 81 | + p_ys = files["p_ys"] |
| 82 | + ws = files["ws"] |
| 83 | + |
| 84 | + n_sample = 10000 |
| 85 | + |
| 86 | + cs = jax.random.categorical(jax.random.PRNGKey(0), jnp.log(p_ys), shape=(n_sample,)) |
| 87 | + sample_ws = ws.take(cs, axis=0) |
| 88 | + n_categories = categorical_idxs.max() + 1 |
| 89 | + |
| 90 | + samples = jax.vmap(sample_categorical, in_axes=(0, 0, None, None))(jax.random.split(jax.random.PRNGKey(0), n_sample), jnp.log(sample_ws), categorical_idxs, n_categories) |
| 91 | + |
| 92 | + distances = jit_js(jnp.array(test_data), jnp.array(samples)) |
| 93 | + |
| 94 | + dfs.append(pl.DataFrame({ |
| 95 | + "distance": np.array(distances), |
| 96 | + "dataset": dataset_path, |
| 97 | + "model": "GenJaxMix", |
| 98 | + "time": times_single_rejuvenation_300[dataset_path.split("/")[-1]] |
| 99 | + })) |
| 100 | + |
| 101 | +# %% |
| 102 | +result_df = pl.concat(dfs) |
| 103 | + |
| 104 | +# %% |
| 105 | +# times_no_rejuvenation = { |
| 106 | +# "covertype": 35.12278389930725, |
| 107 | +# "kddcup": 495.74342131614685, |
| 108 | +# "sydt": 365.2887644767761, |
| 109 | +# "CES": 13.5840482711792, |
| 110 | +# "PUMS": 55.90764021873474, |
| 111 | +# "PUMD": 38.18173289299011, |
| 112 | +# } |
| 113 | + |
| 114 | +# %% |
| 115 | +dfs = [] |
| 116 | +for dataset_path in dataset_paths: |
| 117 | + print(dataset_path) |
| 118 | + train_df, test_df = load_huggingface(dataset_path) |
| 119 | + df = pl.concat((train_df, test_df)) |
| 120 | + |
| 121 | + schema, discretized_df, categorical_idxs = discretize_dataframe(df) |
| 122 | + dummies_df = to_dummies(discretized_df) |
| 123 | + data = dummies_df.to_numpy().astype(np.bool_) |
| 124 | + |
| 125 | + train_data = data[:len(train_df)] |
| 126 | + test_data = data[len(train_df):][:10000] |
| 127 | + |
| 128 | + files = jnp.load(f"{dataset_path.split('/')[-1]}_single_rejuvenation_100.npz") |
| 129 | + |
| 130 | + p_ys = files["p_ys"] |
| 131 | + ws = files["ws"] |
| 132 | + |
| 133 | + n_sample = 10000 |
| 134 | + |
| 135 | + cs = jax.random.categorical(jax.random.PRNGKey(0), jnp.log(p_ys), shape=(n_sample,)) |
| 136 | + sample_ws = ws.take(cs, axis=0) |
| 137 | + n_categories = categorical_idxs.max() + 1 |
| 138 | + |
| 139 | + samples = jax.vmap(sample_categorical, in_axes=(0, 0, None, None))(jax.random.split(jax.random.PRNGKey(0), n_sample), jnp.log(sample_ws), categorical_idxs, n_categories) |
| 140 | + |
| 141 | + distances = jit_js(jnp.array(test_data), jnp.array(samples)) |
| 142 | + |
| 143 | + dfs.append(pl.DataFrame({ |
| 144 | + "distance": np.array(distances), |
| 145 | + "dataset": dataset_path, |
| 146 | + "model": "GenJaxMix", |
| 147 | + "time": times_single_rejuvenation_100[dataset_path.split("/")[-1]] |
| 148 | + })) |
| 149 | + |
| 150 | +# %% |
| 151 | +no_rejuvenation_result_df = pl.concat(dfs) |
| 152 | + |
| 153 | +# %% |
| 154 | +prev_result_df = pl.read_parquet("distance_synth_data.parquet") |
| 155 | + |
| 156 | +# %% |
| 157 | +new_result_df = pl.concat((prev_result_df, result_df, no_rejuvenation_result_df), how="diagonal") |
| 158 | + |
| 159 | +# %% |
| 160 | +new_result_df.write_parquet("new_distance_synth_data.parquet") |
| 161 | +# %% |
0 commit comments