Skip to content

Commit 71e4ba4

Browse files
Jammy2211claude
authored andcommitted
docs: capture xp-threading pitfalls in CLAUDE.md
Add a new subsection under "JAX Support" documenting the four classes of un-threaded `xp` sites that bit analysis-ellipse-jax (PR #412): @Property chains, inherited methods, convert.py helpers, and @cached_property on traced arrays. Also add the validation rule: `jax.jit(fn)(concrete_instance)` is NOT a sufficient JAX trace check — use `fitness._vmap(jnp.array(params))` to force tracer propagation. PR #412 needed 5 follow-up fix commits after the initial ship because the original verification only ran jit-on-concrete. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent b71a609 commit 71e4ba4

1 file changed

Lines changed: 38 additions & 23 deletions

File tree

CLAUDE.md

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,21 @@ When adding a new function that should support JAX:
147147
4. Add a JAX implementation in the guarded branch (e.g. `jax.jacfwd`, `jnp.vectorize`)
148148
5. Verify correctness by comparing both paths in `autogalaxy_workspace_test/scripts/`
149149

150+
### Threading `xp` through inherited / property / convert-helper calls
151+
152+
Adding `xp=np` to a method body and swapping `np.*` for `xp.*` is **only half the work**. Every nested call inside that body — whether to `self.X()`, `obj.X()`, a helper in `convert.py`, an inherited `@property`, or a sibling method — must also receive `xp=xp` if it can route to numpy operations on what would otherwise be JAX tracers. Otherwise the inner call quietly defaults to `xp=np` and fails when a tracer reaches an `np.*` op.
153+
154+
Concrete sites that have bitten this codebase (all fixed during prompt 7 of `ellipse_fitting_jax`):
155+
156+
- **`@property` chains that hardcode `np`.** `Ellipse.ellipticity` and `Ellipse.minor_axis` are properties (no kwargs possible), so a caller in an xp-aware method must either inline the computation under `if xp is not np:` or convert the property to a method. Pattern: read every `@property` you call from xp-aware code; if it does `np.sqrt(...)` on `self.ell_comps`, it's a hazard.
157+
- **Inherited methods.** `Ellipse.angle(xp=np)` and `Ellipse.angle_radians(xp=np)` accept `xp` (they're defined on `EllProfile`), but call sites in `EllipseMultipole.points_perturbed_from` used `ellipse.angle()` without passing it. Pattern: grep within xp-aware functions for `self.X(` and `obj.X(`; for each, verify `xp=xp` is passed.
158+
- **`convert.py` helpers.** `multipole_comps_from`, `multipole_k_m_and_phi_m_from`, `axis_ratio_and_angle_from`, `angle_from` etc. all take `xp=np`; call sites must thread it. They also use Python `&` on JAX bool tracers, which silently calls `__array__()` — replace with `xp.logical_and`.
159+
- **`@cached_property` on traced arrays.** Caches a tracer in `self.__dict__` which is invalid across `vmap` batch elements (different batches share the cache). Use plain `@property` for any property whose value depends on JAX-traced inputs.
160+
161+
**Validation: `jax.jit(fn)(concrete_instance)` is NOT a sufficient JAX trace check.** A `ModelInstance` with concrete float `ell_comps` propagates as floats through `np.*` ops without raising — the bug stays hidden. **Use `jax.vmap(fitness)(jnp.array(params))` instead** (or `Fitness._vmap` on autofit's wrapper). Vmap forces tracer propagation through every leaf and exposes un-threaded `xp` sites.
162+
163+
When adding a JAX path to an Analysis class, the workspace_test parity script must include both a `jax.jit(analysis.fit_from)(instance)` round-trip AND a `fitness._vmap(parameters)` batch evaluation. See `autogalaxy_workspace_test/scripts/jax_likelihood_functions/imaging/lp.py` and the ellipse counterpart for the template.
164+
150165
### JAX and autoarray wrappers at the `jax.jit` boundary
151166

152167
Autoarray types (`Array2D`, `ArrayIrregular`, `VectorYX2DIrregular`, etc.) are **not registered as JAX pytrees**. This means:
@@ -236,26 +251,26 @@ find . -type f -name "*.py" | xargs dos2unix
236251

237252
Prefer simple shell commands.
238253
Avoid chaining with && or pipes.
239-
## Never rewrite history
240-
241-
NEVER perform these operations on any repo with a remote:
242-
243-
- `git init` in a directory already tracked by git
244-
- `rm -rf .git && git init`
245-
- Commit with subject "Initial commit", "Fresh start", "Start fresh", "Reset
246-
for AI workflow", or any equivalent message on a branch with a remote
247-
- `git push --force` to `main` (or any branch tracked as `origin/HEAD`)
248-
- `git filter-repo` / `git filter-branch` on shared branches
249-
- `git rebase -i` rewriting commits already pushed to a shared branch
250-
251-
If the working tree needs a clean state, the **only** correct sequence is:
252-
253-
git fetch origin
254-
git reset --hard origin/main
255-
git clean -fd
256-
257-
This applies equally to humans, local Claude Code, cloud Claude agents, Codex,
258-
and any other agent. The "Initial commit — fresh start for AI workflow" pattern
259-
that appeared independently on origin and local for three workspace repos is
260-
exactly what this rule prevents — it costs ~40 commits of redundant local work
261-
every time it happens.
254+
## Never rewrite history
255+
256+
NEVER perform these operations on any repo with a remote:
257+
258+
- `git init` in a directory already tracked by git
259+
- `rm -rf .git && git init`
260+
- Commit with subject "Initial commit", "Fresh start", "Start fresh", "Reset
261+
for AI workflow", or any equivalent message on a branch with a remote
262+
- `git push --force` to `main` (or any branch tracked as `origin/HEAD`)
263+
- `git filter-repo` / `git filter-branch` on shared branches
264+
- `git rebase -i` rewriting commits already pushed to a shared branch
265+
266+
If the working tree needs a clean state, the **only** correct sequence is:
267+
268+
git fetch origin
269+
git reset --hard origin/main
270+
git clean -fd
271+
272+
This applies equally to humans, local Claude Code, cloud Claude agents, Codex,
273+
and any other agent. The "Initial commit — fresh start for AI workflow" pattern
274+
that appeared independently on origin and local for three workspace repos is
275+
exactly what this rule prevents — it costs ~40 commits of redundant local work
276+
every time it happens.

0 commit comments

Comments
 (0)