Skip to content

test(jax): consolidate scripts/jax_assertions/ subdir for PyAutoFit#21

Merged
Jammy2211 merged 1 commit intomainfrom
feature/jax-assertions-sweep
May 1, 2026
Merged

test(jax): consolidate scripts/jax_assertions/ subdir for PyAutoFit#21
Jammy2211 merged 1 commit intomainfrom
feature/jax-assertions-sweep

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Reorganize jax assertions into a dedicated scripts/jax_assertions/ subdir mirroring autolens_workspace_test/scripts/jax_assertions/. Adds 3 new scripts holding 20 assertions previously living as pytest tests in PyAutoFit/test_autofit/.

Changes

  • git mv scripts/graphical/jax_assertions.pyscripts/jax_assertions/fitness_dispatch.py (preserves the 5 existing Fitness/JIT/pickle/visualization assertions added in 1caf713)
  • 3 new scripts:
    Script Source What it asserts
    pytrees.py test_autofit/jax/test_pytrees.py Manual jax.tree_util.register_pytree_node_class roundtrips for Prior, Model, Collection, ModelInstance
    enable_pytrees.py test_autofit/jax/test_enable_pytrees.py Public autofit.jax.enable_pytrees / register_model API: roundtrip, JIT compile, constant static-routing, kwarg static, TuplePrior gradient flow, idempotency, Collection roundtrip
    nested.py test_autofit/graphical/functionality/test_nested.py autofit.graphical.utils.nested_* vs jax.tree_util parity (get/set, ordering, items/paths, filter, map, update preserving NamedTuple)

Companion PR

Pairs with PyAutoFit#new-pr, which deletes the source unit tests. Merge this PR first so the assertions exist on main before the source tests are removed.

Test plan

  • Each new script runs successfully under python scripts/jax_assertions/<name>.py (locally verified with the cuda_plugin_extension is not found JAX warning being benign)
  • Moved fitness_dispatch.py runs successfully (no behavior change from the path move)

Reorganize jax assertions into a dedicated subdir mirroring the
autolens_workspace_test/scripts/jax_assertions/ pattern.

- git mv scripts/graphical/jax_assertions.py to
  scripts/jax_assertions/fitness_dispatch.py (preserves the 5 existing
  Fitness/JIT/pickle/visualization assertions added in 1caf713).
- 3 new scripts holding assertions that previously lived as pytest tests
  in PyAutoFit/test_autofit/:
    - pytrees.py: 5 manual jax pytree roundtrip checks for Prior, Model,
      Collection, ModelInstance via the legacy direct-registration path.
    - enable_pytrees.py: 8 checks for the public autofit.jax.enable_pytrees
      / register_model API (model roundtrip, JIT compile, constant
      static-routing, kwarg-constant static, TuplePrior gradient flow,
      idempotency, Collection roundtrip).
    - nested.py: 8 checks comparing autofit.graphical.utils.nested_*
      utilities against jax.tree_util as reference (nested_get/set,
      ordering, items/paths, filter, map, update preserving NamedTuple).

Each script ends with a `print("<name>: all assertions passed")` marker.
All 4 scripts run successfully under PYAUTO_TEST_MODE-free conditions.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@Jammy2211 Jammy2211 merged commit 8bd9991 into main May 1, 2026
4 checks passed
@Jammy2211 Jammy2211 deleted the feature/jax-assertions-sweep branch May 1, 2026 08:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant