11import numpy as np
22import jax .numpy as jnp
3- import jax . scipy . special as jsp
3+
44
55from typing import Tuple
66
@@ -195,9 +195,6 @@ def axis_ratio(self, xp=np):
195195 return xp .where (axis_ratio < 0.9999 , axis_ratio , 0.9999 )
196196
197197 def zeta_from (self , grid : aa .type .Grid2DLike , xp = np ):
198-
199- from scipy .special import wofz
200-
201198 q = self .axis_ratio (xp )
202199 q2 = q ** 2.0
203200
@@ -214,65 +211,19 @@ def zeta_from(self, grid: aa.type.Grid2DLike, xp=np):
214211
215212 exp_term = xp .exp (- (xs ** 2 ) * (1.0 - q2 ) - ys ** 2 * (1.0 / q2 - 1.0 ))
216213
217- core = - 1j * (wofz (z1 ) - exp_term * wofz (z2 ))
214+ if xp == np :
215+ from scipy .special import wofz
216+
217+ core = - 1j * (wofz (z1 ) - exp_term * wofz (z2 ))
218+
219+ if xp == jnp :
220+ import jax .scipy .special as jsp
221+
222+ core = - 1j * (xp .exp (- z1 * z1 ) * jsp .erfc (- 1j * z1 ) - exp_term * xp .exp (- z2 * z2 ) * jsp .erfc (- 1j * z2 ))
218223
219224 # symmetry: zeta(x, -y) = conj(zeta(x, y))
220225 return xp .where (y >= 0 , core , xp .conj (core ))
221226
222- # q = self.axis_ratio(xp)
223- # q2 = q ** 2.0
224- # ind_pos_y = grid.array[:, 0] >= 0
225- # shape_grid = xp.shape(grid)
226- # output_grid = xp.zeros((shape_grid[0]), dtype=xp.complex128)
227- # scale_factor = q / (self.sigma * xp.sqrt(2.0 * (1.0 - q2)))
228- #
229- # xs_0 = grid.array[:, 1][ind_pos_y] * scale_factor
230- # ys_0 = grid.array[:, 0][ind_pos_y] * scale_factor
231- # xs_1 = grid.array[:, 1][~ind_pos_y] * scale_factor
232- # ys_1 = -grid.array[:, 0][~ind_pos_y] * scale_factor
233- #
234- # if xp == np:
235- # from scipy.special import wofz
236- #
237- # output_grid[ind_pos_y] = -1j * (
238- # wofz(xs_0 + 1j * ys_0)
239- # - np.exp(-(xs_0 ** 2.0) * (1.0 - q2) - ys_0 * ys_0 * (1.0 / q2 - 1.0))
240- # * wofz(q * xs_0 + 1j * ys_0 / q)
241- # )
242- #
243- # output_grid[~ind_pos_y] = np.conj(
244- # -1j
245- # * (
246- # wofz(xs_1 + 1j * ys_1)
247- # - np.exp(-(xs_1 ** 2.0) * (1.0 - q2) - ys_1 * ys_1 * (1.0 / q2 - 1.0))
248- # * wofz(q * xs_1 + 1j * ys_1 / q)
249- # )
250- # )
251- #
252- # if xp == jnp:
253- # z1 = xs_0 + 1j * ys_0
254- # z2 = q * xs_0 + 1j * ys_0 / q
255- # output_grid[ind_pos_y] = -1j * (
256- # xp.exp(- z1 * z1) * jsp.erfc(- 1j * z1)
257- # - xp.exp(-(xs_0**2.0) * (1.0 - q2) - ys_0 * ys_0 * (1.0 / q2 - 1.0))
258- # * xp.exp(- z2 * z2) * jsp.erfc(- 1j * z2)
259- # )
260- #
261- # z1 = xs_1 + 1j * ys_1
262- # z2 = q * xs_1 + 1j * ys_1 / q
263- # output_grid[~ind_pos_y] = xp.conj(
264- # -1j
265- # * (
266- # xp.exp(- z1 * z1) * jsp.erfc(- 1j * z1)
267- # - xp.exp(-(xs_1**2.0) * (1.0 - q2) - ys_1 * ys_1 * (1.0 / q2 - 1.0))
268- # * xp.exp(- z2 * z2) * jsp.erfc(- 1j * z2)
269- # )
270- # )
271- #
272- # return output_grid
273-
274- # def wofz(self, z, xp=np):
275- # return xp.exp(- z * z) * jsp.erfc(- 1j * z)
276227
277228 # def wofz(self, z, xp=np):
278229 # """
0 commit comments