Skip to content

Commit 2c50096

Browse files
author
NiekWielders
committed
Merge branch 'main' into feature/two_galaxies_lens
2 parents 57d2584 + 59174cf commit 2c50096

14 files changed

Lines changed: 593 additions & 45 deletions

File tree

CLAUDE.md

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,32 @@ Each dataset type has an `Analysis*` class that implements `log_likelihood_funct
7878

7979
These inherit from `AnalysisDataset``Analysis` (in `analysis/analysis/`), which inherits `af.Analysis`. The `log_likelihood_function` builds a `Fit*` object from the `af.ModelInstance` and returns its `figure_of_merit`.
8080

81+
### Decorator System (from autoarray)
82+
83+
Profile methods that consume a grid and return an array, grid, or vector use decorators from `autoarray.structures.decorators`. These ensure the **output type matches the input grid type**:
84+
85+
| Decorator | `Grid2D` input | `Grid2DIrregular` input |
86+
|---|---|---|
87+
| `@aa.grid_dec.to_array` | `Array2D` | `ArrayIrregular` |
88+
| `@aa.grid_dec.to_grid` | `Grid2D` | `Grid2DIrregular` |
89+
| `@aa.grid_dec.to_vector_yx` | `VectorYX2D` | `VectorYX2DIrregular` |
90+
91+
The `@aa.grid_dec.transform` decorator (always stacked below the output decorator) shifts and rotates the grid to the profile's reference frame before passing it to the function body.
92+
93+
The canonical stacking order is:
94+
```python
95+
@aa.grid_dec.to_array # outermost: wraps output
96+
@aa.grid_dec.transform # innermost: transforms grid
97+
def convergence_2d_from(self, grid, xp=np, **kwargs):
98+
y = grid.array[:, 0] # use .array to get raw numpy/jax array
99+
x = grid.array[:, 1]
100+
return ... # return raw array; decorator wraps it
101+
```
102+
103+
**Key rule**: the function body must return a **raw array** (not an autoarray). The decorator handles wrapping. Access grid coordinates via `grid.array[:, 0]` / `grid.array[:, 1]` (not `grid[:, 0]`), because after `@transform` the grid is still an autoarray object and `.array` is the safe way to extract the underlying data for both numpy and jax backends.
104+
105+
See PyAutoArray's `CLAUDE.md` for full details on the decorator internals.
106+
81107
### JAX Support
82108

83109
The codebase is designed so that **NumPy is the default everywhere and JAX is opt-in**. JAX is never imported at module level — it is only imported locally inside functions when explicitly requested.
@@ -100,6 +126,28 @@ When adding a new function that should support JAX:
100126
4. Add a JAX implementation in the guarded branch (e.g. `jax.jacfwd`, `jnp.vectorize`)
101127
5. Verify correctness by comparing both paths in `autogalaxy_workspace_test/scripts/`
102128

129+
### JAX and autoarray wrappers at the `jax.jit` boundary
130+
131+
Autoarray types (`Array2D`, `ArrayIrregular`, `VectorYX2DIrregular`, etc.) are **not registered as JAX pytrees**. This means:
132+
133+
- Constructing them **inside** a JIT trace is fine (Python code runs normally during tracing)
134+
- **Returning** them as the output of a `jax.jit`-compiled function **fails** with `TypeError: ... is not a valid JAX type`
135+
136+
Functions decorated with `@aa.grid_dec.to_array` / `@to_vector_yx` wrap their return value in an autoarray type. This wrapping is safe for intermediate calls (the autoarray object is consumed by downstream Python code). However, if such a function is the **outermost call** inside a `jax.jit` lambda, its return value will fail at the JIT boundary.
137+
138+
The solution is the **`if xp is np:` guard** in the function body:
139+
140+
```python
141+
def convergence_2d_via_hessian_from(self, grid, xp=np):
142+
convergence = 0.5 * (hessian_yy + hessian_xx)
143+
144+
if xp is np:
145+
return aa.ArrayIrregular(values=convergence) # numpy: wrapped
146+
return convergence # jax: raw jax.Array
147+
```
148+
149+
This pattern is applied throughout `autogalaxy/operate/lens_calc.py`. Functions that are only ever called as intermediate steps (e.g. `deflections_yx_2d_from`) do NOT need this guard — their autoarray wrappers are never the JIT output.
150+
103151
### Linear Light Profiles & Inversions
104152

105153
`LightProfileLinear` subclasses do not take an `intensity` parameter—it is solved via a linear inversion (provided by `autoarray`). The `GalaxiesToInversion` class (`galaxy/to_inversion.py`) handles converting galaxies with linear profiles or pixelizations into the inversion objects needed by `autoarray`.
@@ -147,3 +195,22 @@ When importing `autogalaxy as ag`:
147195
- `ag.ps.*` – point sources
148196
- `ag.Galaxy`, `ag.Galaxies`
149197
- `ag.FitImaging`, `ag.AnalysisImaging`, `ag.SimulatorImaging`
198+
199+
## Line Endings — Always Unix (LF)
200+
201+
All files in this project **must use Unix line endings (LF, `\n`)**. Windows/DOS line endings (CRLF, `\r\n`) will break Python files on HPC systems.
202+
203+
**When writing or editing any file**, always produce Unix line endings. Never write `\r\n` line endings.
204+
205+
After creating or copying files, verify and convert if needed:
206+
207+
```bash
208+
# Check for DOS line endings
209+
file autogalaxy/galaxy/galaxy.py # should say "ASCII text", not "CRLF"
210+
211+
# Convert all Python files in the project
212+
find . -type f -name "*.py" | xargs dos2unix
213+
```
214+
215+
Prefer simple shell commands.
216+
Avoid chaining with && or pipes.

autogalaxy/config/priors/mass/dark/cnfw.yaml

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,78 @@
1+
cNFW:
2+
centre_0:
3+
type: Gaussian
4+
mean: 0.0
5+
sigma: 0.1
6+
width_modifier:
7+
type: Absolute
8+
value: 0.05
9+
limits:
10+
lower: -inf
11+
upper: inf
12+
centre_1:
13+
type: Gaussian
14+
mean: 0.0
15+
sigma: 0.1
16+
width_modifier:
17+
type: Absolute
18+
value: 0.05
19+
limits:
20+
lower: -inf
21+
upper: inf
22+
ell_comps_0:
23+
type: TruncatedGaussian
24+
mean: 0.0
25+
sigma: 0.3
26+
lower_limit: -1.0
27+
upper_limit: 1.0
28+
width_modifier:
29+
type: Absolute
30+
value: 0.2
31+
limits:
32+
lower: -1.0
33+
upper: 1.0
34+
ell_comps_1:
35+
type: TruncatedGaussian
36+
mean: 0.0
37+
sigma: 0.3
38+
lower_limit: -1.0
39+
upper_limit: 1.0
40+
width_modifier:
41+
type: Absolute
42+
value: 0.2
43+
limits:
44+
lower: -1.0
45+
upper: 1.0
46+
kappa_s:
47+
type: Uniform
48+
lower_limit: 0.0
49+
upper_limit: 1.0
50+
width_modifier:
51+
type: Relative
52+
value: 0.2
53+
limits:
54+
lower: 0.0
55+
upper: inf
56+
scale_radius:
57+
type: Uniform
58+
lower_limit: 0.0
59+
upper_limit: 30.0
60+
width_modifier:
61+
type: Relative
62+
value: 0.2
63+
limits:
64+
lower: 0.0
65+
upper: inf
66+
core_radius:
67+
type: Uniform
68+
lower_limit: 0.0
69+
upper_limit: 15.0
70+
width_modifier:
71+
type: Relative
72+
value: 0.2
73+
limits:
74+
lower: 0.0
75+
upper: inf
176
cNFWSph:
277
centre_0:
378
type: Gaussian

autogalaxy/config/priors/mass/dark/cnfw_mcr.yaml

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,88 @@
1+
cNFWMCRLudlow:
2+
centre_0:
3+
type: Gaussian
4+
mean: 0.0
5+
sigma: 0.1
6+
width_modifier:
7+
type: Absolute
8+
value: 0.05
9+
limits:
10+
lower: -inf
11+
upper: inf
12+
centre_1:
13+
type: Gaussian
14+
mean: 0.0
15+
sigma: 0.1
16+
width_modifier:
17+
type: Absolute
18+
value: 0.05
19+
limits:
20+
lower: -inf
21+
upper: inf
22+
ell_comps_0:
23+
type: TruncatedGaussian
24+
mean: 0.0
25+
sigma: 0.3
26+
lower_limit: -1.0
27+
upper_limit: 1.0
28+
width_modifier:
29+
type: Absolute
30+
value: 0.2
31+
limits:
32+
lower: -1.0
33+
upper: 1.0
34+
ell_comps_1:
35+
type: TruncatedGaussian
36+
mean: 0.0
37+
sigma: 0.3
38+
lower_limit: -1.0
39+
upper_limit: 1.0
40+
width_modifier:
41+
type: Absolute
42+
value: 0.2
43+
limits:
44+
lower: -1.0
45+
upper: 1.0
46+
mass_at_200:
47+
type: LogUniform
48+
lower_limit: 100000000.0
49+
upper_limit: 1000000000000000.0
50+
width_modifier:
51+
type: Relative
52+
value: 0.5
53+
limits:
54+
lower: 0.0
55+
upper: inf
56+
f_c:
57+
type: Uniform
58+
lower_limit: 0.0001
59+
upper_limit: 0.5
60+
width_modifier:
61+
type: Relative
62+
value: 0.2
63+
limits:
64+
lower: 0.0001
65+
upper: inf
66+
redshift_object:
67+
type: Uniform
68+
lower_limit: 0.0
69+
upper_limit: 1.0
70+
width_modifier:
71+
type: Relative
72+
value: 0.5
73+
limits:
74+
lower: 0.0
75+
upper: inf
76+
redshift_source:
77+
type: Uniform
78+
lower_limit: 0.0
79+
upper_limit: 1.0
80+
width_modifier:
81+
type: Relative
82+
value: 0.5
83+
limits:
84+
lower: 0.0
85+
upper: inf
186
cNFWMCRLudlowSph:
287
centre_0:
388
type: Gaussian

autogalaxy/config/priors/mass/dark/cnfw_mcr_scatter.yaml

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,98 @@
1+
cNFWMCRScatterLudlow:
2+
centre_0:
3+
type: Gaussian
4+
mean: 0.0
5+
sigma: 0.1
6+
width_modifier:
7+
type: Absolute
8+
value: 0.05
9+
limits:
10+
lower: -inf
11+
upper: inf
12+
centre_1:
13+
type: Gaussian
14+
mean: 0.0
15+
sigma: 0.1
16+
width_modifier:
17+
type: Absolute
18+
value: 0.05
19+
limits:
20+
lower: -inf
21+
upper: inf
22+
ell_comps_0:
23+
type: TruncatedGaussian
24+
mean: 0.0
25+
sigma: 0.3
26+
lower_limit: -1.0
27+
upper_limit: 1.0
28+
width_modifier:
29+
type: Absolute
30+
value: 0.2
31+
limits:
32+
lower: -1.0
33+
upper: 1.0
34+
ell_comps_1:
35+
type: TruncatedGaussian
36+
mean: 0.0
37+
sigma: 0.3
38+
lower_limit: -1.0
39+
upper_limit: 1.0
40+
width_modifier:
41+
type: Absolute
42+
value: 0.2
43+
limits:
44+
lower: -1.0
45+
upper: 1.0
46+
mass_at_200:
47+
type: LogUniform
48+
lower_limit: 100000000.0
49+
upper_limit: 1000000000000000.0
50+
width_modifier:
51+
type: Relative
52+
value: 0.5
53+
limits:
54+
lower: 0.0
55+
upper: inf
56+
f_c:
57+
type: Uniform
58+
lower_limit: 0.0001
59+
upper_limit: 0.5
60+
width_modifier:
61+
type: Relative
62+
value: 0.2
63+
limits:
64+
lower: 0.0001
65+
upper: inf
66+
redshift_object:
67+
type: Uniform
68+
lower_limit: 0.0
69+
upper_limit: 1.0
70+
width_modifier:
71+
type: Relative
72+
value: 0.5
73+
limits:
74+
lower: 0.0
75+
upper: inf
76+
redshift_source:
77+
type: Uniform
78+
lower_limit: 0.0
79+
upper_limit: 1.0
80+
width_modifier:
81+
type: Relative
82+
value: 0.5
83+
limits:
84+
lower: 0.0
85+
upper: inf
86+
scatter_sigma:
87+
type: Gaussian
88+
mean: 0.0
89+
sigma: 3.0
90+
width_modifier:
91+
type: Absolute
92+
value: 1.0
93+
limits:
94+
lower: -inf
95+
upper: inf
196
cNFWMCRScatterLudlowSph:
297
centre_0:
398
type: Gaussian

autogalaxy/cosmology/model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,27 @@ def kpc_per_arcsec_from(self, redshift: float, xp=np) -> float:
6262
"""
6363
return self.kpc_proper_per_arcsec(z=redshift, xp=xp)
6464

65+
def luminosity_distance(self, z: float, xp=np) -> float:
66+
"""
67+
Luminosity distance to redshift z in Mpc.
68+
69+
For a flat universe:
70+
71+
D_L(z) = (1 + z)^2 * D_A(0, z)
72+
73+
where D_A(0, z) is the angular diameter distance from Earth.
74+
75+
Returns Mpc, matching the convention of astropy.cosmology.FlatLambdaCDM.luminosity_distance(z).value.
76+
77+
Parameters
78+
----------
79+
z
80+
Redshift at which the luminosity distance is calculated.
81+
"""
82+
D_A_kpc = self.angular_diameter_distance_to_earth_in_kpc_from(redshift=z, xp=xp)
83+
D_A_Mpc = D_A_kpc / xp.asarray(1.0e3)
84+
return (xp.asarray(1.0) + xp.asarray(z)) ** 2 * D_A_Mpc
85+
6586
def angular_diameter_distance_to_earth_in_kpc_from(
6687
self, redshift: float, xp=np
6788
) -> float:

0 commit comments

Comments
 (0)