From e4e70cc7743e4106bec352c148720d6b67ea1735 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 24 May 2026 16:56:55 +0100 Subject: [PATCH] feat: SimulatorInterferometer.via_tracer_from auto-defaults xp from parent Changes via_tracer_from / via_galaxies_from on autolens's SimulatorInterferometer override to default xp=None and fall back to self._xp inherited from aa.SimulatorInterferometer. When the simulator is constructed with use_jax=True, both methods route xp=jnp through tracer.image_2d_from / galaxies.image_2d_from and via_image_from without the user passing xp explicitly. Depends on Jammy2211/PyAutoArray PR adding use_jax constructor flag + _xp property to the parent aa.SimulatorInterferometer. Part of Phase 2 PR 3 of z_features/jax_user_intro.md. Issue: PyAutoArray#334 Co-Authored-By: Claude Opus 4.7 (1M context) --- autolens/interferometer/simulator.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/autolens/interferometer/simulator.py b/autolens/interferometer/simulator.py index 0385e1d58..d86367811 100644 --- a/autolens/interferometer/simulator.py +++ b/autolens/interferometer/simulator.py @@ -21,7 +21,7 @@ class SimulatorInterferometer(aa.SimulatorInterferometer): - def via_tracer_from(self, tracer, grid): + def via_tracer_from(self, tracer, grid, xp=None): """ Returns a realistic simulated image by applying effects to a plain simulated image. @@ -42,11 +42,14 @@ def via_tracer_from(self, tracer, grid): A seed for random noise_maps generation """ - image = tracer.image_2d_from(grid=grid) + if xp is None: + xp = self._xp - return self.via_image_from(image=image) + image = tracer.image_2d_from(grid=grid, xp=xp) + + return self.via_image_from(image=image, xp=xp) - def via_galaxies_from(self, galaxies, grid): + def via_galaxies_from(self, galaxies, grid, xp=None): """Simulate imaging data for this data, as follows: 1) Setup the image-plane grid of the Imaging arrays, which defines the coordinates used for the ray-tracing. @@ -64,7 +67,7 @@ def via_galaxies_from(self, galaxies, grid): tracer = Tracer(galaxies=galaxies) - return self.via_tracer_from(tracer=tracer, grid=grid) + return self.via_tracer_from(tracer=tracer, grid=grid, xp=xp) def via_deflections_and_galaxies_from( self, deflections: aa.VectorYX2D, galaxies: List[ag.Galaxy]