Skip to content

Commit 3582e18

Browse files
authored
Merge pull request #392 from Jammy2211/feature/jax_decorator_bypass
Add CLAUDE.md documenting architecture, decorator system, and JAX rules
2 parents 342f6fb + 482f05d commit 3582e18

1 file changed

Lines changed: 135 additions & 0 deletions

File tree

CLAUDE.md

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# CLAUDE.md
2+
3+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4+
5+
## Commands
6+
7+
### Install
8+
```bash
9+
pip install -e ".[dev]"
10+
```
11+
12+
### Run Tests
13+
```bash
14+
# All tests
15+
python -m pytest test_autolens/
16+
17+
# Single test file
18+
python -m pytest test_autolens/lens/test_tracer.py
19+
20+
# With output
21+
python -m pytest test_autolens/imaging/test_fit_imaging.py -s
22+
```
23+
24+
### Formatting
25+
```bash
26+
black autolens/
27+
```
28+
29+
## Architecture
30+
31+
**PyAutoLens** is the gravitational lensing layer built on top of PyAutoGalaxy. It adds multi-plane ray-tracing, the `Tracer` object, and lensing-specific fit classes. It depends on:
32+
- **`autogalaxy`** — galaxy morphology, mass/light profiles, single-plane fitting
33+
- **`autoarray`** — low-level data structures (grids, masks, arrays, datasets, inversions)
34+
- **`autofit`** — non-linear search / model-fitting framework
35+
36+
### Core Class Hierarchy
37+
38+
```
39+
Tracer (lens/tracer.py)
40+
└── List[List[Galaxy]] — galaxies grouped by redshift plane
41+
├── ray-traces from source to lens to observer
42+
├── delegates to autogalaxy Galaxy/Galaxies for per-plane operations
43+
└── returns lensed images, deflection maps, convergence, magnification
44+
```
45+
46+
### Dataset Types and Fit Classes
47+
48+
| Dataset | Fit class | Analysis class |
49+
|---|---|---|
50+
| `aa.Imaging` | `FitImaging` | `AnalysisImaging` |
51+
| `aa.Interferometer` | `FitInterferometer` | `AnalysisInterferometer` |
52+
| Point source | `FitPointDataset` | `AnalysisPoint` |
53+
54+
All inherit from the corresponding `autogalaxy` base classes (`ag.FitImaging`, etc.) and extend them with multi-plane lensing via the `Tracer`.
55+
56+
### Key Directories
57+
58+
```
59+
autolens/
60+
lens/ Tracer, ray-tracing, multi-plane deflection logic
61+
imaging/ FitImaging, AnalysisImaging
62+
interferometer/ FitInterferometer, AnalysisInterferometer
63+
point/ Point-source datasets, fits, and analysis
64+
quantity/ FitQuantity for arbitrary lensing quantities
65+
analysis/ Shared analysis base classes, adapt images
66+
aggregator/ Scraping results from autofit output directories
67+
plot/ Visualisation (Plotter classes for all data types)
68+
```
69+
70+
## Decorator System (from autoarray)
71+
72+
PyAutoLens inherits the same decorator conventions as PyAutoGalaxy. Mass and light profile methods that take a grid and return an array/grid/vector are decorated with:
73+
74+
| Decorator | `Grid2D`| `Grid2DIrregular`|
75+
|---|---|---|
76+
| `@aa.grid_dec.to_array` | `Array2D` | `ArrayIrregular` |
77+
| `@aa.grid_dec.to_grid` | `Grid2D` | `Grid2DIrregular` |
78+
| `@aa.grid_dec.to_vector_yx` | `VectorYX2D` | `VectorYX2DIrregular` |
79+
80+
The `@aa.grid_dec.transform` decorator (always innermost) transforms the grid to the profile's reference frame. Standard stacking:
81+
82+
```python
83+
@aa.grid_dec.to_array
84+
@aa.grid_dec.transform
85+
def convergence_2d_from(self, grid, xp=np, **kwargs):
86+
y = grid.array[:, 0] # .array extracts raw numpy/jax array
87+
x = grid.array[:, 1]
88+
return ... # raw array — decorator wraps it
89+
```
90+
91+
The function body must return a **raw array**. Use `grid.array[:, 0]` (not `grid[:, 0]`) to access coordinates safely for both numpy and jax backends.
92+
93+
See PyAutoArray's `CLAUDE.md` for full decorator internals.
94+
95+
## JAX Support
96+
97+
The `xp` parameter pattern controls the backend:
98+
- `xp=np` (default) — pure NumPy, no JAX dependency
99+
- `xp=jnp` — JAX path; `jax`/`jax.numpy` imported locally inside the function only
100+
101+
### JAX and the `jax.jit` boundary
102+
103+
Autoarray types (`Array2D`, `ArrayIrregular`, `VectorYX2DIrregular`, etc.) are **not registered as JAX pytrees**. They can be constructed inside a JIT trace, but **cannot be returned** as the output of a `jax.jit`-compiled function.
104+
105+
Functions intended to be called directly inside `jax.jit` must guard autoarray wrapping with `if xp is np:`:
106+
107+
```python
108+
def convergence_2d_via_hessian_from(self, grid, xp=np):
109+
convergence = 0.5 * (hessian_yy + hessian_xx)
110+
111+
if xp is np:
112+
return aa.ArrayIrregular(values=convergence) # numpy: wrapped
113+
return convergence # jax: raw jax.Array
114+
```
115+
116+
Functions that are only called as intermediate steps (e.g. `deflections_yx_2d_from`) do not need this guard — they are consumed by downstream Python before the JIT boundary.
117+
118+
### `LensCalc` (autogalaxy)
119+
120+
The hessian-derived lensing quantities (`convergence_2d_via_hessian_from`, `shear_yx_2d_via_hessian_from`, `magnification_2d_via_hessian_from`, `magnification_2d_from`, `tangential_eigen_value_from`, `radial_eigen_value_from`) all implement the `if xp is np:` guard in `autogalaxy/operate/lens_calc.py` and return raw `jax.Array` on the JAX path, making them safe to call inside `jax.jit`.
121+
122+
## Namespace Conventions
123+
124+
When importing `autolens as al`:
125+
- `al.mp.*` — mass profiles (re-exported from autogalaxy)
126+
- `al.lp.*` — light profiles (re-exported from autogalaxy)
127+
- `al.Galaxy`, `al.Galaxies`
128+
- `al.Tracer`
129+
- `al.FitImaging`, `al.AnalysisImaging`, `al.SimulatorImaging`
130+
- `al.FitInterferometer`, `al.AnalysisInterferometer`
131+
- `al.FitPointDataset`, `al.AnalysisPoint`
132+
133+
## Line Endings — Always Unix (LF)
134+
135+
All files **must use Unix line endings (LF, `\n`)**. Never write `\r\n` line endings.

0 commit comments

Comments
 (0)