Skip to content

Commit 2a099c0

Browse files
author
Niek Wielders
committed
added if statement jnp
1 parent d8ba32c commit 2a099c0

2 files changed

Lines changed: 46 additions & 59 deletions

File tree

autogalaxy/profiles/mass/stellar/gaussian.py

Lines changed: 10 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
import jax.numpy as jnp
3-
import jax.scipy.special as jsp
3+
44

55
from typing import Tuple
66

@@ -195,9 +195,6 @@ def axis_ratio(self, xp=np):
195195
return xp.where(axis_ratio < 0.9999, axis_ratio, 0.9999)
196196

197197
def zeta_from(self, grid: aa.type.Grid2DLike, xp=np):
198-
199-
from scipy.special import wofz
200-
201198
q = self.axis_ratio(xp)
202199
q2 = q ** 2.0
203200

@@ -214,65 +211,19 @@ def zeta_from(self, grid: aa.type.Grid2DLike, xp=np):
214211

215212
exp_term = xp.exp(-(xs ** 2) * (1.0 - q2) - ys ** 2 * (1.0 / q2 - 1.0))
216213

217-
core = -1j * (wofz(z1) - exp_term * wofz(z2))
214+
if xp == np:
215+
from scipy.special import wofz
216+
217+
core = -1j * (wofz(z1) - exp_term * wofz(z2))
218+
219+
if xp == jnp:
220+
import jax.scipy.special as jsp
221+
222+
core = -1j * (xp.exp(- z1 * z1) * jsp.erfc(- 1j * z1) - exp_term * xp.exp(- z2 * z2) * jsp.erfc(- 1j * z2))
218223

219224
# symmetry: zeta(x, -y) = conj(zeta(x, y))
220225
return xp.where(y >= 0, core, xp.conj(core))
221226

222-
# q = self.axis_ratio(xp)
223-
# q2 = q ** 2.0
224-
# ind_pos_y = grid.array[:, 0] >= 0
225-
# shape_grid = xp.shape(grid)
226-
# output_grid = xp.zeros((shape_grid[0]), dtype=xp.complex128)
227-
# scale_factor = q / (self.sigma * xp.sqrt(2.0 * (1.0 - q2)))
228-
#
229-
# xs_0 = grid.array[:, 1][ind_pos_y] * scale_factor
230-
# ys_0 = grid.array[:, 0][ind_pos_y] * scale_factor
231-
# xs_1 = grid.array[:, 1][~ind_pos_y] * scale_factor
232-
# ys_1 = -grid.array[:, 0][~ind_pos_y] * scale_factor
233-
#
234-
# if xp == np:
235-
# from scipy.special import wofz
236-
#
237-
# output_grid[ind_pos_y] = -1j * (
238-
# wofz(xs_0 + 1j * ys_0)
239-
# - np.exp(-(xs_0 ** 2.0) * (1.0 - q2) - ys_0 * ys_0 * (1.0 / q2 - 1.0))
240-
# * wofz(q * xs_0 + 1j * ys_0 / q)
241-
# )
242-
#
243-
# output_grid[~ind_pos_y] = np.conj(
244-
# -1j
245-
# * (
246-
# wofz(xs_1 + 1j * ys_1)
247-
# - np.exp(-(xs_1 ** 2.0) * (1.0 - q2) - ys_1 * ys_1 * (1.0 / q2 - 1.0))
248-
# * wofz(q * xs_1 + 1j * ys_1 / q)
249-
# )
250-
# )
251-
#
252-
# if xp == jnp:
253-
# z1 = xs_0 + 1j * ys_0
254-
# z2 = q * xs_0 + 1j * ys_0 / q
255-
# output_grid[ind_pos_y] = -1j * (
256-
# xp.exp(- z1 * z1) * jsp.erfc(- 1j * z1)
257-
# - xp.exp(-(xs_0**2.0) * (1.0 - q2) - ys_0 * ys_0 * (1.0 / q2 - 1.0))
258-
# * xp.exp(- z2 * z2) * jsp.erfc(- 1j * z2)
259-
# )
260-
#
261-
# z1 = xs_1 + 1j * ys_1
262-
# z2 = q * xs_1 + 1j * ys_1 / q
263-
# output_grid[~ind_pos_y] = xp.conj(
264-
# -1j
265-
# * (
266-
# xp.exp(- z1 * z1) * jsp.erfc(- 1j * z1)
267-
# - xp.exp(-(xs_1**2.0) * (1.0 - q2) - ys_1 * ys_1 * (1.0 / q2 - 1.0))
268-
# * xp.exp(- z2 * z2) * jsp.erfc(- 1j * z2)
269-
# )
270-
# )
271-
#
272-
# return output_grid
273-
274-
# def wofz(self, z, xp=np):
275-
# return xp.exp(- z * z) * jsp.erfc(- 1j * z)
276227

277228
# def wofz(self, z, xp=np):
278229
# """
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import jax.numpy as jnp
2+
3+
import autogalaxy as ag
4+
5+
grid = ag.Grid2DIrregular([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [2.0, 4.0]])
6+
7+
mp = ag.mp.Gaussian(
8+
centre=(0.0, 0.0),
9+
ell_comps=(0.0, 0.05263),
10+
intensity=1.0,
11+
sigma=3.0,
12+
mass_to_light_ratio=1.0,
13+
)
14+
15+
deflections = mp.deflections_2d_via_analytic_from(
16+
grid=ag.Grid2DIrregular([[1.0, 0.0]]),
17+
xp=jnp
18+
)
19+
20+
print(deflections[0, 0])
21+
print(deflections[0, 1])
22+
23+
mp = ag.mp.Gaussian(
24+
centre=(0.0, 0.0),
25+
ell_comps=(0.0, 0.111111),
26+
intensity=1.0,
27+
sigma=5.0,
28+
mass_to_light_ratio=1.0,
29+
)
30+
31+
deflections = mp.deflections_2d_via_analytic_from(
32+
grid=ag.Grid2DIrregular([[0.5, 0.2]])
33+
)
34+
35+
print(deflections[0, 0])
36+
print(deflections[0, 1])

0 commit comments

Comments
 (0)