Skip to content

Commit 85f29c9

Browse files
author
NiekWielders
committed
no circular imports
1 parent cd5c119 commit 85f29c9

1 file changed

Lines changed: 87 additions & 3 deletions

File tree

  • autogalaxy/profiles/mass/abstract

autogalaxy/profiles/mass/abstract/mge.py

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,92 @@
11
# from .jax_utils import w_f_approx
22
import numpy as np
33

4-
from autogalaxy.profiles.mass.stellar.gaussian import Gaussian
54

5+
def wofz(z, xp=np):
6+
"""
7+
JAX-compatible Faddeeva function w(z) = exp(-z^2) * erfc(-i z)
8+
Based on the Poppe–Wijers / Zaghloul–Ali rational approximations.
9+
Valid for all complex z. JIT + autodiff safe.
10+
"""
11+
12+
z = xp.asarray(z, dtype=xp.complex128)
13+
x = xp.real(z)
14+
y = xp.imag(z)
15+
16+
r2 = x * x + y * y
17+
y2 = y * y
18+
z2 = z * z
19+
20+
sqrt_pi = xp.asarray(xp.sqrt(xp.pi), dtype=xp.float64)
21+
inv_sqrt_pi = xp.asarray(1.0 / sqrt_pi, dtype=xp.float64)
22+
23+
# ---------- Large-|z| continued fraction ----------
24+
r1_s1 = xp.asarray([2.5, 2.0, 1.5, 1.0, 0.5], dtype=xp.float64)
25+
26+
t = z
27+
for c in r1_s1:
28+
t = z - c / t
29+
30+
w_large = 1j * inv_sqrt_pi / t
31+
32+
# ---------- Region 5 ----------
33+
U5 = xp.asarray(
34+
[1.320522, 35.7668, 219.031, 1540.787, 3321.990, 36183.31], dtype=xp.float64
35+
)
36+
V5 = xp.asarray(
37+
[1.841439, 61.57037, 364.2191, 2186.181, 9022.228, 24322.84, 32066.6],
38+
dtype=xp.float64,
39+
)
40+
41+
t = inv_sqrt_pi
42+
for u in U5:
43+
t = u + z2 * t
44+
45+
s = xp.asarray(1.0, dtype=xp.float64)
46+
for v in V5:
47+
s = v + z2 * s
48+
49+
w5 = xp.exp(-z2) + 1j * z * t / s
50+
51+
# ---------- Region 6 ----------
52+
U6 = xp.asarray(
53+
[5.9126262, 30.180142, 93.15558, 181.92853, 214.38239, 122.60793],
54+
dtype=xp.float64,
55+
)
56+
V6 = xp.asarray(
57+
[
58+
10.479857,
59+
53.992907,
60+
170.35400,
61+
348.70392,
62+
457.33448,
63+
352.73063,
64+
122.60793,
65+
],
66+
dtype=xp.float64,
67+
)
68+
69+
t = inv_sqrt_pi
70+
for u in U6:
71+
t = u - 1j * z * t
72+
73+
s = xp.asarray(1.0, dtype=xp.float64)
74+
for v in V6:
75+
s = v - 1j * z * s
76+
77+
w6 = t / s
78+
79+
# ---------- Region logic ----------
80+
reg1 = (r2 >= 62.0) | ((r2 >= 30.0) & (r2 < 62.0) & (y2 >= 1e-13))
81+
reg2 = ((r2 >= 30) & (r2 < 62) & (y2 < 1e-13)) | (
82+
(r2 >= 2.5) & (r2 < 30) & (y2 < 0.072)
83+
)
84+
85+
w = w6
86+
w = xp.where(reg2, w5, w)
87+
w = xp.where(reg1, w_large, w)
88+
89+
return w
690

791
class MassProfileMGE:
892
"""
@@ -39,8 +123,8 @@ def zeta_from(grid, amps, sigmas, axis_ratio, xp=np):
39123

40124
# process as one big vectorized calculation
41125
# could try `jax.lax.scan` instead if this is too much memory
42-
w = Gaussian.wofz(inv_sigma_ * z, xp=xp)
43-
wq = Gaussian.wofz(inv_sigma_ * zq, xp=xp)
126+
w = wofz(inv_sigma_ * z, xp=xp)
127+
wq = wofz(inv_sigma_ * zq, xp=xp)
44128
exp_factor = xp.exp(inv_sigma_**2 * expv)
45129

46130
sigma_func_real = w.imag - exp_factor * wq.imag

0 commit comments

Comments
 (0)