Skip to content

Commit 34a929d

Browse files
committed
feat: small improvements to the masked normalizing flow
1 parent 16fde95 commit 34a929d

File tree

3 files changed

+87
-42
lines changed

3 files changed

+87
-42
lines changed

python/nutpie/normalizing_flow.py

Lines changed: 76 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,28 @@ def step(carry, bijection):
579579
)
580580
return y, log_det
581581

582+
def inverse_gradient_and_val(
583+
self,
584+
y: Array,
585+
y_grad: Array,
586+
y_logp: Array,
587+
condition: Array | None = None,
588+
) -> tuple[Array, Array, Array]:
589+
def step(carry, bijection):
590+
from nutpie.transform_adapter import inverse_gradient_and_val
591+
592+
carry = inverse_gradient_and_val(bijection, *carry)
593+
return (carry, None)
594+
595+
(y, y_grad, y_logp), _ = _filter_scan(
596+
step,
597+
(y, y_grad, y_logp),
598+
self.bijection,
599+
reverse=True,
600+
filter_spec=self.filter_spec,
601+
)
602+
return y, y_grad, y_logp
603+
582604
@property
583605
def shape(self):
584606
return self.bijection.shape
@@ -961,6 +983,16 @@ def make_mlp(out_size):
961983
)
962984

963985

986+
class Add(eqx.Module):
987+
bias: Array
988+
989+
def __init__(self, bias):
990+
self.bias = bias
991+
992+
def __call__(self, x: Array, *, key=None) -> Array:
993+
return x + self.bias
994+
995+
964996
def make_flow_scan(
965997
key,
966998
n_dim,
@@ -984,15 +1016,47 @@ def make_flow_scan(
9841016
if nn_depth is None:
9851017
nn_depth = 1
9861018

1019+
def make_transformer():
1020+
elemwises = []
1021+
# loc = bijections.Loc(jnp.zeros(()))
1022+
# elemwises.append(loc)
1023+
1024+
for loc in [0.0]:
1025+
scale = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(()))
1026+
theta = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(()))
1027+
1028+
affine = AsymmetricAffine(
1029+
jnp.zeros(()) + loc,
1030+
jnp.ones(()),
1031+
jnp.ones(()),
1032+
)
1033+
1034+
affine = eqx.tree_at(
1035+
where=lambda aff: aff.scale,
1036+
pytree=affine,
1037+
replace=scale,
1038+
)
1039+
affine = eqx.tree_at(
1040+
where=lambda aff: aff.theta,
1041+
pytree=affine,
1042+
replace=theta,
1043+
)
1044+
elemwises.append(bijections.Invert(affine))
1045+
1046+
if len(elemwises) == 1:
1047+
return elemwises[0]
1048+
return bijections.Chain(elemwises)
1049+
9871050
# Just to get at the size
988-
transformer = AsymmetricAffine()
1051+
transformer = make_transformer()
9891052
size = MaskedCoupling.conditioner_output_size(dim, transformer)
9901053

9911054
key, key1 = jax.random.split(key)
9921055
embed = eqx.nn.Sequential(
9931056
[
9941057
eqx.nn.Linear(dim, n_embed, key=key1, dtype=jnp.float32),
995-
eqx.nn.LayerNorm(shape=(n_embed,), dtype=jnp.float32),
1058+
# Activation(_NN_ACTIVATION),
1059+
# eqx.nn.LayerNorm(shape=(n_embed,), dtype=jnp.float32),
9961060
]
9971061
)
9981062
key, key1 = jax.random.split(key)
@@ -1005,37 +1069,15 @@ def make_flow_scan(
10051069
for i in range(len(mask)):
10061070
mask[i, order[i, : counts[i]]] = True
10071071

1008-
def make_transformer():
1009-
scale = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(()))
1010-
theta = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(()))
1011-
1012-
affine = AsymmetricAffine(
1013-
jnp.zeros(()),
1014-
jnp.ones(()),
1015-
jnp.ones(()),
1016-
)
1017-
1018-
affine = eqx.tree_at(
1019-
where=lambda aff: aff.scale,
1020-
pytree=affine,
1021-
replace=scale,
1022-
)
1023-
affine = eqx.tree_at(
1024-
where=lambda aff: aff.theta,
1025-
pytree=affine,
1026-
replace=theta,
1027-
)
1028-
1029-
return bijections.Invert(affine)
1030-
10311072
def make_mvscale(key, n_dim):
10321073
params = jax.random.normal(key, (n_dim,))
10331074
params = params / jnp.linalg.norm(params)
10341075
return MvScale(params)
10351076

10361077
def make_layer(key, mask, embed, embed_back):
1037-
key1, key2, key3, key4 = jax.random.split(key, 4)
1078+
key1, key2, key3, key4, key5 = jax.random.split(key, 5)
10381079
transformer = make_transformer()
1080+
bias = Add(jax.random.normal(key5, (size,)) * 0.01)
10391081

