Skip to content

Commit 66b7a5d

Browse files
author
NiekWielders
committed
revamped whole structure, added deflections (for spherical case so far)
1 parent 0172ae7 commit 66b7a5d

2 files changed

Lines changed: 163 additions & 95 deletions

File tree

autogalaxy/profiles/mass/abstract/gaussian_utils.py

Whitespace-only changes.
Lines changed: 163 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,32 @@
11
import numpy as np
22

3+
from typing import Tuple, Sequence, Callable
34

4-
class MassProfileMGE:
5+
import autoarray as aa
6+
from autogalaxy.profiles.geometry_profiles import EllProfile
7+
8+
9+
class MassProfileMGE(EllProfile):
510
"""
611
This class speeds up deflection angle calculations of certain mass profiles by decompositing them into many
712
Gaussians.
813
914
This follows the method of Shajib 2019 - https://academic.oup.com/mnras/article/488/1/1387/5526256
1015
"""
1116

12-
def __init__(self):
13-
self.count = 0
14-
self.sigma_calc = 0
15-
self.z = 0
16-
self.zq = 0
17-
self.expv = 0
18-
19-
@staticmethod
20-
def zeta_from(grid, amps, sigmas, axis_ratio):
21-
"""
22-
The key part to compute the deflection angle of each Gaussian.
23-
Because of my optimization, there are blocks looking weird and indirect. What I'm doing here
24-
is trying to avoid big matrix operation to save time.
25-
I think there are still spaces we can optimize.
26-
27-
It seems when using w_f_approx, it gives some errors if y < 0. So when computing for places
28-
where y < 0, we first compute the value at - y, and then change its sign.
29-
"""
30-
31-
output_grid_final = np.zeros(grid.shape[0], dtype="complex128")
32-
33-
q2 = axis_ratio**2.0
34-
35-
scale_factor = axis_ratio / (sigmas[0] * np.sqrt(2.0 * (1.0 - q2)))
36-
37-
xs = np.array((grid.array[:, 1] * scale_factor).copy())
38-
ys = np.array((grid.array[:, 0] * scale_factor).copy())
39-
40-
ys_minus = ys < 0.0
41-
ys[ys_minus] *= -1
42-
z = xs + 1j * ys
43-
zq = axis_ratio * xs + 1j * ys / axis_ratio
44-
45-
expv = -(xs**2.0) * (1.0 - q2) - ys**2.0 * (1.0 / q2 - 1.0)
46-
47-
for i in range(len(sigmas)):
48-
if i > 0:
49-
z /= sigmas[i] / sigmas[i - 1]
50-
zq /= sigmas[i] / sigmas[i - 1]
51-
expv /= (sigmas[i] / sigmas[i - 1]) ** 2.0
52-
53-
output_grid = -1j * (w_f_approx(z) - np.exp(expv) * w_f_approx(zq))
54-
55-
output_grid[ys_minus] = np.conj(output_grid[ys_minus])
56-
57-
output_grid_final += (amps[i] * sigmas[i]) * output_grid
17+
def __init__(
18+
self,
19+
func: Callable,
20+
sigmas: Sequence[float],
21+
func_terms: int = 28,
22+
centre: Tuple[float, float] = (0.0, 0.0),
23+
ell_comps: Tuple[float, float] = (0.0, 0.0),
24+
):
25+
super().__init__(centre=centre, ell_comps=ell_comps)
26+
self.func = func
27+
self.sigmas = sigmas
28+
self.func_terms = func_terms
5829

59-
return output_grid_final
6030

6131
@staticmethod
6232
def kesi(p, xp=np):
@@ -66,6 +36,7 @@ def kesi(p, xp=np):
6636
n_list = xp.arange(0, 2 * p + 1, 1)
6737
return (2.0 * p * xp.log(10) / 3.0 + 2.0 * xp.pi * n_list * 1j) ** (0.5)
6838

