Skip to content

Commit f982dd0

Browse files
Jammy2211Jammy2211claude
authored
feat: simulate_substructure end-to-end jittable simulator (#544)
Adds simulate_substructure to substructure_util.py — a pure jnp function chaining scan-based ray-tracing, source light evaluation, FFT PSF convolution, and Poisson noise via jax.random.PRNGKey. Supports prng_key=None to skip noise for deterministic comparison. Ref: PyAutoLens#542 prompt 3 of 4. Co-authored-by: Jammy2211 <JNightingale2211@gmail.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 07e0d24 commit f982dd0

1 file changed

Lines changed: 48 additions & 0 deletions

File tree

autolens/lens/substructure_util.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,51 @@ def scan_step(carry, plane_inputs):
120120
_, traced_grids = jax.lax.scan(scan_step, init_carry, plane_stack)
121121

122122
return traced_grids
123+
124+
125+
def simulate_substructure(
126+
grid,
127+
image_shape,
128+
halo_params,
129+
halo_mask,
130+
scaling_matrix,
131+
macro_deflections_fn,
132+
macro_plane_mask,
133+
sheet_kappas,
134+
source_image_fn,
135+
psf_kernel,
136+
exposure_time,
137+
background_sky_level,
138+
prng_key,
139+
halo_profile_cls,
140+
):
141+
import jax
142+
import jax.numpy as jnp
143+
144+
traced_grids = traced_grids_via_scan(
145+
grid=grid,
146+
halo_params=halo_params,
147+
halo_mask=halo_mask,
148+
scaling_matrix=scaling_matrix,
149+
macro_deflections_fn=macro_deflections_fn,
150+
macro_plane_mask=macro_plane_mask,
151+
sheet_kappas=sheet_kappas,
152+
halo_profile_cls=halo_profile_cls,
153+
)
154+
155+
source_grid = traced_grids[-1]
156+
image_1d = source_image_fn(source_grid)
157+
image_2d = image_1d.reshape(image_shape)
158+
159+
image_2d = jax.scipy.signal.fftconvolve(image_2d, psf_kernel, mode="same")
160+
161+
image_2d = image_2d + background_sky_level
162+
163+
if prng_key is not None:
164+
image_counts = image_2d * exposure_time
165+
noisy_counts = jax.random.poisson(prng_key, image_counts)
166+
image_2d = noisy_counts / exposure_time - background_sky_level
167+
else:
168+
image_2d = image_2d - background_sky_level
169+
170+
return image_2d

0 commit comments

Comments
 (0)