10401082
conditioner = eqx.nn.Sequential(
10411083
[
@@ -1049,7 +1091,12 @@ def make_layer(key, mask, embed, embed_back):
10491091
dtype=jnp.float32,
10501092
activation=_NN_ACTIVATION,
10511093
),
1052-
embed_back,
1094+
eqx.nn.Sequential(
1095+
[
1096+
embed_back,
1097+
bias,
1098+
]
1099+
),
10531100
]
10541101
)
10551102

@@ -1083,7 +1130,7 @@ def make_layer(key, mask, embed, embed_back):
10831130
replace=None,
10841131
)
10851132
out_axes = eqx.tree_at(
1086-
lambda tree: tree.bijections[0].conditioner.layers[1].layers[-1],
1133+
lambda tree: tree.bijections[0].conditioner.layers[1].layers[-1].layers[0],
10871134
pytree=out_axes,
10881135
replace=None,
10891136
)
@@ -1100,7 +1147,7 @@ def make_layer(key, mask, embed, embed_back):
11001147
replace=False,
11011148
)
11021149
vectorize = eqx.tree_at(
1103-
lambda tree: tree.bijections[0].conditioner.layers[1].layers[-1],
1150+
lambda tree: tree.bijections[0].conditioner.layers[1].layers[-1].layers[0],
11041151
pytree=vectorize,
11051152
replace=False,
11061153
)
@@ -1234,10 +1281,6 @@ def make_flow(
12341281
return
12351282

12361283
n_draws, n_dim = positions.shape
1237-
1238-
if n_dim < 2:
1239-
n_layers = 0
1240-
12411284
assert positions.shape == gradients.shape
12421285

12431286
if n_draws == 0:

python/nutpie/transform_adapter.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import optax
3030
from paramax import unwrap, NonTrainable
3131

32-
from nutpie.normalizing_flow import Coupling, extend_flow, make_flow
32+
from nutpie.normalizing_flow import Coupling, Scan, extend_flow, make_flow
3333
import tqdm
3434

3535
_BIJECTION_TRACE = []
@@ -241,6 +241,8 @@ def inner(bijection, y, y_grad, y_logp):
241241
axis_size=bijection.axis_size,
242242
)(bijection.bijection, draw, grad, jnp.zeros(()))
243243
return y, y_grad, jnp.sum(log_det) + logp
244+
elif isinstance(bijection, Scan):
245+
return bijection.inverse_gradient_and_val(draw, grad, logp)
244246
elif isinstance(bijection, bijections.Sandwich):
245247
draw, grad, logp = inverse_gradient_and_val(
246248
bijections.Invert(bijection.outer), draw, grad, logp
@@ -880,8 +882,8 @@ def make_transform_adapter(
880882
show_progress=False,
881883
nn_depth=None,
882884
nn_width=None,
883-
num_layers=9,
884-
num_diag_windows=9,
885+
num_layers=20,
886+
num_diag_windows=6,
885887
learning_rate=5e-4,
886888
untransformed_dim=None,
887889
zero_init=True,

tests/test_pymc.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,17 +278,17 @@ def test_normalizing_flow(kind):
278278
compiled = nutpie.compile_pymc_model(
279279
model, backend="jax", gradient_backend="jax"
280280
).with_transform_adapt(
281-
num_diag_windows=6,
282281
verbose=True,
283282
coupling_type=kind,
283+
num_layers=2,
284284
)
285285
trace = nutpie.sample(
286286
compiled,
287287
chains=1,
288288
transform_adapt=True,
289-
window_switch_freq=150,
290-
tune=600,
289+
window_switch_freq=128,
291290
seed=1,
291+
draws=2000,
292292
)
293293
draws = trace.posterior.x.isel(x_dim_0=0, chain=0)
294294
kstest = stats.ks_1samp(draws, stats.halfnorm.cdf)
@@ -309,17 +309,17 @@ def test_normalizing_flow_1d(kind):
309309
compiled = nutpie.compile_pymc_model(
310310
model, backend="jax", gradient_backend="jax"
311311
).with_transform_adapt(
312-
num_diag_windows=6,
313312
verbose=True,
314313
coupling_type=kind,
314+
num_layers=2,
315315
)
316316
trace = nutpie.sample(
317317
compiled,
318318
chains=1,
319319
transform_adapt=True,
320-
window_switch_freq=150,
321-
tune=600,
320+
window_switch_freq=128,
322321
seed=1,
322+
draws=2000,
323323
)
324324
draws = trace.posterior.x.isel(chain=0)
325325
kstest = stats.ks_1samp(draws, stats.halfnorm.cdf)

0 commit comments

Comments
 (0)