Skip to content

bug: jit/imaging/pixelization.py — JIT vs eager mapping matrix shape mismatch #68

@Jammy2211

Description

@Jammy2211

Summary

On clean main (autolens_workspace_developer@4e99bf9, PyAutoLens 2026.5.14.2), jax_profiling/jit/imaging/pixelization.py crashes at the regularized-reconstruction step with a JAX broadcasting error. The JIT-traced blurred_mapping_matrix has shape (15361, 1285) whereas the eager path produces (15361, 1225), so the downstream curvature matrix (1285×1285) cannot be added to the regularization matrix (1225×1225).

Surfaced by the autolens_profiling Phase 1 follow-up smoke (issue #67); split out of that triage because it's a pre-existing JIT-vs-eager bug independent of the point_source regression drift.

Reproduction

cd autolens_workspace_developer
source ../activate.sh  # or set NUMBA_CACHE_DIR / MPLCONFIGDIR / PYTHONPATH manually
python jax_profiling/jit/imaging/pixelization.py

Output (relevant tail)

  blurred_mapping_matrix (JIT) shape: (15361, 1285)
  blurred_mapping_matrix shape: (15361, 1225)
...
  curvature_matrix shape: (1285, 1285)
...
  regularization_matrix shape: (1225, 1225)

--- Step 12: Regularized reconstruction ---
Traceback (most recent call last):
  File ".../jax_profiling/jit/imaging/pixelization.py", line 660, in <module>
    reconstruction = compute_reconstruction(
  File ".../jax_profiling/jit/imaging/pixelization.py", line 652, in compute_reconstruction
    curvature_reg_matrix = curvature_matrix + regularization_matrix
TypeError: add got incompatible shapes for broadcasting: (1285, 1285), (1225, 1225).

Diagnosis hint

  • 1225 = 35×35 (likely the rectangular pixelization native mesh shape, 35x35 pixels).
  • 1285 - 1225 = 60. Possibly 60 extra source pixels from over-sampling / mesh expansion in the JIT path that the eager path doesn't apply.
  • The blurred_mapping_matrix is being computed with a different mesh shape on the JIT path vs the eager path, even though both should be reading the same Pixelization model.

Likely culprits to investigate (somewhere in PyAutoArray / PyAutoGalaxy):

  • Mapper-side mesh construction under xp=jnp.
  • Pixelization source-mesh shape derivation when xp=jnp vs xp=np.
  • Over-sampling / image-mesh padding applied conditionally.

Context

  • All other JIT scripts on clean main behave consistently:

This is the only pixelization (rectangular) JIT script that crashes — interferometer/pixelization works, which is suspicious enough to point at imaging-specific mapper code.

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions