Skip to content

Add JAX 0.8+ and 0.9+ compatibility#208

Merged
mhliu0001 merged 6 commits intomasterfrom
update_jax
Feb 21, 2026
Merged

Add JAX 0.8+ and 0.9+ compatibility#208
mhliu0001 merged 6 commits intomasterfrom
update_jax

Conversation

@mhliu0001
Copy link
Collaborator

@mhliu0001 mhliu0001 commented Jan 14, 2026

This PR updates appletree to be compatible with JAX >= 0.8.0 while preserving backward compatibility with older supported JAX versions.

Summary

  • Remove the JAX version upper bound (<0.8.0 → *) to support JAX 0.8.x and 0.9.x
  • Replace all deprecated jnp.clip(x, a_min=..., a_max=...) keyword arguments with positional arguments, which work across all JAX versions
  • Make the numpyro import in randgen.py conditional -- it is only needed on JAX < 0.4.21 (which lacks jax.random.binomial), so it no longer blocks import on JAX 0.9, where numpyro may not yet be compatible
  • Use jax.default_backend() for platform detection with a fallback for older JAX versions

Testing

Tested for JAX 0.7.2, 0.8.3, 0.9.0.1.

@mhliu0001 mhliu0001 requested a review from dachengx January 14, 2026 18:18
@github-actions
Copy link

Pull Request Test Coverage Report for Build 21005071564

Details

  • 14 of 17 (82.35%) changed or added relevant lines in 3 files are covered.
  • 37 unchanged lines in 2 files lost coverage.
  • Overall coverage decreased (-0.05%) to 84.065%

Changes Missing Coverage Covered Lines Changed/Added Lines %
appletree/config.py 0 1 0.0%
appletree/utils.py 4 6 66.67%
Files with Coverage Reduction New Missed Lines %
appletree/interpolation.py 15 72.05%
appletree/utils.py 22 62.19%
Totals Coverage Status
Change from base Build 18637300865: -0.05%
Covered Lines: 2432
Relevant Lines: 2893

💛 - Coveralls

@coveralls
Copy link

coveralls commented Jan 14, 2026

Pull Request Test Coverage Report for Build 22245185836

Details

  • 20 of 24 (83.33%) changed or added relevant lines in 5 files are covered.
  • 1 unchanged line in 1 file lost coverage.
  • Overall coverage decreased (-0.09%) to 84.191%

Changes Missing Coverage Covered Lines Changed/Added Lines %
appletree/config.py 0 1 0.0%
appletree/utils.py 4 5 80.0%
appletree/randgen.py 1 3 33.33%
Files with Coverage Reduction New Missed Lines %
appletree/utils.py 1 62.19%
Totals Coverage Status
Change from base Build 21965575783: -0.09%
Covered Lines: 2471
Relevant Lines: 2935

💛 - Coveralls

@mhliu0001 mhliu0001 changed the title Add support for JAX v0.8 Add JAX 0.8+ and 0.9+ compatibility Feb 21, 2026
@mhliu0001
Copy link
Collaborator Author

The code factor issues will be addressed in a separate PR.

@mhliu0001 mhliu0001 merged commit ef88ba6 into master Feb 21, 2026
6 of 7 checks passed
@mhliu0001 mhliu0001 deleted the update_jax branch February 21, 2026 00:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants