@@ -579,6 +579,28 @@ def step(carry, bijection):
579579 )
580580 return y , log_det
581581
582+ def inverse_gradient_and_val (
583+ self ,
584+ y : Array ,
585+ y_grad : Array ,
586+ y_logp : Array ,
587+ condition : Array | None = None ,
588+ ) -> tuple [Array , Array , Array ]:
589+ def step (carry , bijection ):
590+ from nutpie .transform_adapter import inverse_gradient_and_val
591+
592+ carry = inverse_gradient_and_val (bijection , * carry )
593+ return (carry , None )
594+
595+ (y , y_grad , y_logp ), _ = _filter_scan (
596+ step ,
597+ (y , y_grad , y_logp ),
598+ self .bijection ,
599+ reverse = True ,
600+ filter_spec = self .filter_spec ,
601+ )
602+ return y , y_grad , y_logp
603+
582604 @property
583605 def shape (self ):
584606 return self .bijection .shape
@@ -961,6 +983,16 @@ def make_mlp(out_size):
961983 )
962984
963985
986+ class Add (eqx .Module ):
987+ bias : Array
988+
989+ def __init__ (self , bias ):
990+ self .bias = bias
991+
992+ def __call__ (self , x : Array , * , key = None ) -> Array :
993+ return x + self .bias
994+
995+
964996def make_flow_scan (
965997 key ,
966998 n_dim ,
@@ -984,15 +1016,47 @@ def make_flow_scan(
9841016 if nn_depth is None :
9851017 nn_depth = 1
9861018
1019+ def make_transformer ():
1020+ elemwises = []
1021+ # loc = bijections.Loc(jnp.zeros(()))
1022+ # elemwises.append(loc)
1023+
1024+ for loc in [0.0 ]:
1025+ scale = Parameterize (lambda x : x + jnp .sqrt (1 + x ** 2 ), jnp .zeros (()))
1026+ theta = Parameterize (lambda x : x + jnp .sqrt (1 + x ** 2 ), jnp .zeros (()))
1027+
1028+ affine = AsymmetricAffine (
1029+ jnp .zeros (()) + loc ,
1030+ jnp .ones (()),
1031+ jnp .ones (()),
1032+ )
1033+
1034+ affine = eqx .tree_at (
1035+ where = lambda aff : aff .scale ,
1036+ pytree = affine ,
1037+ replace = scale ,
1038+ )
1039+ affine = eqx .tree_at (
1040+ where = lambda aff : aff .theta ,
1041+ pytree = affine ,
1042+ replace = theta ,
1043+ )
1044+ elemwises .append (bijections .Invert (affine ))
1045+
1046+ if len (elemwises ) == 1 :
1047+ return elemwises [0 ]
1048+ return bijections .Chain (elemwises )
1049+
9871050 # Just to get at the size
988- transformer = AsymmetricAffine ()
1051+ transformer = make_transformer ()
9891052 size = MaskedCoupling .conditioner_output_size (dim , transformer )
9901053
9911054 key , key1 = jax .random .split (key )
9921055 embed = eqx .nn .Sequential (
9931056 [
9941057 eqx .nn .Linear (dim , n_embed , key = key1 , dtype = jnp .float32 ),
995- eqx .nn .LayerNorm (shape = (n_embed ,), dtype = jnp .float32 ),
1058+ # Activation(_NN_ACTIVATION),
1059+ # eqx.nn.LayerNorm(shape=(n_embed,), dtype=jnp.float32),
9961060 ]
9971061 )
9981062 key , key1 = jax .random .split (key )
@@ -1005,37 +1069,15 @@ def make_flow_scan(
10051069 for i in range (len (mask )):
10061070 mask [i , order [i , : counts [i ]]] = True
10071071
1008- def make_transformer ():
1009- scale = Parameterize (lambda x : x + jnp .sqrt (1 + x ** 2 ), jnp .zeros (()))
1010- theta = Parameterize (lambda x : x + jnp .sqrt (1 + x ** 2 ), jnp .zeros (()))
1011-
1012- affine = AsymmetricAffine (
1013- jnp .zeros (()),
1014- jnp .ones (()),
1015- jnp .ones (()),
1016- )
1017-
1018- affine = eqx .tree_at (
1019- where = lambda aff : aff .scale ,
1020- pytree = affine ,
1021- replace = scale ,
1022- )
1023- affine = eqx .tree_at (
1024- where = lambda aff : aff .theta ,
1025- pytree = affine ,
1026- replace = theta ,
1027- )
1028-
1029- return bijections .Invert (affine )
1030-
10311072 def make_mvscale (key , n_dim ):
10321073 params = jax .random .normal (key , (n_dim ,))
10331074 params = params / jnp .linalg .norm (params )
10341075 return MvScale (params )
10351076
10361077 def make_layer (key , mask , embed , embed_back ):
1037- key1 , key2 , key3 , key4 = jax .random .split (key , 4 )
1078+ key1 , key2 , key3 , key4 , key5 = jax .random .split (key , 5 )
10381079 transformer = make_transformer ()
1080+ bias = Add (jax .random .normal (key5 , (size ,)) * 0.01 )
10391081
10401082 conditioner = eqx .nn .Sequential (
10411083 [
@@ -1049,7 +1091,12 @@ def make_layer(key, mask, embed, embed_back):
10491091 dtype = jnp .float32 ,
10501092 activation = _NN_ACTIVATION ,
10511093 ),
1052- embed_back ,
1094+ eqx .nn .Sequential (
1095+ [
1096+ embed_back ,
1097+ bias ,
1098+ ]
1099+ ),
10531100 ]
10541101 )
10551102
@@ -1083,7 +1130,7 @@ def make_layer(key, mask, embed, embed_back):
10831130 replace = None ,
10841131 )
10851132 out_axes = eqx .tree_at (
1086- lambda tree : tree .bijections [0 ].conditioner .layers [1 ].layers [- 1 ],
1133+ lambda tree : tree .bijections [0 ].conditioner .layers [1 ].layers [- 1 ]. layers [ 0 ] ,
10871134 pytree = out_axes ,
10881135 replace = None ,
10891136 )
@@ -1100,7 +1147,7 @@ def make_layer(key, mask, embed, embed_back):
11001147 replace = False ,
11011148 )
11021149 vectorize = eqx .tree_at (
1103- lambda tree : tree .bijections [0 ].conditioner .layers [1 ].layers [- 1 ],
1150+ lambda tree : tree .bijections [0 ].conditioner .layers [1 ].layers [- 1 ]. layers [ 0 ] ,
11041151 pytree = vectorize ,
11051152 replace = False ,
11061153 )
@@ -1234,10 +1281,6 @@ def make_flow(
12341281 return
12351282
12361283 n_draws , n_dim = positions .shape
1237-
1238- if n_dim < 2 :
1239- n_layers = 0
1240-
12411284 assert positions .shape == gradients .shape
12421285
12431286 if n_draws == 0 :
0 commit comments