Skip to content

Commit d5560e6

Browse files
Jammy2211claude
authored andcommitted
feat: honor PYAUTO_SMALL_DATASETS in Imaging.from_fits
Mask2D.circular and Grid2D.uniform already cap to (15, 15) at 0.6"/px under PYAUTO_SMALL_DATASETS=1, but Imaging.from_fits did not — it just loaded whatever was on disk. Any caller that paired from_fits(150x150 fixture) with Mask2D.circular(shape_native=dataset.shape_native) under the env var crashed with a (150,150) vs (15,15) broadcast error on apply_mask. Add a center-crop hook in Imaging.from_fits that mirrors the existing caps: data and noise_map exceeding (15, 15) are center-cropped and pixel_scales is overridden to 0.6. The PSF is left alone (PSFs are usually already small and capping them changes shape semantics). A new utility cap_array_2d_for_small_datasets in autoarray/util/ dataset_util.py implements the cap and is reusable by other from_fits loaders in follow-up PRs. No-op when env unset OR when on-disk shape is already at-or-below the cap, so the simulator -> from_fits round-trip is unchanged. Closes Cluster E from the 2026-05-07 release-prep triage. The workspace-side env_vars.yaml override shipped earlier (PR #80 in autolens_workspace_test) becomes redundant after this lands but is left in place as belt-and-suspenders. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 26b1c03 commit d5560e6

4 files changed

Lines changed: 197 additions & 0 deletions

File tree

