Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions scripts/imaging/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,112 @@
residuals = via_fft - via_real_space

print(f"Mapping Matrix Max residual = {residuals.max()}")


"""
__Mask Padding__

When the mask is close to the image edge and the PSF kernel footprint extends
beyond the boundary, the blurring mask is automatically padded. This test
verifies that the padded convolution produces the same result as an equivalent
centred configuration that requires no padding.

We simulate a compact lens+source on two grids:
- A large 51x51 grid with the model centred → no padding needed.
- A small 21x21 grid with the model offset near the edge → padding triggered.

Both should give identical log-likelihoods and chi_squared ≈ 0.
"""

import warnings

pixel_scales = 0.2

psf_pad = al.Convolver.from_gaussian(
shape_native=(11, 11), pixel_scales=pixel_scales, sigma=0.75, normalize=True
)

lens_centred = al.Galaxy(
redshift=0.5,
light=al.lp.Sersic(
centre=(0.0, 0.0), intensity=0.1,
effective_radius=0.3, sersic_index=2.0,
),
mass=al.mp.Isothermal(centre=(0.0, 0.0), einstein_radius=1.0),
)
source_centred = al.Galaxy(
redshift=1.0,
light=al.lp.Exponential(
centre=(0.0, 0.0), intensity=0.3, effective_radius=0.2,
),
)
tracer_centred = al.Tracer(galaxies=[lens_centred, source_centred])

sim_pad = al.SimulatorImaging(
exposure_time=300.0, psf=psf_pad, add_poisson_noise_to_data=False,
)

# --- Centred fit on a large grid: no padding needed ---
grid_large = al.Grid2D.uniform(
shape_native=(51, 51), pixel_scales=pixel_scales, over_sample_size=1,
)
dataset_centred = sim_pad.via_tracer_from(tracer=tracer_centred, grid=grid_large)
dataset_centred.noise_map = al.Array2D.ones(
shape_native=(51, 51), pixel_scales=pixel_scales,
)
mask_centred = al.Mask2D.circular(
shape_native=(51, 51), pixel_scales=pixel_scales,
radius=0.6, centre=(0.0, 0.0),
)
masked_centred = dataset_centred.apply_mask(mask=mask_centred)
fit_centred = al.FitImaging(dataset=masked_centred, tracer=tracer_centred)

# --- Off-centre fit on a small grid: triggers padding ---
offset = (0.0, 1.2)
lens_off = al.Galaxy(
redshift=0.5,
light=al.lp.Sersic(
centre=offset, intensity=0.1,
effective_radius=0.3, sersic_index=2.0,
),
mass=al.mp.Isothermal(centre=offset, einstein_radius=1.0),
)
source_off = al.Galaxy(
redshift=1.0,
light=al.lp.Exponential(
centre=offset, intensity=0.3, effective_radius=0.2,
),
)
tracer_off = al.Tracer(galaxies=[lens_off, source_off])

grid_small = al.Grid2D.uniform(
shape_native=(21, 21), pixel_scales=pixel_scales, over_sample_size=1,
)
dataset_off = sim_pad.via_tracer_from(tracer=tracer_off, grid=grid_small)
dataset_off.noise_map = al.Array2D.ones(
shape_native=(21, 21), pixel_scales=pixel_scales,
)
mask_off = al.Mask2D.circular(
shape_native=(21, 21), pixel_scales=pixel_scales,
radius=0.6, centre=offset,
)

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
masked_off = dataset_off.apply_mask(mask=mask_off)
fit_off = al.FitImaging(dataset=masked_off, tracer=tracer_off)
padding_occurred = any("Mask padded" in str(x.message) for x in w)

assert padding_occurred, "Expected mask padding to be triggered for the off-centre mask"
assert fit_centred.chi_squared < 1e-4, f"Centred chi_squared too large: {fit_centred.chi_squared}"
assert fit_off.chi_squared < 1e-4, f"Off-centre chi_squared too large: {fit_off.chi_squared}"

likelihood_diff = abs(fit_centred.log_likelihood - fit_off.log_likelihood)
assert likelihood_diff < 1e-4, (
f"Padded and non-padded log-likelihoods differ by {likelihood_diff}"
)

print(f"\nMask padding test PASSED")
print(f" Centred log_likelihood = {fit_centred.log_likelihood:.8f} chi_squared = {fit_centred.chi_squared:.8f}")
print(f" Padded log_likelihood = {fit_off.log_likelihood:.8f} chi_squared = {fit_off.chi_squared:.8f}")
print(f" Difference = {likelihood_diff:.2e}")