Skip to content

Commit e0c85d3

Browse files
committed
Merge branch 'main' into feature/ellipse_utils
2 parents a820ae6 + afa1e8c commit e0c85d3

60 files changed

Lines changed: 2828 additions & 1805 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CLAUDE.md

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,75 @@ Each dataset type has an `Analysis*` class that implements `log_likelihood_funct
7878

7979
These inherit from `AnalysisDataset``Analysis` (in `analysis/analysis/`), which inherits `af.Analysis`. The `log_likelihood_function` builds a `Fit*` object from the `af.ModelInstance` and returns its `figure_of_merit`.
8080

81+
### Decorator System (from autoarray)
82+
83+
Profile methods that consume a grid and return an array, grid, or vector use decorators from `autoarray.structures.decorators`. These ensure the **output type matches the input grid type**:
84+
85+
| Decorator | `Grid2D` input | `Grid2DIrregular` input |
86+
|---|---|---|
87+
| `@aa.grid_dec.to_array` | `Array2D` | `ArrayIrregular` |
88+
| `@aa.grid_dec.to_grid` | `Grid2D` | `Grid2DIrregular` |
89+
| `@aa.grid_dec.to_vector_yx` | `VectorYX2D` | `VectorYX2DIrregular` |
90+
91+
The `@aa.grid_dec.transform` decorator (always stacked below the output decorator) shifts and rotates the grid to the profile's reference frame before passing it to the function body.
92+
93+
The canonical stacking order is:
94+
```python
95+
@aa.grid_dec.to_array # outermost: wraps output
96+
@aa.grid_dec.transform # innermost: transforms grid
97+
def convergence_2d_from(self, grid, xp=np, **kwargs):
98+
y = grid.array[:, 0] # use .array to get raw numpy/jax array
99+
x = grid.array[:, 1]
100+
return ... # return raw array; decorator wraps it
101+
```
102+
103+
**Key rule**: the function body must return a **raw array** (not an autoarray). The decorator handles wrapping. Access grid coordinates via `grid.array[:, 0]` / `grid.array[:, 1]` (not `grid[:, 0]`), because after `@transform` the grid is still an autoarray object and `.array` is the safe way to extract the underlying data for both numpy and jax backends.
104+
105+
See PyAutoArray's `CLAUDE.md` for full details on the decorator internals.
106+
81107
### JAX Support
82108

83-
JAX is integrated via the `xp` parameter pattern throughout the codebase. Fit classes accept `xp=np` (NumPy, default) or `xp=jnp` (JAX). The `AbstractFitInversion.use_jax` property tracks which backend is active. The `AnalysisImaging.__init__` has `use_jax: bool = True`. The conftest.py forces JAX backend initialization before tests run.
109+
The codebase is designed so that **NumPy is the default everywhere and JAX is opt-in**. JAX is never imported at module level — it is only imported locally inside functions when explicitly requested.
110+
111+
The `xp` parameter pattern is the single point of control:
112+
- `xp=np` (default throughout) — pure NumPy path, no JAX dependency at runtime
113+
- `xp=jnp` — JAX path, imports `jax` / `jax.numpy` locally inside the function
114+
115+
This means:
116+
- **Unit tests** (`test_autogalaxy/`) always run on the NumPy path. No test should import JAX or pass `xp=jnp` unless it is explicitly testing the JAX path.
117+
- **Integration tests** (in `autogalaxy_workspace_test/`) are where the JAX path is exercised, typically wrapped in `jax.jit` to test both correctness and compilation.
118+
- `conftest.py` forces JAX backend initialisation before the test suite runs, but this only ensures JAX is available — it does not switch the default backend.
119+
120+
`AbstractFitInversion.use_jax` tracks whether a fit was constructed with JAX. `AnalysisImaging` has `use_jax: bool = True` to opt into the JAX path for model-fitting.
121+
122+
When adding a new function that should support JAX:
123+
1. Default the parameter to `xp=np`
124+
2. Guard any JAX imports with `if xp is not np:` and import `jax` / `jax.numpy` locally inside that branch
125+
3. Add the NumPy implementation as the default path (finite-difference, `np.*` calls, etc.)
126+
4. Add a JAX implementation in the guarded branch (e.g. `jax.jacfwd`, `jnp.vectorize`)
127+
5. Verify correctness by comparing both paths in `autogalaxy_workspace_test/scripts/`
128+
129+
### JAX and autoarray wrappers at the `jax.jit` boundary
130+
131+
Autoarray types (`Array2D`, `ArrayIrregular`, `VectorYX2DIrregular`, etc.) are **not registered as JAX pytrees**. This means:
132+
133+
- Constructing them **inside** a JIT trace is fine (Python code runs normally during tracing)
134+
- **Returning** them as the output of a `jax.jit`-compiled function **fails** with `TypeError: ... is not a valid JAX type`
135+
136+
Functions decorated with `@aa.grid_dec.to_array` / `@to_vector_yx` wrap their return value in an autoarray type. This wrapping is safe for intermediate calls (the autoarray object is consumed by downstream Python code). However, if such a function is the **outermost call** inside a `jax.jit` lambda, its return value will fail at the JIT boundary.
137+
138+
The solution is the **`if xp is np:` guard** in the function body:
139+
140+
```python
141+
def convergence_2d_via_hessian_from(self, grid, xp=np):
142+
convergence = 0.5 * (hessian_yy + hessian_xx)
143+
144+
if xp is np:
145+
return aa.ArrayIrregular(values=convergence) # numpy: wrapped
146+
return convergence # jax: raw jax.Array
147+
```
148+
149+
This pattern is applied throughout `autogalaxy/operate/lens_calc.py`. Functions that are only ever called as intermediate steps (e.g. `deflections_yx_2d_from`) do NOT need this guard — their autoarray wrappers are never the JIT output.
84150

85151
### Linear Light Profiles & Inversions
86152

@@ -101,6 +167,10 @@ Default priors, visualization settings, and general config live in `autogalaxy/c
101167

