Skip to content

Commit e4e70cc

Browse files
Jammy2211claude
authored andcommitted
feat: SimulatorInterferometer.via_tracer_from auto-defaults xp from parent
Changes via_tracer_from / via_galaxies_from on autolens's SimulatorInterferometer override to default xp=None and fall back to self._xp inherited from aa.SimulatorInterferometer. When the simulator is constructed with use_jax=True, both methods route xp=jnp through tracer.image_2d_from / galaxies.image_2d_from and via_image_from without the user passing xp explicitly. Depends on Jammy2211/PyAutoArray PR adding use_jax constructor flag + _xp property to the parent aa.SimulatorInterferometer. Part of Phase 2 PR 3 of z_features/jax_user_intro.md. Issue: PyAutoArray#334 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 0ef58b7 commit e4e70cc

1 file changed

Lines changed: 8 additions & 5 deletions

File tree

autolens/interferometer/simulator.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
class SimulatorInterferometer(aa.SimulatorInterferometer):
24-
def via_tracer_from(self, tracer, grid):
24+
def via_tracer_from(self, tracer, grid, xp=None):
2525
"""
2626
Returns a realistic simulated image by applying effects to a plain simulated image.
2727
@@ -42,11 +42,14 @@ def via_tracer_from(self, tracer, grid):
4242
A seed for random noise_maps generation
4343
"""
4444

45-
image = tracer.image_2d_from(grid=grid)
45+
if xp is None:
46+
xp = self._xp
4647

47-
return self.via_image_from(image=image)
48+
image = tracer.image_2d_from(grid=grid, xp=xp)
49+
50+
return self.via_image_from(image=image, xp=xp)
4851

49-
def via_galaxies_from(self, galaxies, grid):
52+
def via_galaxies_from(self, galaxies, grid, xp=None):
5053
"""Simulate imaging data for this data, as follows:
5154
5255
1) Setup the image-plane grid of the Imaging arrays, which defines the coordinates used for the ray-tracing.
@@ -64,7 +67,7 @@ def via_galaxies_from(self, galaxies, grid):
6467

6568
tracer = Tracer(galaxies=galaxies)
6669

67-
return self.via_tracer_from(tracer=tracer, grid=grid)
70+
return self.via_tracer_from(tracer=tracer, grid=grid, xp=xp)
6871

6972
def via_deflections_and_galaxies_from(
7073
self, deflections: aa.VectorYX2D, galaxies: List[ag.Galaxy]

0 commit comments

Comments
 (0)