You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: CLAUDE.md
+90-1Lines changed: 90 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -78,9 +78,75 @@ Each dataset type has an `Analysis*` class that implements `log_likelihood_funct
78
78
79
79
These inherit from `AnalysisDataset` → `Analysis` (in `analysis/analysis/`), which inherits `af.Analysis`. The `log_likelihood_function` builds a `Fit*` object from the `af.ModelInstance` and returns its `figure_of_merit`.
80
80
81
+
### Decorator System (from autoarray)
82
+
83
+
Profile methods that consume a grid and return an array, grid, or vector use decorators from `autoarray.structures.decorators`. These ensure the **output type matches the input grid type**:
The `@aa.grid_dec.transform` decorator (always stacked below the output decorator) shifts and rotates the grid to the profile's reference frame before passing it to the function body.
y = grid.array[:, 0] # use .array to get raw numpy/jax array
99
+
x = grid.array[:, 1]
100
+
return...# return raw array; decorator wraps it
101
+
```
102
+
103
+
**Key rule**: the function body must return a **raw array** (not an autoarray). The decorator handles wrapping. Access grid coordinates via `grid.array[:, 0]` / `grid.array[:, 1]` (not `grid[:, 0]`), because after `@transform` the grid is still an autoarray object and `.array` is the safe way to extract the underlying data for both numpy and jax backends.
104
+
105
+
See PyAutoArray's `CLAUDE.md` for full details on the decorator internals.
106
+
81
107
### JAX Support
82
108
83
-
JAX is integrated via the `xp` parameter pattern throughout the codebase. Fit classes accept `xp=np` (NumPy, default) or `xp=jnp` (JAX). The `AbstractFitInversion.use_jax` property tracks which backend is active. The `AnalysisImaging.__init__` has `use_jax: bool = True`. The conftest.py forces JAX backend initialization before tests run.
109
+
The codebase is designed so that **NumPy is the default everywhere and JAX is opt-in**. JAX is never imported at module level — it is only imported locally inside functions when explicitly requested.
110
+
111
+
The `xp` parameter pattern is the single point of control:
112
+
-`xp=np` (default throughout) — pure NumPy path, no JAX dependency at runtime
113
+
-`xp=jnp` — JAX path, imports `jax` / `jax.numpy` locally inside the function
114
+
115
+
This means:
116
+
-**Unit tests** (`test_autogalaxy/`) always run on the NumPy path. No test should import JAX or pass `xp=jnp` unless it is explicitly testing the JAX path.
117
+
-**Integration tests** (in `autogalaxy_workspace_test/`) are where the JAX path is exercised, typically wrapped in `jax.jit` to test both correctness and compilation.
118
+
-`conftest.py` forces JAX backend initialisation before the test suite runs, but this only ensures JAX is available — it does not switch the default backend.
119
+
120
+
`AbstractFitInversion.use_jax` tracks whether a fit was constructed with JAX. `AnalysisImaging` has `use_jax: bool = True` to opt into the JAX path for model-fitting.
121
+
122
+
When adding a new function that should support JAX:
123
+
1. Default the parameter to `xp=np`
124
+
2. Guard any JAX imports with `if xp is not np:` and import `jax` / `jax.numpy` locally inside that branch
125
+
3. Add the NumPy implementation as the default path (finite-difference, `np.*` calls, etc.)
126
+
4. Add a JAX implementation in the guarded branch (e.g. `jax.jacfwd`, `jnp.vectorize`)
127
+
5. Verify correctness by comparing both paths in `autogalaxy_workspace_test/scripts/`
128
+
129
+
### JAX and autoarray wrappers at the `jax.jit` boundary
130
+
131
+
Autoarray types (`Array2D`, `ArrayIrregular`, `VectorYX2DIrregular`, etc.) are **not registered as JAX pytrees**. This means:
132
+
133
+
- Constructing them **inside** a JIT trace is fine (Python code runs normally during tracing)
134
+
-**Returning** them as the output of a `jax.jit`-compiled function **fails** with `TypeError: ... is not a valid JAX type`
135
+
136
+
Functions decorated with `@aa.grid_dec.to_array` / `@to_vector_yx` wrap their return value in an autoarray type. This wrapping is safe for intermediate calls (the autoarray object is consumed by downstream Python code). However, if such a function is the **outermost call** inside a `jax.jit` lambda, its return value will fail at the JIT boundary.
137
+
138
+
The solution is the **`if xp is np:` guard** in the function body:
This pattern is applied throughout `autogalaxy/operate/lens_calc.py`. Functions that are only ever called as intermediate steps (e.g. `deflections_yx_2d_from`) do NOT need this guard — their autoarray wrappers are never the JIT output.
84
150
85
151
### Linear Light Profiles & Inversions
86
152
@@ -101,6 +167,10 @@ Default priors, visualization settings, and general config live in `autogalaxy/c
101
167
102
168
Both are mixin classes inherited by `LightProfile`, `MassProfile`, `Galaxy`, and `Galaxies`.
103
169
170
+
### Workspace Script Style
171
+
172
+
Scripts in `autogalaxy_workspace` and `autogalaxy_workspace_test` use `"""..."""` docstring blocks as prose commentary throughout — **not**`#` comments. Every script opens with a module-level docstring (title + underline + description), and each logical section of code is preceded by a `"""..."""` block with a `__Section Name__` header explaining what follows. See any script in `autogalaxy_workspace/scripts/` for examples of this style.
173
+
104
174
### Workspace (Examples & Notebooks)
105
175
106
176
The `autogalaxy_workspace` at `/mnt/c/Users/Jammy/Code/PyAutoJAX/autogalaxy_workspace` contains runnable examples and tutorials. Key locations:
@@ -125,3 +195,22 @@ When importing `autogalaxy as ag`:
Copy file name to clipboardExpand all lines: autogalaxy/config/general.yaml
+1-1Lines changed: 1 addition & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -1,5 +1,5 @@
1
1
fits:
2
-
flip_for_ds9: true
2
+
flip_for_ds9: false
3
3
psf:
4
4
use_fft_default: true # If True, PSFs are convolved using FFTs by default, which is faster and uses less memory in all cases except for very small PSFs, False uses direct convolution.
0 commit comments