autoarray/dataset/imaging/dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,17 @@ def from_fits(
203203
passed into the calculations performed in the `inversion` module.
204204
"""
205205

206+
from autoarray.util.dataset_util import cap_array_2d_for_small_datasets
207+
206208
data = Array2D.from_fits(
207209
file_path=data_path, hdu=data_hdu, pixel_scales=pixel_scales
208210
)
211+
data, pixel_scales = cap_array_2d_for_small_datasets(data, pixel_scales)
209212

210213
noise_map = Array2D.from_fits(
211214
file_path=noise_map_path, hdu=noise_map_hdu, pixel_scales=pixel_scales
212215
)
216+
noise_map, pixel_scales = cap_array_2d_for_small_datasets(noise_map, pixel_scales)
213217

214218
if psf_path is not None:
215219
kernel = Array2D.from_fits(

autoarray/util/dataset_util.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,54 @@
33
from pathlib import Path
44

55

6+
SMALL_DATASETS_SHAPE_NATIVE = (15, 15)
7+
SMALL_DATASETS_PIXEL_SCALES = 0.6
8+
9+
10+
def cap_array_2d_for_small_datasets(array_2d, pixel_scales):
11+
"""
12+
Center-crop a 2D autoarray to the small-datasets cap when
13+
``PYAUTO_SMALL_DATASETS=1`` is active.
14+
15+
Returns ``(array_2d, pixel_scales)`` unchanged in any of these cases:
16+
17+
- ``PYAUTO_SMALL_DATASETS`` is not set to ``"1"``.
18+
- ``array_2d.shape_native`` is already at-or-below the cap (15, 15).
19+
20+
When the env var is set and the input shape exceeds (15, 15), returns a
21+
new ``Array2D`` center-cropped to (15, 15) with ``pixel_scales`` overridden
22+
to 0.6 — matching the convention used by ``Mask2D.circular`` and
23+
``Grid2D.uniform`` so the loaded dataset stays shape-consistent with masks
24+
and grids built under the same env var.
25+
26+
The same env var is honoured for shape construction in
27+
``Mask2D.circular`` and ``Grid2D.uniform`` (and by ``should_simulate``
28+
for on-disk regeneration). This helper closes the gap for FITS loaders
29+
that read pre-committed datasets larger than the cap, which would
30+
otherwise broadcast-mismatch against capped masks/grids.
31+
32+
Center-cropping (rather than downsampling/resampling) is intentional:
33+
smoke-mode tests don't require numerical correctness, and the simpler
34+
op avoids interpolation artifacts and a scipy dependency at this layer.
35+
"""
36+
if os.environ.get("PYAUTO_SMALL_DATASETS") != "1":
37+
return array_2d, pixel_scales
38+
39+
h, w = array_2d.shape_native
40+
cap_h, cap_w = SMALL_DATASETS_SHAPE_NATIVE
41+
if h <= cap_h and w <= cap_w:
42+
return array_2d, pixel_scales
43+
44+
from autoarray.structures.arrays.uniform_2d import Array2D
45+
46+
h0, w0 = (h - cap_h) // 2, (w - cap_w) // 2
47+
cropped = array_2d.native.array[h0:h0 + cap_h, w0:w0 + cap_w]
48+
return (
49+
Array2D.no_mask(values=cropped, pixel_scales=SMALL_DATASETS_PIXEL_SCALES),
50+
SMALL_DATASETS_PIXEL_SCALES,
51+
)
52+
53+
654
def should_simulate(dataset_path):
755
"""
856
Returns True if the dataset at ``dataset_path`` needs to be simulated.

test_autoarray/dataset/imaging/test_dataset.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,77 @@ def test__from_fits__all_data_in_one_fits_file_multiple_hdus__loads_data_psf_noi
162162
assert dataset.noise_map.mask.pixel_scales == (0.1, 0.1)
163163

164164

165+
def test__from_fits__small_datasets_env_caps_data_and_noise_map(
166+
test_data_path, monkeypatch
167+
):
168+
"""When PYAUTO_SMALL_DATASETS=1, Imaging.from_fits center-crops data and
169+
noise_map to (15, 15) at pixel_scales=0.6 so they stay shape-consistent
170+
with masks built via Mask2D.circular under the same env var. PSF is left
171+
alone."""
172+
from astropy.io import fits
173+
174+
fits.writeto(
175+
Path(test_data_path) / "data_30x30.fits",
176+
data=np.ones((30, 30), dtype=np.float64),
177+
overwrite=True,
178+
)
179+
fits.writeto(
180+
Path(test_data_path) / "noise_map_30x30.fits",
181+
data=2.0 * np.ones((30, 30), dtype=np.float64),
182+
overwrite=True,
183+
)
184+
fits.writeto(
185+
Path(test_data_path) / "psf_5x5.fits",
186+
data=(1.0 / 25.0) * np.ones((5, 5), dtype=np.float64),
187+
overwrite=True,
188+
)
189+
190+
monkeypatch.setenv("PYAUTO_SMALL_DATASETS", "1")
191+
192+
dataset = aa.Imaging.from_fits(
193+
pixel_scales=0.08,
194+
data_path=Path(test_data_path) / "data_30x30.fits",
195+
psf_path=Path(test_data_path) / "psf_5x5.fits",
196+
noise_map_path=Path(test_data_path) / "noise_map_30x30.fits",
197+
)
198+
199+
assert dataset.data.shape_native == (15, 15)
200+
assert dataset.noise_map.shape_native == (15, 15)
201+
assert dataset.pixel_scales == (0.6, 0.6)
202+
assert dataset.psf.kernel.shape_native == (5, 5)
203+
204+
205+
def test__from_fits__small_datasets_env_unset__shape_unchanged(
206+
test_data_path, monkeypatch
207+
):
208+
"""Sanity: with the env var unset, from_fits returns the on-disk shape
209+
unchanged, even for files larger than the cap."""
210+
from astropy.io import fits
211+
212+
fits.writeto(
213+
Path(test_data_path) / "data_30x30.fits",
214+
data=np.ones((30, 30), dtype=np.float64),
215+
overwrite=True,
216+
)
217+
fits.writeto(
218+
Path(test_data_path) / "noise_map_30x30.fits",
219+
data=2.0 * np.ones((30, 30), dtype=np.float64),
220+
overwrite=True,
221+
)
222+
223+
monkeypatch.delenv("PYAUTO_SMALL_DATASETS", raising=False)
224+
225+
dataset = aa.Imaging.from_fits(
226+
pixel_scales=0.08,
227+
data_path=Path(test_data_path) / "data_30x30.fits",
228+
noise_map_path=Path(test_data_path) / "noise_map_30x30.fits",
229+
)
230+
231+
assert dataset.data.shape_native == (30, 30)
232+
assert dataset.noise_map.shape_native == (30, 30)
233+
assert dataset.pixel_scales == (0.08, 0.08)
234+
235+
165236
def test__output_to_fits__round_trips_data_psf_noise_map_correctly(
166237
imaging_7x7, test_data_path
167238
):
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import numpy as np
2+
3+
import autoarray as aa
4+
5+
from autoarray.util.dataset_util import (
6+
cap_array_2d_for_small_datasets,
7+
SMALL_DATASETS_SHAPE_NATIVE,
8+
SMALL_DATASETS_PIXEL_SCALES,
9+
)
10+
11+
12+
def _array_2d(shape, fill=1.0, pixel_scales=0.08):
13+
return aa.Array2D.no_mask(
14+
values=fill * np.ones(shape), pixel_scales=pixel_scales
15+
)
16+
17+
18+
def test__env_unset__returns_inputs_unchanged(monkeypatch):
19+
monkeypatch.delenv("PYAUTO_SMALL_DATASETS", raising=False)
20+
21+
array = _array_2d((150, 150))
22+
result, pixel_scales = cap_array_2d_for_small_datasets(array, 0.08)
23+
24+
assert result is array
25+
assert pixel_scales == 0.08
26+
27+
28+
def test__env_set__shape_already_at_cap__returns_inputs_unchanged(monkeypatch):
29+
monkeypatch.setenv("PYAUTO_SMALL_DATASETS", "1")
30+
31+
array = _array_2d(SMALL_DATASETS_SHAPE_NATIVE, pixel_scales=0.08)
32+
result, pixel_scales = cap_array_2d_for_small_datasets(array, 0.08)
33+
34+
assert result is array
35+
assert pixel_scales == 0.08
36+
37+
38+
def test__env_set__shape_below_cap__returns_inputs_unchanged(monkeypatch):
39+
monkeypatch.setenv("PYAUTO_SMALL_DATASETS", "1")
40+
41+
array = _array_2d((10, 10), pixel_scales=0.08)
42+
result, pixel_scales = cap_array_2d_for_small_datasets(array, 0.08)
43+
44+
assert result is array
45+
assert pixel_scales == 0.08
46+
47+
48+
def test__env_set__shape_above_cap__center_crops_and_overrides_pixel_scales(
49+
monkeypatch,
50+
):
51+
monkeypatch.setenv("PYAUTO_SMALL_DATASETS", "1")
52+
53+
raw = np.arange(150 * 150, dtype=float).reshape(150, 150)
54+
array = aa.Array2D.no_mask(values=raw, pixel_scales=0.08)
55+
56+
result, pixel_scales = cap_array_2d_for_small_datasets(array, 0.08)
57+
58+
assert result is not array
59+
assert result.shape_native == SMALL_DATASETS_SHAPE_NATIVE
60+
assert pixel_scales == SMALL_DATASETS_PIXEL_SCALES
61+
62+
h0 = (150 - 15) // 2
63+
expected = raw[h0:h0 + 15, h0:h0 + 15]
64+
assert (result.native.array == expected).all()
65+
66+
67+
def test__env_set__non_square_above_cap__center_crops_to_15x15(monkeypatch):
68+
monkeypatch.setenv("PYAUTO_SMALL_DATASETS", "1")
69+
70+
array = _array_2d((100, 50))
71+
result, pixel_scales = cap_array_2d_for_small_datasets(array, 0.08)
72+
73+
assert result.shape_native == SMALL_DATASETS_SHAPE_NATIVE
74+
assert pixel_scales == SMALL_DATASETS_PIXEL_SCALES

0 commit comments

Comments
 (0)