11import numpy as np
22
3+ from typing import Tuple , Sequence , Callable
34
4- class MassProfileMGE :
5+ import autoarray as aa
6+ from autogalaxy .profiles .geometry_profiles import EllProfile
7+
8+
9+ class MassProfileMGE (EllProfile ):
510 """
611 This class speeds up deflection angle calculations of certain mass profiles by decompositing them into many
712 Gaussians.
813
914 This follows the method of Shajib 2019 - https://academic.oup.com/mnras/article/488/1/1387/5526256
1015 """
1116
12- def __init__ (self ):
13- self .count = 0
14- self .sigma_calc = 0
15- self .z = 0
16- self .zq = 0
17- self .expv = 0
18-
19- @staticmethod
20- def zeta_from (grid , amps , sigmas , axis_ratio ):
21- """
22- The key part to compute the deflection angle of each Gaussian.
23- Because of my optimization, there are blocks looking weird and indirect. What I'm doing here
24- is trying to avoid big matrix operation to save time.
25- I think there are still spaces we can optimize.
26-
27- It seems when using w_f_approx, it gives some errors if y < 0. So when computing for places
28- where y < 0, we first compute the value at - y, and then change its sign.
29- """
30-
31- output_grid_final = np .zeros (grid .shape [0 ], dtype = "complex128" )
32-
33- q2 = axis_ratio ** 2.0
34-
35- scale_factor = axis_ratio / (sigmas [0 ] * np .sqrt (2.0 * (1.0 - q2 )))
36-
37- xs = np .array ((grid .array [:, 1 ] * scale_factor ).copy ())
38- ys = np .array ((grid .array [:, 0 ] * scale_factor ).copy ())
39-
40- ys_minus = ys < 0.0
41- ys [ys_minus ] *= - 1
42- z = xs + 1j * ys
43- zq = axis_ratio * xs + 1j * ys / axis_ratio
44-
45- expv = - (xs ** 2.0 ) * (1.0 - q2 ) - ys ** 2.0 * (1.0 / q2 - 1.0 )
46-
47- for i in range (len (sigmas )):
48- if i > 0 :
49- z /= sigmas [i ] / sigmas [i - 1 ]
50- zq /= sigmas [i ] / sigmas [i - 1 ]
51- expv /= (sigmas [i ] / sigmas [i - 1 ]) ** 2.0
52-
53- output_grid = - 1j * (w_f_approx (z ) - np .exp (expv ) * w_f_approx (zq ))
54-
55- output_grid [ys_minus ] = np .conj (output_grid [ys_minus ])
56-
57- output_grid_final += (amps [i ] * sigmas [i ]) * output_grid
17+ def __init__ (
18+ self ,
19+ func : Callable ,
20+ sigmas : Sequence [float ],
21+ func_terms : int = 28 ,
22+ centre : Tuple [float , float ] = (0.0 , 0.0 ),
23+ ell_comps : Tuple [float , float ] = (0.0 , 0.0 ),
24+ ):
25+ super ().__init__ (centre = centre , ell_comps = ell_comps )
26+ self .func = func
27+ self .sigmas = sigmas
28+ self .func_terms = func_terms
5829
59- return output_grid_final
6030
6131 @staticmethod
6232 def kesi (p , xp = np ):
@@ -66,6 +36,7 @@ def kesi(p, xp=np):
6636 n_list = xp .arange (0 , 2 * p + 1 , 1 )
6737 return (2.0 * p * xp .log (10 ) / 3.0 + 2.0 * xp .pi * n_list * 1j ) ** (0.5 )
6838
39+
6940 @staticmethod
7041 def eta (p , xp = np ):
7142 """
@@ -85,8 +56,96 @@ def eta(p, xp=np):
8556 return eta_list
8657
8758
59+ @staticmethod
60+ def wofz (z , xp = np ):
61+ """
62+ JAX-compatible Faddeeva function w(z) = exp(-z^2) * erfc(-i z)
63+ Based on the Poppe–Wijers / Zaghloul–Ali rational approximations.
64+ Valid for all complex z. JIT + autodiff safe.
65+ """
66+
67+ z = xp .asarray (z , dtype = xp .complex128 )
68+ x = xp .real (z )
69+ y = xp .imag (z )
70+
71+ r2 = x * x + y * y
72+ y2 = y * y
73+ z2 = z * z
74+
75+ sqrt_pi = xp .asarray (xp .sqrt (xp .pi ), dtype = xp .float64 )
76+ inv_sqrt_pi = xp .asarray (1.0 / sqrt_pi , dtype = xp .float64 )
77+
78+ # ---------- Large-|z| continued fraction ----------
79+ r1_s1 = xp .asarray ([2.5 , 2.0 , 1.5 , 1.0 , 0.5 ], dtype = xp .float64 )
80+
81+ t = z
82+ for c in r1_s1 :
83+ t = z - c / t
84+
85+ w_large = 1j * inv_sqrt_pi / t
86+
87+ # ---------- Region 5 ----------
88+ U5 = xp .asarray (
89+ [1.320522 , 35.7668 , 219.031 , 1540.787 , 3321.990 , 36183.31 ], dtype = xp .float64
90+ )
91+ V5 = xp .asarray (
92+ [1.841439 , 61.57037 , 364.2191 , 2186.181 , 9022.228 , 24322.84 , 32066.6 ],
93+ dtype = xp .float64 ,
94+ )
95+
96+ t = inv_sqrt_pi
97+ for u in U5 :
98+ t = u + z2 * t
99+
100+ s = xp .asarray (1.0 , dtype = xp .float64 )
101+ for v in V5 :
102+ s = v + z2 * s
103+
104+ w5 = xp .exp (- z2 ) + 1j * z * t / s
105+
106+ # ---------- Region 6 ----------
107+ U6 = xp .asarray (
108+ [5.9126262 , 30.180142 , 93.15558 , 181.92853 , 214.38239 , 122.60793 ],
109+ dtype = xp .float64 ,
110+ )
111+ V6 = xp .asarray (
112+ [
113+ 10.479857 ,
114+ 53.992907 ,
115+ 170.35400 ,
116+ 348.70392 ,
117+ 457.33448 ,
118+ 352.73063 ,
119+ 122.60793 ,
120+ ],
121+ dtype = xp .float64 ,
122+ )
123+
124+ t = inv_sqrt_pi
125+ for u in U6 :
126+ t = u - 1j * z * t
127+
128+ s = xp .asarray (1.0 , dtype = xp .float64 )
129+ for v in V6 :
130+ s = v - 1j * z * s
131+
132+ w6 = t / s
133+
134+ # ---------- Region logic ----------
135+ reg1 = (r2 >= 62.0 ) | ((r2 >= 30.0 ) & (r2 < 62.0 ) & (y2 >= 1e-13 ))
136+ reg2 = ((r2 >= 30 ) & (r2 < 62 ) & (y2 < 1e-13 )) | (
137+ (r2 >= 2.5 ) & (r2 < 30 ) & (y2 < 0.072 )
138+ )
139+
140+ w = w6
141+ w = xp .where (reg2 , w5 , w )
142+ w = xp .where (reg1 , w_large , w )
143+
144+ return w
145+
146+
88147 def decompose_convergence_via_mge (
89- self , func , radii_min , radii_max , func_terms = 28 , func_gaussians = 20 , xp = np
148+ self , xp = np
90149 ):
91150 """
92151
@@ -104,17 +163,18 @@ def decompose_convergence_via_mge(
104163 Returns
105164 -------
106165 """
107- kesis = self .kesi (func_terms , xp = xp ) # kesi in Eq.(6) of 1906.08263
108- etas = self .eta (func_terms , xp = xp ) # eta in Eqr.(6) of 1906.08263
166+ kesis = self .kesi (self . func_terms , xp = xp ) # kesi in Eq.(6) of 1906.08263
167+ etas = self .eta (self . func_terms , xp = xp ) # eta in Eqr.(6) of 1906.08263
109168
110- # sigma is sampled from logspace between these radii.
169+ sigmas = xp . array ( self . sigmas )
111170
112- log_sigmas = xp .linspace (xp .log (radii_min ), xp .log (radii_max ), func_gaussians )
171+ #log_sigmas = xp.linspace(xp.log(radii_min), xp.log(radii_max), func_gaussians)
172+ log_sigmas = xp .log (sigmas )
113173 d_log_sigma = log_sigmas [1 ] - log_sigmas [0 ]
114- sigma_list = xp .exp (log_sigmas )
174+ # sigma_list = xp.exp(log_sigmas)
115175
116176 f_sigma = xp .sum (
117- etas * xp .real (func (sigma_list .reshape (- 1 , 1 ) * kesis )), axis = 1
177+ etas * xp .real (self . func (sigmas .reshape (- 1 , 1 ) * kesis )), axis = 1
118178 )
119179
120180 amplitude_list = f_sigma * d_log_sigma / xp .sqrt (2.0 * xp .pi )
@@ -125,59 +185,67 @@ def decompose_convergence_via_mge(
125185 amplitude_list = amplitude_list .at [0 ].multiply (0.5 )
126186 amplitude_list = amplitude_list .at [- 1 ].multiply (0.5 )
127187
128- return amplitude_list , sigma_list
188+ return amplitude_list , sigmas
129189
130- def convergence_2d_via_mge_from (self , grid_radii ):
131- raise NotImplementedError ()
132190
133- def _convergence_2d_via_mge_from (self , grid_radii , ** kwargs ):
134- """Calculate the projected convergence at a given set of arc-second gridded coordinates.
135-
136- Parameters
137- ----------
138- grid
139- The grid of (y,x) arc-second coordinates the convergence is computed on.
191+ @aa .grid_dec .to_vector_yx
192+ @aa .grid_dec .transform
193+ def _deflections_2d_via_mge_from (
194+ self , grid : aa .type .Grid2DLike , xp = np , ** kwargs ,
195+ ):
196+ amps , sigmas = self .decompose_convergence_via_mge (xp = xp )
140197
141- """
198+ deflection_angles = (
199+ amps [:, None ]
200+ * sigmas [:, None ]
201+ * xp .sqrt ((2.0 * xp .pi ) / (1.0 - self .axis_ratio (xp )** 2.0 ))
202+ * self .zeta_from (grid = grid , xp = xp )
203+ )
142204
143- self .count = 0
144- self .sigma_calc = 0
145- self .z = 0
146- self .zq = 0
147- self .expv = 0
205+ # Add Gaussian profiles
206+ deflections = xp .sum (deflection_angles , axis = 0 )
148207
149- amps , sigmas = self .decompose_convergence_via_mge ()
208+ return self .rotated_grid_from_reference_frame_from (
209+ xp .multiply (
210+ 1.0 , xp .vstack ((- 1.0 * xp .imag (deflections ), xp .real (deflections ))).T
211+ ),
212+ xp = xp ,
213+ )
150214
151- convergence = 0.0
215+ def axis_ratio (self , xp = np ):
216+ axis_ratio = super ().axis_ratio (xp = xp )
217+ return xp .where (axis_ratio < 0.9999 , axis_ratio , 0.9999 )
152218
153- for i in range (len (sigmas )):
154- convergence += self .convergence_func_gaussian (
155- grid_radii = grid_radii .array , sigma = sigmas [i ], intensity = amps [i ]
156- )
157- return convergence
158219
159- def convergence_func_gaussian (self , grid_radii , sigma , intensity ):
160- return np .multiply (
161- intensity , np .exp (- 0.5 * np .square (np .divide (grid_radii , sigma )))
162- )
220+ def zeta_from (self , grid : aa .type .Grid2DLike , xp = np ):
221+ q = xp .asarray (self .axis_ratio (xp ), dtype = xp .float64 )
222+ q2 = q * q
163223
164- def _deflections_2d_via_mge_from (
165- self , grid , sigmas_factor = 1.0 , func_terms = None , func_gaussians = None
166- ):
167- axis_ratio = np .array (self .axis_ratio ())
224+ y = xp .asarray (grid .array [:, 0 ], dtype = xp .float64 )
225+ x = xp .asarray (grid .array [:, 1 ], dtype = xp .float64 )
168226
169- if axis_ratio > 0.9999 :
170- axis_ratio = 0.9999
227+ ind_pos_y = y >= 0
171228
172- amps , sigmas = self .decompose_convergence_via_mge ()
173- sigmas *= sigmas_factor
229+ sigmas = xp .asarray (self .sigmas , dtype = xp .float64 )[:, None ] # (S,1)
174230
175- angle = self . zeta_from (
176- grid = grid , amps = amps , sigmas = sigmas , axis_ratio = axis_ratio
231+ scale = q / (
232+ sigmas * xp . sqrt ( xp . asarray ( 2.0 , dtype = xp . float64 ) * ( 1.0 - q2 ))
177233 )
178234
179- angle *= np .sqrt ((2.0 * np .pi ) / (1.0 - axis_ratio ** 2.0 ))
235+ xs = x [None , :] * scale
236+ ys = xp .abs (y )[None , :] * scale
180237
181- return self .rotated_grid_from_reference_frame_from (
182- np .vstack ((- angle .imag , angle .real )).T
238+ z1 = xs + 1j * ys
239+ z2 = q * xs + 1j * ys / q
240+
241+ exp_term = xp .exp (
242+ - (xs * xs ) * (1.0 - q2 )
243+ - (ys * ys ) * (1.0 / q2 - 1.0 )
183244 )
245+
246+ core = - 1j * (
247+ self .wofz (z1 , xp = xp )
248+ - exp_term * self .wofz (z2 , xp = xp )
249+ )
250+
251+ return xp .where (ind_pos_y [None , :], core , xp .conj (core ))
0 commit comments