@@ -217,85 +217,129 @@ def zeta_from(self, grid: aa.type.Grid2DLike, xp=np):
217217 core = - 1j * (wofz (z1 ) - exp_term * wofz (z2 ))
218218
219219 if xp == jnp :
220- import jax .scipy .special as jsp
220+ # import jax.scipy.special as jsp
221221
222- core = - 1j * (xp . exp ( - z1 * z1 ) * jsp . erfc ( - 1j * z1 ) - exp_term * xp . exp ( - z2 * z2 ) * jsp . erfc ( - 1j * z2 ))
222+ core = - 1j * (self . wofz ( z1 , xp = xp ) - exp_term * self . wofz ( z2 , xp = xp ))
223223
224224 # symmetry: zeta(x, -y) = conj(zeta(x, y))
225225 return xp .where (y >= 0 , core , xp .conj (core ))
226226
227227
228- # def wofz(self, z, xp=np):
229- # """
230- # JAX-compatible Faddeeva function w(z) = exp(-z^2) * erfc(-i z)
231- # Based on the Poppe–Wijers / Zaghloul–Ali rational approximations.
232- # Valid for all complex z. JIT + autodiff safe.
233- # """
234- #
235- # # y = grid.array[:, 0]
236- # # x = grid.array[:, 1]
237- # # z = x + 1j * y
238- #
239- # z = xp.asarray(z, dtype=xp.complex128)
240- # x = xp.real(z)
241- # y = xp.imag(z)
242- #
243- # r2 = x * x + y * y
244- # y2 = y * y
245- # z2 = z * z
246- # sqrt_pi = xp.sqrt(xp.pi)
247- #
248- # # --- Region 1: |z|^2 >= 3.8e4 ---
249- # w1 = 1j / (z * sqrt_pi)
250- #
251- # # --- Region 2: 3.8e4 > |z|^2 >= 256 ---
252- # w2 = 1j * z / (sqrt_pi * (z2 - 0.5))
253- #
254- # # --- Region 3: 256 > |z|^2 >= 62 ---
255- # w3 = 1j * (z2 - 1.0) / (z * sqrt_pi * (z2 - 1.5))
256- #
257- # # --- Region 4: 62 > |z|^2 >= 30 and y^2 >= 1e-13 ---
258- # w4 = 1j * z * (z2 - 2.5) / (sqrt_pi * (z2 * (z2 - 3.0) + 0.75))
259- #
260- # # --- Region 5: special small-imaginary case ---
261- # U5 = xp.array([1.320522, 35.7668, 219.031, 1540.787, 3321.990, 36183.31], dtype=xp.float64)
262- # V5 = xp.array([1.841439, 61.57037, 364.2191, 2186.181,
263- # 9022.228, 24322.84, 32066.6], dtype=xp.float64)
264- #
265- # # Horner form in z^2
266- # num5 = sqrt_pi
267- # for k in range(0, 6):
268- # num5 = num5 * z2 + U5[k]
269- #
270- # den5 = 1.0
271- # for k in range(0, 7):
272- # den5 = den5 * z2 + V5[k]
273- #
274- # w5 = xp.exp(-z2) + 1j * z * num5 / den5
275- #
276- # # --- Region 6: remaining small-|z| region ---
277- # U6 = xp.array([5.9126262, 30.180142, 93.15558,
278- # 181.92853, 214.38239, 122.60793], dtype=xp.float64)
279- # V6 = xp.array([10.479857, 53.992907, 170.35400,
280- # 348.70392, 457.33448, 352.73063, 122.60793], dtype=xp.float64)
281- #
282- # num6 = sqrt_pi
283- # for k in range(0, 6):
284- # num6 = num6 * (-1j * z) + U6[k]
285- #
286- # den6 = 1
287- # for k in range(1, 7):
288- # den6 = den6 * (-1j * z) + V6[k]
289- #
290- # w6 = num6 / den6
291- #
292- # # --- Combine regions using pure array logic ---
293- # w = w6
294- # w = xp.where((r2 >= 2.5) & (y2 < 0.072) & (r2 < 30), w5, w)
295- # w = xp.where((r2 >= 30) & (r2 < 62) & (y2 < 1e-13), w5, w)
296- # w = xp.where((r2 >= 30) & (r2 < 62) & (y2 >= 1e-13), w4, w)
297- # w = xp.where((r2 >= 62) & (r2 < 256), w3, w)
298- # w = xp.where((r2 >= 256) & (r2 < 3.8e4), w2, w)
299- # w = xp.where(r2 >= 3.8e4, w1, w)
300- #
301- # return w
228+ def wofz (self , z , xp = np ):
229+ """
230+ JAX-compatible Faddeeva function w(z) = exp(-z^2) * erfc(-i z)
231+ Based on the Poppe–Wijers / Zaghloul–Ali rational approximations.
232+ Valid for all complex z. JIT + autodiff safe.
233+ """
234+
235+ # y = grid.array[:, 0]
236+ # x = grid.array[:, 1]
237+ # z = x + 1j * y
238+
239+ z = xp .complex128 (xp .asarray (z ))
240+ x = xp .real (z )
241+ y = xp .imag (z )
242+
243+ r2 = x * x + y * y
244+ y2 = y * y
245+ z2 = z * z
246+ sqrt_pi = xp .sqrt (xp .pi )
247+
248+ # --- Region 1: |z|^2 >= 3.8e4 ---
249+ w1 = 1j / (z * sqrt_pi )
250+
251+ # --- Region 2: 3.8e4 > |z|^2 >= 256 ---
252+ w2 = 1j * z / (sqrt_pi * (z2 - 0.5 ))
253+
254+ # --- Region 3: 256 > |z|^2 >= 62 ---
255+ w3 = 1j * (z2 - 1.0 ) / (z * sqrt_pi * (z2 - 1.5 ))
256+
257+ # --- Region 4: 62 > |z|^2 >= 30 and y^2 >= 1e-13 ---
258+ w4 = 1j * z * (z2 - 2.5 ) / (sqrt_pi * (z2 * (z2 - 3.0 ) + 0.75 ))
259+
260+ # --- Region 5: special small-imaginary case ---
261+ U5 = xp .float64 (xp .array ([1.320522 , 35.7668 , 219.031 , 1540.787 , 3321.990 , 36183.31 ]))
262+ V5 = xp .float64 (xp .array ([1.841439 , 61.57037 , 364.2191 , 2186.181 ,
263+ 9022.228 , 24322.84 , 32066.6 ]))
264+
265+ t = sqrt_pi
266+ t = U5 [0 ] + z2 * t
267+ t = U5 [1 ] + z2 * t
268+ t = U5 [2 ] + z2 * t
269+ t = U5 [3 ] + z2 * t
270+ t = U5 [4 ] + z2 * t
271+ num5 = U5 [5 ] + z2 * t
272+
273+ s = 1.0
274+ s = V5 [0 ] + z2 * s
275+ s = V5 [1 ] + z2 * s
276+ s = V5 [2 ] + z2 * s
277+ s = V5 [3 ] + z2 * s
278+ s = V5 [4 ] + z2 * s
279+ s = V5 [5 ] + z2 * s
280+ den5 = V5 [6 ] + z2 * s
281+
282+ #num5 = (U5[5] + z2 * (U5[4] + z2 * (U5[3] + z2 * (U5[2] + z2 * (U5[1] + z2 * (U5[0] + z2 * sqrt_pi))))))
283+
284+ #den5 = (V5[6] + z2 * (V5[5] + z2 * (V5[4] + z2 * (V5[3] + z2 * (V5[2] + z2 * (V5[1] + z2 * (V5[0] + z2)))))))
285+
286+ # Horner form in z^2
287+ # num5 = sqrt_pi
288+ # for k in range(0, 6):
289+ # num5 = num5 * z2 + U5[k]
290+ #
291+ # den5 = 1.0
292+ # for k in range(0, 7):
293+ # den5 = den5 * z2 + V5[k]
294+
295+ w5 = xp .exp (- z2 ) + 1j * z * num5 / den5
296+
297+ # --- Region 6: remaining small-|z| region ---
298+ U6 = xp .float64 (xp .array ([5.9126262 , 30.180142 , 93.15558 ,
299+ 181.92853 , 214.38239 , 122.60793 ]))
300+ V6 = xp .float64 (xp .array ([10.479857 , 53.992907 , 170.35400 ,
301+ 348.70392 , 457.33448 , 352.73063 , 122.60793 ]))
302+
303+ t = sqrt_pi
304+ t = U6 [0 ] - 1j * z * t
305+ t = U6 [1 ] - 1j * z * t
306+ t = U6 [2 ] - 1j * z * t
307+ t = U6 [3 ] - 1j * z * t
308+ t = U6 [4 ] - 1j * z * t
309+ num6 = U6 [5 ] - 1j * z * t
310+
311+ s = 1.0
312+ s = V6 [0 ] - 1j * z * s
313+ s = V6 [1 ] - 1j * z * s
314+ s = V6 [2 ] - 1j * z * s
315+ s = V6 [3 ] - 1j * z * s
316+ s = V6 [4 ] - 1j * z * s
317+ s = V6 [5 ] - 1j * z * s
318+ den6 = V6 [6 ] - 1j * z * s
319+
320+ # num6 = (U6[5] - 1j * z * (U6[4] - 1j * z * (U6[3] - 1j * z * (U6[2] - 1j * z * (U6[1] - 1j * z *
321+ # (U6[0] - 1j * z * sqrt_pi))))))
322+ #
323+ # den6 = (V6[6] - 1j * z * (V6[5] - 1j * z * (V6[4] - 1j * z * (V6[3] - 1j * z * (V6[2] - 1j * z * (V6[1] - 1j * z *
324+ # (V6[0] - 1j * z)))))))
325+
326+ #num6 = sqrt_pi
327+ # for k in range(0, 6):
328+ # num6 = num6 * (-1j * z) + U6[k]
329+ #
330+ # den6 = 1
331+ # for k in range(1, 7):
332+ # den6 = den6 * (-1j * z) + V6[k]
333+
334+ w6 = num6 / den6
335+
336+ # --- Combine regions using pure array logic ---
337+ w = w6
338+ w = xp .where ((r2 >= 2.5 ) & (y2 < 0.072 ) & (r2 < 30 ), w5 , w )
339+ w = xp .where ((r2 >= 30 ) & (r2 < 62 ) & (y2 < 1e-13 ), w5 , w )
340+ w = xp .where ((r2 >= 30 ) & (r2 < 62 ) & (y2 >= 1e-13 ), w4 , w )
341+ w = xp .where ((r2 >= 62 ) & (r2 < 256 ), w3 , w )
342+ w = xp .where ((r2 >= 256 ) & (r2 < 3.8e4 ), w2 , w )
343+ w = xp .where (r2 >= 3.8e4 , w1 , w )
344+
345+ return w
0 commit comments