diff --git a/folx/ad.py b/folx/ad.py index 6e06fc0..0b39b2c 100644 --- a/folx/ad.py +++ b/folx/ad.py @@ -9,6 +9,30 @@ def is_tree_complex(tree): return any(jnp.iscomplexobj(leaf) for leaf in leaves) +def _varying_axes(x: jax.Array) -> tuple: + if not hasattr(jax, 'typeof'): + return () + + typ = jax.typeof(x) + manual_axis_type = getattr(typ, 'manual_axis_type', None) + if manual_axis_type is not None: + return tuple(getattr(manual_axis_type, 'varying', ())) + + return tuple(getattr(typ, 'vma', ())) + + +def _mark_varying_like(x: jax.Array, like: jax.Array) -> jax.Array: + axes = _varying_axes(like) + if not axes: + return x + + if hasattr(jax.lax, 'pcast'): + return jax.lax.pcast(x, axes, to='varying') + if hasattr(jax.lax, 'pvary'): + return jax.lax.pvary(x, axes) + return x + + def vjp_rc(fun, *primals: jax.Array): def real_fun(*primals): return jnp.real(fun(*primals)) @@ -72,8 +96,7 @@ def flat_f(x): out = flat_f(flat_primals) eye = jnp.eye(out.size, dtype=out.dtype) - if hasattr(jax.lax, 'pvary'): - eye = jax.lax.pvary(eye, tuple(jax.typeof(out).vma)) + eye = _mark_varying_like(eye, out) result = jax.vmap(vjp(flat_f, flat_primals))(eye)[0] result = jax.vmap(unravel, out_axes=0)(result) if len(primals) == 1: @@ -94,8 +117,7 @@ def jvp_fun(s): return jax.jvp(f, primals, unravel(s))[1] eye = jnp.eye(flat_primals.size, dtype=flat_primals.dtype) - if hasattr(jax.lax, 'pvary'): - eye = jax.lax.pvary(eye, tuple(jax.typeof(flat_primals).vma)) + eye = _mark_varying_like(eye, flat_primals) J = jax.vmap(jvp_fun, out_axes=-1)(eye) return J