You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add a new subsection under "JAX Support" documenting the four classes
of un-threaded `xp` sites that bit analysis-ellipse-jax (PR #412):
@Property chains, inherited methods, convert.py helpers, and
@cached_property on traced arrays.
Also add the validation rule: `jax.jit(fn)(concrete_instance)` is NOT
a sufficient JAX trace check — use `fitness._vmap(jnp.array(params))`
to force tracer propagation. PR #412 needed 5 follow-up fix commits
after the initial ship because the original verification only ran
jit-on-concrete.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy file name to clipboardExpand all lines: CLAUDE.md
+38-23Lines changed: 38 additions & 23 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -147,6 +147,21 @@ When adding a new function that should support JAX:
147
147
4. Add a JAX implementation in the guarded branch (e.g. `jax.jacfwd`, `jnp.vectorize`)
148
148
5. Verify correctness by comparing both paths in `autogalaxy_workspace_test/scripts/`
149
149
150
+
### Threading `xp` through inherited / property / convert-helper calls
151
+
152
+
Adding `xp=np` to a method body and swapping `np.*` for `xp.*` is **only half the work**. Every nested call inside that body — whether to `self.X()`, `obj.X()`, a helper in `convert.py`, an inherited `@property`, or a sibling method — must also receive `xp=xp` if it can route to numpy operations on what would otherwise be JAX tracers. Otherwise the inner call quietly defaults to `xp=np` and fails when a tracer reaches an `np.*` op.
153
+
154
+
Concrete sites that have bitten this codebase (all fixed during prompt 7 of `ellipse_fitting_jax`):
155
+
156
+
-**`@property` chains that hardcode `np`.**`Ellipse.ellipticity` and `Ellipse.minor_axis` are properties (no kwargs possible), so a caller in an xp-aware method must either inline the computation under `if xp is not np:` or convert the property to a method. Pattern: read every `@property` you call from xp-aware code; if it does `np.sqrt(...)` on `self.ell_comps`, it's a hazard.
157
+
-**Inherited methods.**`Ellipse.angle(xp=np)` and `Ellipse.angle_radians(xp=np)` accept `xp` (they're defined on `EllProfile`), but call sites in `EllipseMultipole.points_perturbed_from` used `ellipse.angle()` without passing it. Pattern: grep within xp-aware functions for `self.X(` and `obj.X(`; for each, verify `xp=xp` is passed.
158
+
-**`convert.py` helpers.**`multipole_comps_from`, `multipole_k_m_and_phi_m_from`, `axis_ratio_and_angle_from`, `angle_from` etc. all take `xp=np`; call sites must thread it. They also use Python `&` on JAX bool tracers, which silently calls `__array__()` — replace with `xp.logical_and`.
159
+
-**`@cached_property` on traced arrays.** Caches a tracer in `self.__dict__` which is invalid across `vmap` batch elements (different batches share the cache). Use plain `@property` for any property whose value depends on JAX-traced inputs.
160
+
161
+
**Validation: `jax.jit(fn)(concrete_instance)` is NOT a sufficient JAX trace check.** A `ModelInstance` with concrete float `ell_comps` propagates as floats through `np.*` ops without raising — the bug stays hidden. **Use `jax.vmap(fitness)(jnp.array(params))` instead** (or `Fitness._vmap` on autofit's wrapper). Vmap forces tracer propagation through every leaf and exposes un-threaded `xp` sites.
162
+
163
+
When adding a JAX path to an Analysis class, the workspace_test parity script must include both a `jax.jit(analysis.fit_from)(instance)` round-trip AND a `fitness._vmap(parameters)` batch evaluation. See `autogalaxy_workspace_test/scripts/jax_likelihood_functions/imaging/lp.py` and the ellipse counterpart for the template.
164
+
150
165
### JAX and autoarray wrappers at the `jax.jit` boundary
151
166
152
167
Autoarray types (`Array2D`, `ArrayIrregular`, `VectorYX2DIrregular`, etc.) are **not registered as JAX pytrees**. This means:
0 commit comments