diff --git a/s2fft/precompute_transforms/custom_ops.py b/s2fft/precompute_transforms/custom_ops.py index b6a84da3..03cfe63f 100644 --- a/s2fft/precompute_transforms/custom_ops.py +++ b/s2fft/precompute_transforms/custom_ops.py @@ -86,7 +86,7 @@ def wigner_subset_to_s2( return np.fft.ifft(x, axis=-2, norm="forward") -@partial(jit, static_argnums=(3, 4)) +# @partial(jit, static_argnums=(3, 4)) def wigner_subset_to_s2_jax( flmn: jnp.ndarray, spins: jnp.ndarray, @@ -209,7 +209,7 @@ def so3_to_wigner_subset( return s2_to_wigner_subset(x, spins, DW, L, sampling) -@partial(jit, static_argnums=(3, 4, 5)) +# @partial(jit, static_argnums=(3, 4, 5)) def so3_to_wigner_subset_jax( f: jnp.ndarray, spins: jnp.ndarray, @@ -338,7 +338,7 @@ def s2_to_wigner_subset( return x * (2.0 * np.pi) ** 2 -@partial(jit, static_argnums=(3, 4)) +# @partial(jit, static_argnums=(3, 4)) def s2_to_wigner_subset_jax( fs: jnp.ndarray, spins: jnp.ndarray,