From 10d0eccc4fa76feba8fa1d05629bb56a6542713e Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 24 May 2026 16:38:24 +0100 Subject: [PATCH] feat: SimulatorImaging.via_galaxies_from auto-defaults xp from parent use_jax Changes via_galaxies_from to accept xp=None and fall back to self._xp inherited from aa.SimulatorImaging. When the simulator is constructed with use_jax=True, via_galaxies_from now routes xp=jnp through galaxies.padded_image_2d_from and via_image_from without the user passing xp explicitly. Depends on Jammy2211/PyAutoArray PR adding the use_jax constructor flag + _xp property to the parent aa.SimulatorImaging. Part of Phase 2 PR 2 of z_features/jax_user_intro.md. Issue: PyAutoArray#334 Co-Authored-By: Claude Opus 4.7 (1M context) --- autogalaxy/imaging/simulator.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/autogalaxy/imaging/simulator.py b/autogalaxy/imaging/simulator.py index df762288..7674ea11 100644 --- a/autogalaxy/imaging/simulator.py +++ b/autogalaxy/imaging/simulator.py @@ -20,7 +20,7 @@ class SimulatorImaging(aa.SimulatorImaging): def via_galaxies_from( - self, galaxies: List[Galaxy], grid: aa.type.Grid2DLike + self, galaxies: List[Galaxy], grid: aa.type.Grid2DLike, xp=None ) -> aa.Imaging: """ Simulate an `Imaging` dataset from an input list of `Galaxy` objects and a 2D grid of (y,x) coordinates. @@ -48,6 +48,9 @@ def via_galaxies_from( to generate the image of the galaxies. """ + if xp is None: + xp = self._xp + galaxies = Galaxies(galaxies=galaxies) for galaxy in galaxies: @@ -59,14 +62,14 @@ def via_galaxies_from( ) image = galaxies.padded_image_2d_from( - grid=grid, psf_shape_2d=self.psf.kernel.shape_native + grid=grid, psf_shape_2d=self.psf.kernel.shape_native, xp=xp ) over_sample_size = grid.over_sample_size.resized_from( new_shape=image.shape_native, mask_pad_value=1 ) - dataset = self.via_image_from(image=image, over_sample_size=over_sample_size) + dataset = self.via_image_from(image=image, over_sample_size=over_sample_size, xp=xp) return dataset.trimmed_after_convolution_from( kernel_shape=self.psf.kernel.shape_native