Skip to content

Commit 314d8cd

Browse files
Jammy2211claude
authored andcommitted
feat: scan-based multi-plane ray-tracing for substructure simulations
Adds traced_grids_via_scan using jax.lax.scan over redshift planes, replacing the Python-loop approach for O(1000) halos. Includes helpers precompute_scaling_matrix and galaxies_to_halo_arrays to convert from the existing Galaxy/LOSSampler API to padded array inputs. Ref: PyAutoLens#542 prompt 2 of 4. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent db7fdc9 commit 314d8cd

1 file changed

Lines changed: 122 additions & 0 deletions

File tree

autolens/lens/substructure_util.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import numpy as np
2+
3+
import autogalaxy as ag
4+
5+
6+
def precompute_scaling_matrix(plane_redshifts, cosmology=None):
7+
import jax.numpy as jnp
8+
9+
cosmology = cosmology or ag.cosmo.Planck15()
10+
n = len(plane_redshifts)
11+
z_final = plane_redshifts[-1]
12+
mat = np.zeros((n, n))
13+
14+
for i in range(n):
15+
for j in range(i):
16+
mat[i, j] = float(
17+
cosmology.scaling_factor_between_redshifts_from(
18+
redshift_0=plane_redshifts[j],
19+
redshift_1=plane_redshifts[i],
20+
redshift_final=z_final,
21+
)
22+
)
23+
24+
return jnp.array(mat)
25+
26+
27+
def galaxies_to_halo_arrays(galaxies, plane_redshifts, max_n, profile_cls):
28+
import jax.numpy as jnp
29+
30+
n_planes = len(plane_redshifts)
31+
32+
if profile_cls is ag.mp.cNFWSph:
33+
n_params = 5
34+
def extract(prof):
35+
return [
36+
prof.centre[0], prof.centre[1],
37+
prof.kappa_s, prof.scale_radius, prof.core_radius,
38+
]
39+
else:
40+
n_params = 5
41+
def extract(prof):
42+
return [
43+
prof.centre[0], prof.centre[1],
44+
prof.kappa_s, prof.scale_radius, prof.truncation_radius,
45+
]
46+
47+
params = np.zeros((n_planes, max_n, n_params))
48+
mask = np.zeros((n_planes, max_n), dtype=bool)
49+
sheet_kappas = np.zeros(n_planes)
50+
51+
z_to_plane = {}
52+
for i, z in enumerate(plane_redshifts):
53+
z_to_plane[round(float(z), 8)] = i
54+
55+
for g in galaxies:
56+
z_key = round(float(g.redshift), 8)
57+
plane_i = z_to_plane.get(z_key)
58+
if plane_i is None:
59+
continue
60+
61+
if hasattr(g, "mass_sheet"):
62+
sheet_kappas[plane_i] = float(g.mass_sheet.kappa)
63+
elif hasattr(g, "mass") and isinstance(g.mass, profile_cls):
64+
slot = int(mask[plane_i].sum())
65+
if slot < max_n:
66+
params[plane_i, slot] = extract(g.mass)
67+
mask[plane_i, slot] = True
68+
69+
return jnp.array(params), jnp.array(mask), jnp.array(sheet_kappas)
70+
71+
72+
def traced_grids_via_scan(
73+
grid,
74+
halo_params,
75+
halo_mask,
76+
scaling_matrix,
77+
macro_deflections_fn,
78+
macro_plane_mask,
79+
sheet_kappas,
80+
halo_profile_cls,
81+
):
82+
import jax
83+
import jax.numpy as jnp
84+
85+
n_planes = halo_params.shape[0]
86+
n_grid = grid.shape[0]
87+
88+
init_defl_buffer = jnp.zeros((n_planes, n_grid, 2))
89+
90+
def scan_step(carry, plane_inputs):
91+
grid_0, defl_buffer, plane_idx = carry
92+
halo_p, halo_m, scaling_row, is_macro, sheet_kappa = plane_inputs
93+
94+
scaled = jnp.einsum("p,pmd->md", scaling_row, defl_buffer)
95+
current_grid = grid_0 - scaled
96+
97+
halo_defl = halo_profile_cls.vmapped_deflections_from(
98+
current_grid, halo_p, halo_m
99+
)
100+
101+
macro_defl = macro_deflections_fn(current_grid)
102+
macro_defl = is_macro * macro_defl
103+
104+
sheet_defl = sheet_kappa * current_grid
105+
106+
total_defl = halo_defl + macro_defl + sheet_defl
107+
defl_buffer = defl_buffer.at[plane_idx].set(total_defl)
108+
109+
return (grid_0, defl_buffer, plane_idx + 1), current_grid
110+
111+
plane_stack = (
112+
halo_params,
113+
halo_mask,
114+
scaling_matrix,
115+
macro_plane_mask,
116+
sheet_kappas,
117+
)
118+
119+
init_carry = (grid, init_defl_buffer, 0)
120+
_, traced_grids = jax.lax.scan(scan_step, init_carry, plane_stack)
121+
122+
return traced_grids

0 commit comments

Comments
 (0)