Skip to content

Fix jacrev/jacfwd varying basis construction for JAX 0.10#43

Open
dalin27 wants to merge 1 commit intomicrosoft:mainfrom
dalin27:dl/fix-jax-010-pvary-pcast
Open

Fix jacrev/jacfwd varying basis construction for JAX 0.10#43
dalin27 wants to merge 1 commit intomicrosoft:mainfrom
dalin27:dl/fix-jax-010-pvary-pcast

Conversation

@dalin27
Copy link
Copy Markdown

@dalin27 dalin27 commented May 1, 2026

jax.lax.pvary was deprecated in JAX 0.8.2 and removed from the public jax.lax API in JAX 0.10.0. JAX 0.10 also removed ShapedArray.vma, replacing it with manual_axis_type.varying.

The new helpers preserve compatibility by:

  • using manual_axis_type.varying when available;
  • falling back to vma on older JAX versions;
  • using jax.lax.pcast(..., to="varying") when available;
  • falling back to jax.lax.pvary(...) on older JAX versions.

@dalin27
Copy link
Copy Markdown
Author

dalin27 commented May 1, 2026

@microsoft-github-policy-service agree

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant