Skip to content

Propagate xp through Grid2DIrregular.grid_2d_via_deflection_grid_from#287

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/grid-irregular-xp-propagation
Apr 18, 2026
Merged

Propagate xp through Grid2DIrregular.grid_2d_via_deflection_grid_from#287
Jammy2211 merged 1 commit into
mainfrom
feature/grid-irregular-xp-propagation

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Grid2DIrregular.grid_2d_via_deflection_grid_from previously constructed the result grid without propagating xp=self._xp, so a JAX-backed receiver produced a numpy-backed result whose values were still JAX tracers. Downstream self._xp.square(...) then called np.square on a tracer and raised TracerArrayConversionError, blocking JIT of point-source pipelines (see #286).

This one-line change passes xp=self._xp so the new grid inherits the receiver's backend, and adds a unit test covering both the np and jnp round-trips.

API Changes

None — internal change only. Public signature unchanged; behaviour is only affected when the receiver's _xp is not numpy (in which case previously the result silently downgraded to numpy).
See full details below.

Test Plan

  • Full PyAutoArray test suite passes (733 passed)
  • New test__grid_2d_via_deflection_grid_from__propagates_xp covers numpy + JAX round-trips
  • Downstream PyAutoLens point-source JIT pipeline traces end-to-end (was TracerArrayConversionError before, see stacked PR on PyAutoLens)
Full API Changes (for automation & release notes)

Removed

None.

Added

None (unit test added, no public API additions).

Renamed

None.

Changed Signature

None.

Changed Behaviour

  • Grid2DIrregular.grid_2d_via_deflection_grid_from — the returned grid now has _xp matching the receiver's _xp. Previously always defaulted to numpy, silently downgrading JAX-backed callers.

Migration

None — callers that were numpy-backed are unchanged. Callers that were JAX-backed and worked around the downgrade (e.g. by manually re-wrapping the result with xp=jnp) can drop the workaround.

Follows up: #286

🤖 Generated with Claude Code

Previously constructed the new Grid2DIrregular without passing xp, so the
resulting grid defaulted to _xp=np even when called on a JAX-backed
receiver. Downstream calls to xp.square on the values (which were JAX
tracers under JIT) raised TracerArrayConversionError.

Pass xp=self._xp so the result inherits the receiver's backend.

Co-Authored-By: Claude Opus 4.7 <[email protected]>
@Jammy2211
Copy link
Copy Markdown
Collaborator Author

@Jammy2211 Jammy2211 merged commit 1897037 into main Apr 18, 2026
4 checks passed
@Jammy2211 Jammy2211 deleted the feature/grid-irregular-xp-propagation branch April 18, 2026 20:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant