Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions autoarray/dataset/imaging/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,17 @@ def from_fits(
passed into the calculations performed in the `inversion` module.
"""

from autoarray.util.dataset_util import cap_array_2d_for_small_datasets

data = Array2D.from_fits(
file_path=data_path, hdu=data_hdu, pixel_scales=pixel_scales
)
data, pixel_scales = cap_array_2d_for_small_datasets(data, pixel_scales)

noise_map = Array2D.from_fits(
file_path=noise_map_path, hdu=noise_map_hdu, pixel_scales=pixel_scales
)
noise_map, pixel_scales = cap_array_2d_for_small_datasets(noise_map, pixel_scales)

if psf_path is not None:
kernel = Array2D.from_fits(
Expand Down
48 changes: 48 additions & 0 deletions autoarray/util/dataset_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,54 @@
from pathlib import Path


SMALL_DATASETS_SHAPE_NATIVE = (15, 15)
SMALL_DATASETS_PIXEL_SCALES = 0.6


def cap_array_2d_for_small_datasets(array_2d, pixel_scales):
"""
Center-crop a 2D autoarray to the small-datasets cap when
``PYAUTO_SMALL_DATASETS=1`` is active.

Returns ``(array_2d, pixel_scales)`` unchanged in any of these cases:

- ``PYAUTO_SMALL_DATASETS`` is not set to ``"1"``.
- ``array_2d.shape_native`` is already at-or-below the cap (15, 15).

When the env var is set and the input shape exceeds (15, 15), returns a
new ``Array2D`` center-cropped to (15, 15) with ``pixel_scales`` overridden
to 0.6 — matching the convention used by ``Mask2D.circular`` and
``Grid2D.uniform`` so the loaded dataset stays shape-consistent with masks
and grids built under the same env var.

The same env var is honoured for shape construction in
``Mask2D.circular`` and ``Grid2D.uniform`` (and by ``should_simulate``
for on-disk regeneration). This helper closes the gap for FITS loaders
that read pre-committed datasets larger than the cap, which would
otherwise broadcast-mismatch against capped masks/grids.

Center-cropping (rather than downsampling/resampling) is intentional:
smoke-mode tests don't require numerical correctness, and the simpler
op avoids interpolation artifacts and a scipy dependency at this layer.
"""
if os.environ.get("PYAUTO_SMALL_DATASETS") != "1":
return array_2d, pixel_scales

h, w = array_2d.shape_native
cap_h, cap_w = SMALL_DATASETS_SHAPE_NATIVE
if h <= cap_h and w <= cap_w:
return array_2d, pixel_scales

from autoarray.structures.arrays.uniform_2d import Array2D

h0, w0 = (h - cap_h) // 2, (w - cap_w) // 2
cropped = array_2d.native.array[h0:h0 + cap_h, w0:w0 + cap_w]
return (
Array2D.no_mask(values=cropped, pixel_scales=SMALL_DATASETS_PIXEL_SCALES),
SMALL_DATASETS_PIXEL_SCALES,
)


def should_simulate(dataset_path):
"""
Returns True if the dataset at ``dataset_path`` needs to be simulated.
Expand Down
71 changes: 71 additions & 0 deletions test_autoarray/dataset/imaging/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,77 @@ def test__from_fits__all_data_in_one_fits_file_multiple_hdus__loads_data_psf_noi
assert dataset.noise_map.mask.pixel_scales == (0.1, 0.1)


def test__from_fits__small_datasets_env_caps_data_and_noise_map(
test_data_path, monkeypatch
):
"""When PYAUTO_SMALL_DATASETS=1, Imaging.from_fits center-crops data and
noise_map to (15, 15) at pixel_scales=0.6 so they stay shape-consistent
with masks built via Mask2D.circular under the same env var. PSF is left
alone."""
from astropy.io import fits

fits.writeto(
Path(test_data_path) / "data_30x30.fits",
data=np.ones((30, 30), dtype=np.float64),
overwrite=True,
)
fits.writeto(
Path(test_data_path) / "noise_map_30x30.fits",
data=2.0 * np.ones((30, 30), dtype=np.float64),
overwrite=True,
)
fits.writeto(
Path(test_data_path) / "psf_5x5.fits",
data=(1.0 / 25.0) * np.ones((5, 5), dtype=np.float64),
overwrite=True,
)

monkeypatch.setenv("PYAUTO_SMALL_DATASETS", "1")

dataset = aa.Imaging.from_fits(
pixel_scales=0.08,
data_path=Path(test_data_path) / "data_30x30.fits",
psf_path=Path(test_data_path) / "psf_5x5.fits",
noise_map_path=Path(test_data_path) / "noise_map_30x30.fits",
)

assert dataset.data.shape_native == (15, 15)
assert dataset.noise_map.shape_native == (15, 15)
assert dataset.pixel_scales == (0.6, 0.6)
assert dataset.psf.kernel.shape_native == (5, 5)


def test__from_fits__small_datasets_env_unset__shape_unchanged(
test_data_path, monkeypatch
):
"""Sanity: with the env var unset, from_fits returns the on-disk shape
unchanged, even for files larger than the cap."""
from astropy.io import fits

fits.writeto(
Path(test_data_path) / "data_30x30.fits",
data=np.ones((30, 30), dtype=np.float64),
overwrite=True,
)
fits.writeto(
Path(test_data_path) / "noise_map_30x30.fits",
data=2.0 * np.ones((30, 30), dtype=np.float64),
overwrite=True,
)

monkeypatch.delenv("PYAUTO_SMALL_DATASETS", raising=False)

dataset = aa.Imaging.from_fits(
pixel_scales=0.08,
data_path=Path(test_data_path) / "data_30x30.fits",
noise_map_path=Path(test_data_path) / "noise_map_30x30.fits",
)

assert dataset.data.shape_native == (30, 30)
assert dataset.noise_map.shape_native == (30, 30)
assert dataset.pixel_scales == (0.08, 0.08)


def test__output_to_fits__round_trips_data_psf_noise_map_correctly(
imaging_7x7, test_data_path
):
Expand Down
74 changes: 74 additions & 0 deletions test_autoarray/util/test_dataset_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import numpy as np

import autoarray as aa

from autoarray.util.dataset_util import (
cap_array_2d_for_small_datasets,
SMALL_DATASETS_SHAPE_NATIVE,
SMALL_DATASETS_PIXEL_SCALES,
)


def _array_2d(shape, fill=1.0, pixel_scales=0.08):
return aa.Array2D.no_mask(
values=fill * np.ones(shape), pixel_scales=pixel_scales
)


def test__env_unset__returns_inputs_unchanged(monkeypatch):
monkeypatch.delenv("PYAUTO_SMALL_DATASETS", raising=False)

array = _array_2d((150, 150))
result, pixel_scales = cap_array_2d_for_small_datasets(array, 0.08)

assert result is array
assert pixel_scales == 0.08


def test__env_set__shape_already_at_cap__returns_inputs_unchanged(monkeypatch):
monkeypatch.setenv("PYAUTO_SMALL_DATASETS", "1")

array = _array_2d(SMALL_DATASETS_SHAPE_NATIVE, pixel_scales=0.08)
result, pixel_scales = cap_array_2d_for_small_datasets(array, 0.08)

assert result is array
assert pixel_scales == 0.08


def test__env_set__shape_below_cap__returns_inputs_unchanged(monkeypatch):
monkeypatch.setenv("PYAUTO_SMALL_DATASETS", "1")

array = _array_2d((10, 10), pixel_scales=0.08)
result, pixel_scales = cap_array_2d_for_small_datasets(array, 0.08)

assert result is array
assert pixel_scales == 0.08


def test__env_set__shape_above_cap__center_crops_and_overrides_pixel_scales(
monkeypatch,
):
monkeypatch.setenv("PYAUTO_SMALL_DATASETS", "1")

raw = np.arange(150 * 150, dtype=float).reshape(150, 150)
array = aa.Array2D.no_mask(values=raw, pixel_scales=0.08)

result, pixel_scales = cap_array_2d_for_small_datasets(array, 0.08)

assert result is not array
assert result.shape_native == SMALL_DATASETS_SHAPE_NATIVE
assert pixel_scales == SMALL_DATASETS_PIXEL_SCALES

h0 = (150 - 15) // 2
expected = raw[h0:h0 + 15, h0:h0 + 15]
assert (result.native.array == expected).all()


def test__env_set__non_square_above_cap__center_crops_to_15x15(monkeypatch):
monkeypatch.setenv("PYAUTO_SMALL_DATASETS", "1")

array = _array_2d((100, 50))
result, pixel_scales = cap_array_2d_for_small_datasets(array, 0.08)

assert result.shape_native == SMALL_DATASETS_SHAPE_NATIVE
assert pixel_scales == SMALL_DATASETS_PIXEL_SCALES
Loading