Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions folx/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,30 @@ def is_tree_complex(tree):
return any(jnp.iscomplexobj(leaf) for leaf in leaves)


def _varying_axes(x: jax.Array) -> tuple:
if not hasattr(jax, 'typeof'):
return ()

typ = jax.typeof(x)
manual_axis_type = getattr(typ, 'manual_axis_type', None)
if manual_axis_type is not None:
return tuple(getattr(manual_axis_type, 'varying', ()))

return tuple(getattr(typ, 'vma', ()))


def _mark_varying_like(x: jax.Array, like: jax.Array) -> jax.Array:
axes = _varying_axes(like)
if not axes:
return x

if hasattr(jax.lax, 'pcast'):
return jax.lax.pcast(x, axes, to='varying')
if hasattr(jax.lax, 'pvary'):
return jax.lax.pvary(x, axes)
return x


def vjp_rc(fun, *primals: jax.Array):
def real_fun(*primals):
return jnp.real(fun(*primals))
Expand Down Expand Up @@ -72,8 +96,7 @@ def flat_f(x):
out = flat_f(flat_primals)

eye = jnp.eye(out.size, dtype=out.dtype)
if hasattr(jax.lax, 'pvary'):
eye = jax.lax.pvary(eye, tuple(jax.typeof(out).vma))
eye = _mark_varying_like(eye, out)
result = jax.vmap(vjp(flat_f, flat_primals))(eye)[0]
result = jax.vmap(unravel, out_axes=0)(result)
if len(primals) == 1:
Expand All @@ -94,8 +117,7 @@ def jvp_fun(s):
return jax.jvp(f, primals, unravel(s))[1]

eye = jnp.eye(flat_primals.size, dtype=flat_primals.dtype)
if hasattr(jax.lax, 'pvary'):
eye = jax.lax.pvary(eye, tuple(jax.typeof(flat_primals).vma))
eye = _mark_varying_like(eye, flat_primals)
J = jax.vmap(jvp_fun, out_axes=-1)(eye)
return J

Expand Down