102168
Both are mixin classes inherited by `LightProfile`, `MassProfile`, `Galaxy`, and `Galaxies`.
103169

170+
### Workspace Script Style
171+
172+
Scripts in `autogalaxy_workspace` and `autogalaxy_workspace_test` use `"""..."""` docstring blocks as prose commentary throughout — **not** `#` comments. Every script opens with a module-level docstring (title + underline + description), and each logical section of code is preceded by a `"""..."""` block with a `__Section Name__` header explaining what follows. See any script in `autogalaxy_workspace/scripts/` for examples of this style.
173+
104174
### Workspace (Examples & Notebooks)
105175

106176
The `autogalaxy_workspace` at `/mnt/c/Users/Jammy/Code/PyAutoJAX/autogalaxy_workspace` contains runnable examples and tutorials. Key locations:
@@ -125,3 +195,22 @@ When importing `autogalaxy as ag`:
125195
- `ag.ps.*` – point sources
126196
- `ag.Galaxy`, `ag.Galaxies`
127197
- `ag.FitImaging`, `ag.AnalysisImaging`, `ag.SimulatorImaging`
198+
199+
## Line Endings — Always Unix (LF)
200+
201+
All files in this project **must use Unix line endings (LF, `\n`)**. Windows/DOS line endings (CRLF, `\r\n`) will break Python files on HPC systems.
202+
203+
**When writing or editing any file**, always produce Unix line endings. Never write `\r\n` line endings.
204+
205+
After creating or copying files, verify and convert if needed:
206+
207+
```bash
208+
# Check for DOS line endings
209+
file autogalaxy/galaxy/galaxy.py # should say "ASCII text", not "CRLF"
210+
211+
# Convert all Python files in the project
212+
find . -type f -name "*.py" | xargs dos2unix
213+
```
214+
215+
Prefer simple shell commands.
216+
Avoid chaining with && or pipes.

autogalaxy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
from .operate.image import OperateImage
7070
from .operate.image import OperateImageList
7171
from .operate.image import OperateImageGalaxies
72-
from .operate.deflections import OperateDeflections
72+
from .operate.lens_calc import LensCalc
7373
from .gui.scribbler import Scribbler
7474
from .imaging.fit_imaging import FitImaging
7575
from .imaging.model.analysis import AnalysisImaging

