You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: CLAUDE.md
+67Lines changed: 67 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -78,6 +78,32 @@ Each dataset type has an `Analysis*` class that implements `log_likelihood_funct
78
78
79
79
These inherit from `AnalysisDataset` → `Analysis` (in `analysis/analysis/`), which inherits `af.Analysis`. The `log_likelihood_function` builds a `Fit*` object from the `af.ModelInstance` and returns its `figure_of_merit`.
80
80
81
+
### Decorator System (from autoarray)
82
+
83
+
Profile methods that consume a grid and return an array, grid, or vector use decorators from `autoarray.structures.decorators`. These ensure the **output type matches the input grid type**:
The `@aa.grid_dec.transform` decorator (always stacked below the output decorator) shifts and rotates the grid to the profile's reference frame before passing it to the function body.
y = grid.array[:, 0] # use .array to get raw numpy/jax array
99
+
x = grid.array[:, 1]
100
+
return...# return raw array; decorator wraps it
101
+
```
102
+
103
+
**Key rule**: the function body must return a **raw array** (not an autoarray). The decorator handles wrapping. Access grid coordinates via `grid.array[:, 0]` / `grid.array[:, 1]` (not `grid[:, 0]`), because after `@transform` the grid is still an autoarray object and `.array` is the safe way to extract the underlying data for both numpy and jax backends.
104
+
105
+
See PyAutoArray's `CLAUDE.md` for full details on the decorator internals.
106
+
81
107
### JAX Support
82
108
83
109
The codebase is designed so that **NumPy is the default everywhere and JAX is opt-in**. JAX is never imported at module level — it is only imported locally inside functions when explicitly requested.
@@ -100,6 +126,28 @@ When adding a new function that should support JAX:
100
126
4. Add a JAX implementation in the guarded branch (e.g. `jax.jacfwd`, `jnp.vectorize`)
101
127
5. Verify correctness by comparing both paths in `autogalaxy_workspace_test/scripts/`
102
128
129
+
### JAX and autoarray wrappers at the `jax.jit` boundary
130
+
131
+
Autoarray types (`Array2D`, `ArrayIrregular`, `VectorYX2DIrregular`, etc.) are **not registered as JAX pytrees**. This means:
132
+
133
+
- Constructing them **inside** a JIT trace is fine (Python code runs normally during tracing)
134
+
-**Returning** them as the output of a `jax.jit`-compiled function **fails** with `TypeError: ... is not a valid JAX type`
135
+
136
+
Functions decorated with `@aa.grid_dec.to_array` / `@to_vector_yx` wrap their return value in an autoarray type. This wrapping is safe for intermediate calls (the autoarray object is consumed by downstream Python code). However, if such a function is the **outermost call** inside a `jax.jit` lambda, its return value will fail at the JIT boundary.
137
+
138
+
The solution is the **`if xp is np:` guard** in the function body:
This pattern is applied throughout `autogalaxy/operate/lens_calc.py`. Functions that are only ever called as intermediate steps (e.g. `deflections_yx_2d_from`) do NOT need this guard — their autoarray wrappers are never the JIT output.
150
+
103
151
### Linear Light Profiles & Inversions
104
152
105
153
`LightProfileLinear` subclasses do not take an `intensity` parameter—it is solved via a linear inversion (provided by `autoarray`). The `GalaxiesToInversion` class (`galaxy/to_inversion.py`) handles converting galaxies with linear profiles or pixelizations into the inversion objects needed by `autoarray`.
@@ -147,3 +195,22 @@ When importing `autogalaxy as ag`:
0 commit comments