39+
6940
@staticmethod
7041
def eta(p, xp=np):
7142
"""
@@ -85,8 +56,96 @@ def eta(p, xp=np):
8556
return eta_list
8657

8758

59+
@staticmethod
60+
def wofz(z, xp=np):
61+
"""
62+
JAX-compatible Faddeeva function w(z) = exp(-z^2) * erfc(-i z)
63+
Based on the Poppe–Wijers / Zaghloul–Ali rational approximations.
64+
Valid for all complex z. JIT + autodiff safe.
65+
"""
66+
67+
z = xp.asarray(z, dtype=xp.complex128)
68+
x = xp.real(z)
69+
y = xp.imag(z)
70+
71+
r2 = x * x + y * y
72+
y2 = y * y
73+
z2 = z * z
74+
75+
sqrt_pi = xp.asarray(xp.sqrt(xp.pi), dtype=xp.float64)
76+
inv_sqrt_pi = xp.asarray(1.0 / sqrt_pi, dtype=xp.float64)
77+
78+
# ---------- Large-|z| continued fraction ----------
79+
r1_s1 = xp.asarray([2.5, 2.0, 1.5, 1.0, 0.5], dtype=xp.float64)
80+
81+
t = z
82+
for c in r1_s1:
83+
t = z - c / t
84+
85+
w_large = 1j * inv_sqrt_pi / t
86+
87+
# ---------- Region 5 ----------
88+
U5 = xp.asarray(
89+
[1.320522, 35.7668, 219.031, 1540.787, 3321.990, 36183.31], dtype=xp.float64
90+
)
91+
V5 = xp.asarray(
92+
[1.841439, 61.57037, 364.2191, 2186.181, 9022.228, 24322.84, 32066.6],
93+
dtype=xp.float64,
94+
)
95+
96+
t = inv_sqrt_pi
97+
for u in U5:
98+
t = u + z2 * t
99+
100+
s = xp.asarray(1.0, dtype=xp.float64)
101+
for v in V5:
102+
s = v + z2 * s
103+
104+
w5 = xp.exp(-z2) + 1j * z * t / s
105+
106+
# ---------- Region 6 ----------
107+
U6 = xp.asarray(
108+
[5.9126262, 30.180142, 93.15558, 181.92853, 214.38239, 122.60793],
109+
dtype=xp.float64,
110+
)
111+
V6 = xp.asarray(
112+
[
113+
10.479857,
114+
53.992907,
115+
170.35400,
116+
348.70392,
117+
457.33448,
118+
352.73063,
119+
122.60793,
120+
],
121+
dtype=xp.float64,
122+
)
123+
124+
t = inv_sqrt_pi
125+
for u in U6:
126+
t = u - 1j * z * t
127+
128+
s = xp.asarray(1.0, dtype=xp.float64)
129+
for v in V6:
130+
s = v - 1j * z * s
131+
132+
w6 = t / s
133+
134+
# ---------- Region logic ----------
135+
reg1 = (r2 >= 62.0) | ((r2 >= 30.0) & (r2 < 62.0) & (y2 >= 1e-13))
136+
reg2 = ((r2 >= 30) & (r2 < 62) & (y2 < 1e-13)) | (
137+
(r2 >= 2.5) & (r2 < 30) & (y2 < 0.072)
138+
)
139+
140+
w = w6
141+
w = xp.where(reg2, w5, w)
142+
w = xp.where(reg1, w_large, w)
143+
144+
return w
145+
146+
88147
def decompose_convergence_via_mge(
89-
self, func, radii_min, radii_max, func_terms=28, func_gaussians=20, xp=np
148+
self, xp=np
90149
):
91150
"""
92151
@@ -104,17 +163,18 @@ def decompose_convergence_via_mge(
104163
Returns
105164
-------
106165
"""
107-
kesis = self.kesi(func_terms, xp=xp) # kesi in Eq.(6) of 1906.08263
108-
etas = self.eta(func_terms, xp=xp) # eta in Eqr.(6) of 1906.08263
166+
kesis = self.kesi(self.func_terms, xp=xp) # kesi in Eq.(6) of 1906.08263
167+
etas = self.eta(self.func_terms, xp=xp) # eta in Eqr.(6) of 1906.08263
109168

110-
# sigma is sampled from logspace between these radii.
169+
sigmas = xp.array(self.sigmas)
111170

112-
log_sigmas = xp.linspace(xp.log(radii_min), xp.log(radii_max), func_gaussians)
171+
#log_sigmas = xp.linspace(xp.log(radii_min), xp.log(radii_max), func_gaussians)
172+
log_sigmas = xp.log(sigmas)
113173
d_log_sigma = log_sigmas[1] - log_sigmas[0]
114-
sigma_list = xp.exp(log_sigmas)
174+
#sigma_list = xp.exp(log_sigmas)
115175

116176
f_sigma = xp.sum(
117-
etas * xp.real(func(sigma_list.reshape(-1, 1) * kesis)), axis=1
177+
etas * xp.real(self.func(sigmas.reshape(-1, 1) * kesis)), axis=1
118178
)
119179

120180
amplitude_list = f_sigma * d_log_sigma / xp.sqrt(2.0 * xp.pi)
@@ -125,59 +185,67 @@ def decompose_convergence_via_mge(
125185
amplitude_list = amplitude_list.at[0].multiply(0.5)
126186
amplitude_list = amplitude_list.at[-1].multiply(0.5)
127187

128-
return amplitude_list, sigma_list
188+
return amplitude_list, sigmas
129189

130-
def convergence_2d_via_mge_from(self, grid_radii):
131-
raise NotImplementedError()
132190

133-
def _convergence_2d_via_mge_from(self, grid_radii, **kwargs):
134-
"""Calculate the projected convergence at a given set of arc-second gridded coordinates.
135-
136-
Parameters
137-
----------
138-
grid
139-
The grid of (y,x) arc-second coordinates the convergence is computed on.
191+
@aa.grid_dec.to_vector_yx
192+
@aa.grid_dec.transform
193+
def _deflections_2d_via_mge_from(
194+
self, grid: aa.type.Grid2DLike, xp=np, **kwargs,
195+
):
196+
amps, sigmas = self.decompose_convergence_via_mge(xp=xp)
140197

141-
"""
198+
deflection_angles = (
199+
amps[:, None]
200+
* sigmas[:, None]
201+
* xp.sqrt((2.0 * xp.pi) / (1.0 - self.axis_ratio(xp)**2.0))
202+
* self.zeta_from(grid=grid, xp=xp)
203+
)
142204

143-
self.count = 0
144-
self.sigma_calc = 0
145-
self.z = 0
146-
self.zq = 0
147-
self.expv = 0
205+
# Add Gaussian profiles
206+
deflections = xp.sum(deflection_angles, axis=0)
148207

149-
amps, sigmas = self.decompose_convergence_via_mge()
208+
return self.rotated_grid_from_reference_frame_from(
209+
xp.multiply(
210+
1.0, xp.vstack((-1.0 * xp.imag(deflections), xp.real(deflections))).T
211+
),
212+
xp=xp,
213+
)
150214

151-
convergence = 0.0
215+
def axis_ratio(self, xp=np):
216+
axis_ratio = super().axis_ratio(xp=xp)
217+
return xp.where(axis_ratio < 0.9999, axis_ratio, 0.9999)
152218

153-
for i in range(len(sigmas)):
154-
convergence += self.convergence_func_gaussian(
155-
grid_radii=grid_radii.array, sigma=sigmas[i], intensity=amps[i]
156-
)
157-
return convergence
158219

159-
def convergence_func_gaussian(self, grid_radii, sigma, intensity):
160-
return np.multiply(
161-
intensity, np.exp(-0.5 * np.square(np.divide(grid_radii, sigma)))
162-
)
220+
def zeta_from(self, grid: aa.type.Grid2DLike, xp=np):
221+
q = xp.asarray(self.axis_ratio(xp), dtype=xp.float64)
222+
q2 = q * q
163223

164-
def _deflections_2d_via_mge_from(
165-
self, grid, sigmas_factor=1.0, func_terms=None, func_gaussians=None
166-
):
167-
axis_ratio = np.array(self.axis_ratio())
224+
y = xp.asarray(grid.array[:, 0], dtype=xp.float64)
225+
x = xp.asarray(grid.array[:, 1], dtype=xp.float64)
168226

169-
if axis_ratio > 0.9999:
170-
axis_ratio = 0.9999
227+
ind_pos_y = y >= 0
171228

172-
amps, sigmas = self.decompose_convergence_via_mge()
173-
sigmas *= sigmas_factor
229+
sigmas = xp.asarray(self.sigmas, dtype=xp.float64)[:, None] # (S,1)
174230

175-
angle = self.zeta_from(
176-
grid=grid, amps=amps, sigmas=sigmas, axis_ratio=axis_ratio
231+
scale = q / (
232+
sigmas * xp.sqrt(xp.asarray(2.0, dtype=xp.float64) * (1.0 - q2))
177233
)
178234

179-
angle *= np.sqrt((2.0 * np.pi) / (1.0 - axis_ratio**2.0))
235+
xs = x[None, :] * scale
236+
ys = xp.abs(y)[None, :] * scale
180237

181-
return self.rotated_grid_from_reference_frame_from(
182-
np.vstack((-angle.imag, angle.real)).T
238+
z1 = xs + 1j * ys
239+
z2 = q * xs + 1j * ys / q
240+
241+
exp_term = xp.exp(
242+
-(xs * xs) * (1.0 - q2)
243+
- (ys * ys) * (1.0 / q2 - 1.0)
183244
)
245+
246+
core = -1j * (
247+
self.wofz(z1, xp=xp)
248+
- exp_term * self.wofz(z2, xp=xp)
249+
)
250+
251+
return xp.where(ind_pos_y[None, :], core, xp.conj(core))

0 commit comments

Comments
 (0)