autogalaxy/analysis/model_util.py

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def mge_model_from(
1111
centre_prior_is_uniform: bool = True,
1212
centre: Tuple[float, float] = (0.0, 0.0),
1313
centre_fixed: Optional[Tuple[float, float]] = None,
14+
centre_sigma: float = 0.3,
1415
use_spherical: bool = False,
1516
) -> af.Collection:
1617
"""
@@ -89,8 +90,8 @@ def mge_model_from(
8990
lower_limit=centre[1] - 0.1, upper_limit=centre[1] + 0.1
9091
)
9192
else:
92-
centre_0 = af.GaussianPrior(mean=centre[0], sigma=0.3)
93-
centre_1 = af.GaussianPrior(mean=centre[1], sigma=0.3)
93+
centre_0 = af.GaussianPrior(mean=centre[0], sigma=centre_sigma)
94+
centre_1 = af.GaussianPrior(mean=centre[1], sigma=centre_sigma)
9495

9596
if use_spherical:
9697
model_cls = GaussianSph
@@ -129,6 +130,106 @@ def mge_model_from(
129130
)
130131

131132

133+
def mge_point_model_from(
134+
pixel_scales: float,
135+
total_gaussians: int = 10,
136+
centre: Tuple[float, float] = (0.0, 0.0),
137+
) -> af.Model:
138+
"""
139+
Construct a Multi-Gaussian Expansion (MGE) model for a compact or unresolved
140+
point-like component (e.g. a nuclear starburst, AGN, or unresolved bulge).
141+
142+
The model is composed of ``total_gaussians`` linear Gaussians whose sigma values
143+
are logarithmically spaced between 0.01 arcseconds and twice the pixel scale.
144+
All Gaussians share the same centre and ellipticity components, keeping the
145+
parameter count low while capturing a realistic PSF-convolved point source.
146+
147+
Parameters
148+
----------
149+
pixel_scales
150+
The pixel scale of the image in arcseconds per pixel. The maximum Gaussian
151+
width is set to ``2 * pixel_scales`` so that the model is compact relative to
152+
the resolution of the data.
153+
total_gaussians
154+
Number of Gaussian components in the basis.
155+
centre
156+
(y, x) centre of the point source in arc-seconds. A ±0.1 arcsecond uniform
157+
prior is placed on each coordinate.
158+
159+
Returns
160+
-------
161+
af.Model
162+
An ``autofit.Model`` wrapping a ``Basis`` of linear Gaussians.
163+
"""
164+
165+
from autogalaxy.profiles.light.linear import Gaussian
166+
from autogalaxy.profiles.basis import Basis
167+
168+
if total_gaussians < 1:
169+
raise ValueError(
170+
f"mge_point_model_from requires total_gaussians >= 1, got {total_gaussians}."
171+
)
172+
173+
if pixel_scales <= 0:
174+
raise ValueError(
175+
f"mge_point_model_from requires pixel_scales > 0, got {pixel_scales}."
176+
)
177+
178+
# Sigma values are logarithmically spaced between 0.01 arcsec (10**-2)
179+
# and twice the pixel scale, with a floor to avoid taking log10 of
180+
# very small or non-positive values.
181+
min_log10_sigma = -2.0 # corresponds to 0.01 arcsec
182+
max_sigma = max(2.0 * pixel_scales, 10**min_log10_sigma)
183+
max_log10_sigma = np.log10(max_sigma)
184+
185+
log10_sigma_list = np.linspace(min_log10_sigma, max_log10_sigma, total_gaussians)
186+
centre_0 = af.UniformPrior(lower_limit=centre[0] - 0.1, upper_limit=centre[0] + 0.1)
187+
centre_1 = af.UniformPrior(lower_limit=centre[1] - 0.1, upper_limit=centre[1] + 0.1)
188+
189+
gaussian_list = af.Collection(af.Model(Gaussian) for _ in range(total_gaussians))
190+
191+
for i, gaussian in enumerate(gaussian_list):
192+
gaussian.centre.centre_0 = centre_0
193+
gaussian.centre.centre_1 = centre_1
194+
gaussian.ell_comps = gaussian_list[0].ell_comps
195+
gaussian.sigma = 10 ** log10_sigma_list[i]
196+
197+
return af.Model(Basis, profile_list=gaussian_list)
198+
199+
200+
def hilbert_pixels_from_pixel_scale(pixel_scale: float) -> int:
201+
"""
202+
Return the number of Hilbert-curve pixels appropriate for a given image pixel scale.
203+
204+
The Hilbert pixel count controls the resolution of the Hilbert-curve ordering used
205+
in adaptive source-plane pixelizations. Finer pixel scales resolve smaller angular
206+
features and therefore benefit from a higher Hilbert resolution.
207+
208+
Parameters
209+
----------
210+
pixel_scale
211+
The pixel scale of the image in arcseconds per pixel.
212+
213+
Returns
214+
-------
215+
int
216+
The recommended number of Hilbert pixels.
217+
"""
218+
if not np.isfinite(pixel_scale) or pixel_scale <= 0:
219+
raise ValueError(
220+
f"hilbert_pixels_from_pixel_scale requires pixel_scale to be finite and > 0, got {pixel_scale}."
221+
)
222+
223+
if pixel_scale > 0.06:
224+
return 1000
225+
elif pixel_scale > 0.04:
226+
return 1250
227+
elif pixel_scale >= 0.03:
228+
return 1500
229+
else:
230+
return 1750
231+
232+
132233
def simulator_start_here_model_from():
133234

134235
from autogalaxy.profiles.light.snr import Sersic

autogalaxy/config/general.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
fits:
2-
flip_for_ds9: true
2+
flip_for_ds9: false
33
psf:
44
use_fft_default: true # If True, PSFs are convolved using FFTs by default, which is faster and uses less memory in all cases except for very small PSFs, False uses direct convolution.
55
updates:

autogalaxy/config/priors/mass/dark/cnfw.yaml

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,78 @@
1+
cNFW:
2+
centre_0:
3+
type: Gaussian
4+
mean: 0.0
5+
sigma: 0.1
6+
width_modifier:
7+
type: Absolute
8+
value: 0.05
9+
limits:
10+
lower: -inf
11+
upper: inf
12+
centre_1:
13+
type: Gaussian
14+
mean: 0.0
15+
sigma: 0.1
16+
width_modifier:
17+
type: Absolute
18+
value: 0.05
19+
limits:
20+
lower: -inf
21+
upper: inf
22+
ell_comps_0:
23+
type: TruncatedGaussian
24+
mean: 0.0
25+
sigma: 0.3
26+
lower_limit: -1.0
27+
upper_limit: 1.0
28+
width_modifier:
29+
type: Absolute
30+
value: 0.2
31+
limits:
32+
lower: -1.0
33+
upper: 1.0
34+
ell_comps_1:
35+
type: TruncatedGaussian
36+
mean: 0.0
37+
sigma: 0.3
38+
lower_limit: -1.0
39+
upper_limit: 1.0
40+
width_modifier:
41+
type: Absolute
42+
value: 0.2
43+
limits:
44+
lower: -1.0
45+
upper: 1.0
46+
kappa_s:
47+
type: Uniform
48+
lower_limit: 0.0
49+
upper_limit: 1.0
50+
width_modifier:
51+
type: Relative
52+
value: 0.2
53+
limits:
54+
lower: 0.0
55+
upper: inf
56+
scale_radius:
57+
type: Uniform
58+
lower_limit: 0.0
59+
upper_limit: 30.0
60+
width_modifier:
61+
type: Relative
62+
value: 0.2
63+
limits:
64+
lower: 0.0
65+
upper: inf
66+
core_radius:
67+
type: Uniform
68+
lower_limit: 0.0
69+
upper_limit: 15.0
70+
width_modifier:
71+
type: Relative
72+
value: 0.2
73+
limits:
74+
lower: 0.0
75+
upper: inf
176
cNFWSph:
277
centre_0:
378
type: Gaussian

0 commit comments

Comments
 (0)