|
1 | 1 | # from .jax_utils import w_f_approx |
2 | 2 | import numpy as np |
3 | 3 |
|
4 | | -from autogalaxy.profiles.mass.stellar.gaussian import Gaussian |
5 | 4 |
|
| 5 | +def wofz(z, xp=np): |
| 6 | + """ |
| 7 | + JAX-compatible Faddeeva function w(z) = exp(-z^2) * erfc(-i z) |
| 8 | + Based on the Poppe–Wijers / Zaghloul–Ali rational approximations. |
| 9 | + Valid for all complex z. JIT + autodiff safe. |
| 10 | + """ |
| 11 | + |
| 12 | + z = xp.asarray(z, dtype=xp.complex128) |
| 13 | + x = xp.real(z) |
| 14 | + y = xp.imag(z) |
| 15 | + |
| 16 | + r2 = x * x + y * y |
| 17 | + y2 = y * y |
| 18 | + z2 = z * z |
| 19 | + |
| 20 | + sqrt_pi = xp.asarray(xp.sqrt(xp.pi), dtype=xp.float64) |
| 21 | + inv_sqrt_pi = xp.asarray(1.0 / sqrt_pi, dtype=xp.float64) |
| 22 | + |
| 23 | + # ---------- Large-|z| continued fraction ---------- |
| 24 | + r1_s1 = xp.asarray([2.5, 2.0, 1.5, 1.0, 0.5], dtype=xp.float64) |
| 25 | + |
| 26 | + t = z |
| 27 | + for c in r1_s1: |
| 28 | + t = z - c / t |
| 29 | + |
| 30 | + w_large = 1j * inv_sqrt_pi / t |
| 31 | + |
| 32 | + # ---------- Region 5 ---------- |
| 33 | + U5 = xp.asarray( |
| 34 | + [1.320522, 35.7668, 219.031, 1540.787, 3321.990, 36183.31], dtype=xp.float64 |
| 35 | + ) |
| 36 | + V5 = xp.asarray( |
| 37 | + [1.841439, 61.57037, 364.2191, 2186.181, 9022.228, 24322.84, 32066.6], |
| 38 | + dtype=xp.float64, |
| 39 | + ) |
| 40 | + |
| 41 | + t = inv_sqrt_pi |
| 42 | + for u in U5: |
| 43 | + t = u + z2 * t |
| 44 | + |
| 45 | + s = xp.asarray(1.0, dtype=xp.float64) |
| 46 | + for v in V5: |
| 47 | + s = v + z2 * s |
| 48 | + |
| 49 | + w5 = xp.exp(-z2) + 1j * z * t / s |
| 50 | + |
| 51 | + # ---------- Region 6 ---------- |
| 52 | + U6 = xp.asarray( |
| 53 | + [5.9126262, 30.180142, 93.15558, 181.92853, 214.38239, 122.60793], |
| 54 | + dtype=xp.float64, |
| 55 | + ) |
| 56 | + V6 = xp.asarray( |
| 57 | + [ |
| 58 | + 10.479857, |
| 59 | + 53.992907, |
| 60 | + 170.35400, |
| 61 | + 348.70392, |
| 62 | + 457.33448, |
| 63 | + 352.73063, |
| 64 | + 122.60793, |
| 65 | + ], |
| 66 | + dtype=xp.float64, |
| 67 | + ) |
| 68 | + |
| 69 | + t = inv_sqrt_pi |
| 70 | + for u in U6: |
| 71 | + t = u - 1j * z * t |
| 72 | + |
| 73 | + s = xp.asarray(1.0, dtype=xp.float64) |
| 74 | + for v in V6: |
| 75 | + s = v - 1j * z * s |
| 76 | + |
| 77 | + w6 = t / s |
| 78 | + |
| 79 | + # ---------- Region logic ---------- |
| 80 | + reg1 = (r2 >= 62.0) | ((r2 >= 30.0) & (r2 < 62.0) & (y2 >= 1e-13)) |
| 81 | + reg2 = ((r2 >= 30) & (r2 < 62) & (y2 < 1e-13)) | ( |
| 82 | + (r2 >= 2.5) & (r2 < 30) & (y2 < 0.072) |
| 83 | + ) |
| 84 | + |
| 85 | + w = w6 |
| 86 | + w = xp.where(reg2, w5, w) |
| 87 | + w = xp.where(reg1, w_large, w) |
| 88 | + |
| 89 | + return w |
6 | 90 |
|
7 | 91 | class MassProfileMGE: |
8 | 92 | """ |
@@ -39,8 +123,8 @@ def zeta_from(grid, amps, sigmas, axis_ratio, xp=np): |
39 | 123 |
|
40 | 124 | # process as one big vectorized calculation |
41 | 125 | # could try `jax.lax.scan` instead if this is too much memory |
42 | | - w = Gaussian.wofz(inv_sigma_ * z, xp=xp) |
43 | | - wq = Gaussian.wofz(inv_sigma_ * zq, xp=xp) |
| 126 | + w = wofz(inv_sigma_ * z, xp=xp) |
| 127 | + wq = wofz(inv_sigma_ * zq, xp=xp) |
44 | 128 | exp_factor = xp.exp(inv_sigma_**2 * expv) |
45 | 129 |
|
46 | 130 | sigma_func_real = w.imag - exp_factor * wq.imag |
|
0 commit comments