Skip to content

Commit cd5c119

Browse files
author
NiekWielders
committed
passed xp everywhere, linked to mge instead mge_numpy
1 parent b906575 commit cd5c119

5 files changed

Lines changed: 35 additions & 29 deletions

File tree

autogalaxy/profiles/mass/abstract/mge_jax.py renamed to autogalaxy/profiles/mass/abstract/mge.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# from .jax_utils import w_f_approx
2+
import numpy as np
3+
4+
from autogalaxy.profiles.mass.stellar.gaussian import Gaussian
25

36

47
class MassProfileMGE:
@@ -13,7 +16,7 @@ def __init__(self):
1316
pass
1417

1518
@staticmethod
16-
def zeta_from(grid, amps, sigmas, axis_ratio):
19+
def zeta_from(grid, amps, sigmas, axis_ratio, xp=np):
1720
"""
1821
The key part to compute the deflection angle of each Gaussian.
1922
"""
@@ -36,8 +39,8 @@ def zeta_from(grid, amps, sigmas, axis_ratio):
3639

3740
# process as one big vectorized calculation
3841
# could try `jax.lax.scan` instead if this is too much memory
39-
w = w_f_approx(inv_sigma_ * z)
40-
wq = w_f_approx(inv_sigma_ * zq)
42+
w = Gaussian.wofz(inv_sigma_ * z, xp=xp)
43+
wq = Gaussian.wofz(inv_sigma_ * zq, xp=xp)
4144
exp_factor = xp.exp(inv_sigma_**2 * expv)
4245

4346
sigma_func_real = w.imag - exp_factor * wq.imag
@@ -47,15 +50,15 @@ def zeta_from(grid, amps, sigmas, axis_ratio):
4750
return output_grid.sum(axis=0)
4851

4952
@staticmethod
50-
def kesi(p):
53+
def kesi(p, xp=np):
5154
"""
5255
see Eq.(6) of 1906.08263
5356
"""
5457
n_list = xp.arange(0, 2 * p + 1, 1)
5558
return (2.0 * p * xp.log(10) / 3.0 + 2.0 * xp.pi * n_list * 1j) ** (0.5)
5659

5760
@staticmethod
58-
def eta(p):
61+
def eta(p, xp=np):
5962
"""
6063
see Eq.(6) of 1906.00263
6164
"""
@@ -76,7 +79,7 @@ def decompose_convergence_via_mge(self):
7679
raise NotImplementedError()
7780

7881
def _decompose_convergence_via_mge(
79-
self, func, radii_min, radii_max, func_terms=28, func_gaussians=20
82+
self, func, radii_min, radii_max, func_terms=28, func_gaussians=20, xp=np
8083
):
8184
"""
8285
@@ -117,7 +120,7 @@ def convergence_2d_via_mge_from(self, grid_radii):
117120
raise NotImplementedError()
118121

119122
def _convergence_2d_via_mge_from(
120-
self, grid_radii, func_terms=28, func_gaussians=20
123+
self, grid_radii, func_terms=28, func_gaussians=20, xp=np
121124
):
122125
"""Calculate the projected convergence at a given set of arc-second gridded coordinates.
123126
@@ -137,7 +140,7 @@ def _convergence_2d_via_mge_from(
137140
return convergence.sum(axis=0)
138141

139142
def _deflections_2d_via_mge_from(
140-
self, grid, sigmas_factor=1.0, func_terms=28, func_gaussians=20
143+
self, grid, sigmas_factor=1.0, func_terms=28, func_gaussians=20, xp=np
141144
):
142145
axis_ratio = xp.min(xp.array([self.axis_ratio(xp), 0.9999]))
143146

@@ -147,7 +150,7 @@ def _deflections_2d_via_mge_from(
147150
sigmas *= sigmas_factor
148151

149152
angle = self.zeta_from(
150-
grid=grid, amps=amps, sigmas=sigmas, axis_ratio=axis_ratio
153+
grid=grid, amps=amps, sigmas=sigmas, axis_ratio=axis_ratio, xp=xp
151154
)
152155

153156
angle *= xp.sqrt((2.0 * xp.pi) / (1.0 - axis_ratio**2.0))

autogalaxy/profiles/mass/dark/abstract.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from autogalaxy.profiles.mass.abstract.abstract import MassProfile
77
from autogalaxy.cosmology.model import LensingCosmology
8-
from autogalaxy.profiles.mass.abstract.mge_numpy import (
8+
from autogalaxy.profiles.mass.abstract.mge import (
99
MassProfileMGE,
1010
)
1111

@@ -85,9 +85,9 @@ def convergence_2d_via_mge_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs)
8585

8686
elliptical_radii = self.elliptical_radii_grid_from(grid=grid, xp=xp, **kwargs)
8787

88-
return self._convergence_2d_via_mge_from(grid_radii=elliptical_radii)
88+
return self._convergence_2d_via_mge_from(grid_radii=elliptical_radii, xp=xp)
8989

90-
def tabulate_integral(self, grid, tabulate_bins, **kwargs):
90+
def tabulate_integral(self, grid, tabulate_bins, xp=np, **kwargs):
9191
"""Tabulate an integral over the convergence of deflection potential of a mass profile. This is used in \
9292
the gNFW profile classes to speed up the integration procedure.
9393
@@ -99,15 +99,15 @@ def tabulate_integral(self, grid, tabulate_bins, **kwargs):
9999
The number of bins to tabulate the inner integral of this profile.
100100
"""
101101
eta_min = 1.0e-4
102-
eta_max = 1.05 * np.max(self.elliptical_radii_grid_from(grid=grid, **kwargs))
102+
eta_max = 1.05 * xp.max(self.elliptical_radii_grid_from(grid=grid, xp=xp, **kwargs))
103103

104-
minimum_log_eta = np.log10(eta_min)
105-
maximum_log_eta = np.log10(eta_max)
104+
minimum_log_eta = xp.log10(eta_min)
105+
maximum_log_eta = xp.log10(eta_max)
106106
bin_size = (maximum_log_eta - minimum_log_eta) / (tabulate_bins - 1)
107107

108108
return eta_min, eta_max, minimum_log_eta, maximum_log_eta, bin_size
109109

110-
def decompose_convergence_via_mge(self, **kwargs):
110+
def decompose_convergence_via_mge(self, xp=np, **kwargs):
111111
rho_at_scale_radius = (
112112
self.kappa_s / self.scale_radius
113113
) # density parameter of 3D gNFW
@@ -124,9 +124,9 @@ def gnfw_3d(r):
124124
)
125125

126126
amplitude_list, sigma_list = self._decompose_convergence_via_mge(
127-
func=gnfw_3d, radii_min=radii_min, radii_max=radii_max
127+
func=gnfw_3d, radii_min=radii_min, radii_max=radii_max, xp=xp
128128
)
129-
amplitude_list *= np.sqrt(2.0 * np.pi) * sigma_list
129+
amplitude_list *= xp.sqrt(2.0 * xp.pi) * sigma_list
130130
return amplitude_list, sigma_list
131131

132132
def coord_func_f(self, grid_radius: np.ndarray, xp=np) -> np.ndarray:

autogalaxy/profiles/mass/stellar/sersic.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import autoarray as aa
66

77
from autogalaxy.profiles.mass.abstract.abstract import MassProfile
8-
from autogalaxy.profiles.mass.abstract.mge_numpy import (
8+
from autogalaxy.profiles.mass.abstract.mge import (
99
MassProfileMGE,
1010
)
1111
from autogalaxy.profiles.mass.abstract.cse import (
@@ -126,7 +126,7 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs):
126126
@aa.grid_dec.to_vector_yx
127127
@aa.grid_dec.transform
128128
def deflections_2d_via_mge_from(
129-
self, grid: aa.type.Grid2DLike, func_terms=28, func_gaussians=20, **kwargs
129+
self, grid: aa.type.Grid2DLike, func_terms=28, func_gaussians=20, xp=np, **kwargs
130130
):
131131
"""
132132
Calculate the projected 2D deflection angles from a grid of (y,x) arc second coordinates, by computing and
@@ -146,6 +146,7 @@ def deflections_2d_via_mge_from(
146146
sigmas_factor=np.sqrt(self.axis_ratio()),
147147
func_terms=func_terms,
148148
func_gaussians=func_gaussians,
149+
xp=xp
149150
)
150151

151152
@aa.grid_dec.to_vector_yx
@@ -209,6 +210,7 @@ def convergence_2d_via_mge_from(
209210
grid_radii=eccentric_radii,
210211
func_terms=func_terms,
211212
func_gaussians=func_gaussians,
213+
xp=xp
212214
)
213215

214216
@aa.over_sample
@@ -238,7 +240,7 @@ def convergence_func(self, grid_radius: float) -> float:
238240

239241
@aa.grid_dec.to_array
240242
def potential_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs):
241-
return np.zeros(shape=grid.shape[0])
243+
return xp.zeros(shape=grid.shape[0])
242244

243245
def image_2d_via_radii_from(self, radius: np.ndarray):
244246
"""
@@ -258,7 +260,7 @@ def image_2d_via_radii_from(self, radius: np.ndarray):
258260
)
259261

260262
def decompose_convergence_via_mge(
261-
self, func_terms=28, func_gaussians=20
263+
self, func_terms=28, func_gaussians=20, xp=np
262264
) -> Tuple[List, List]:
263265
radii_min = self.effective_radius / 100.0
264266
radii_max = self.effective_radius * 20.0
@@ -267,7 +269,7 @@ def sersic_2d(r):
267269
return (
268270
self.mass_to_light_ratio
269271
* self.intensity
270-
* np.exp(
272+
* xp.exp(
271273
-self.sersic_constant
272274
* (((r / self.effective_radius) ** (1.0 / self.sersic_index)) - 1.0)
273275
)
@@ -279,6 +281,7 @@ def sersic_2d(r):
279281
radii_max=radii_max,
280282
func_terms=func_terms,
281283
func_gaussians=func_gaussians,
284+
xp=xp
282285
)
283286

284287
def decompose_convergence_via_cse(

autogalaxy/profiles/mass/stellar/sersic_core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def image_2d_via_radii_from(self, grid_radii: np.ndarray, xp=np):
103103
),
104104
)
105105

106-
def decompose_convergence_via_mge(self):
106+
def decompose_convergence_via_mge(self, xp=np):
107107
radii_min = self.effective_radius / 50.0
108108
radii_max = self.effective_radius * 20.0
109109

@@ -113,7 +113,7 @@ def core_sersic_2D(r):
113113
* self.intensity_prime()
114114
* (1.0 + (self.radius_break / r) ** self.alpha)
115115
** (self.gamma / self.alpha)
116-
* np.exp(
116+
* xp.exp(
117117
-self.sersic_constant
118118
* (
119119
(r**self.alpha + self.radius_break**self.alpha)
@@ -124,7 +124,7 @@ def core_sersic_2D(r):
124124
)
125125

126126
return self._decompose_convergence_via_mge(
127-
func=core_sersic_2D, radii_min=radii_min, radii_max=radii_max
127+
func=core_sersic_2D, radii_min=radii_min, radii_max=radii_max, xp=xp
128128
)
129129

130130
def intensity_prime(self, xp=np):

autogalaxy/profiles/mass/stellar/sersic_gradient.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def convergence_func(self, grid_radius: float) -> float:
150150
* self.image_2d_via_radii_from(grid_radius)
151151
)
152152

153-
def decompose_convergence_via_mge(self, **kwargs):
153+
def decompose_convergence_via_mge(self, xp=np, **kwargs):
154154
radii_min = self.effective_radius / 100.0
155155
radii_max = self.effective_radius * 20.0
156156

@@ -162,14 +162,14 @@ def sersic_gradient_2D(r):
162162
((self.axis_ratio() * r) / self.effective_radius)
163163
** -self.mass_to_light_gradient
164164
)
165-
* np.exp(
165+
* xp.exp(
166166
-self.sersic_constant
167167
* (((r / self.effective_radius) ** (1.0 / self.sersic_index)) - 1.0)
168168
)
169169
)
170170

171171
return self._decompose_convergence_via_mge(
172-
func=sersic_gradient_2D, radii_min=radii_min, radii_max=radii_max
172+
func=sersic_gradient_2D, radii_min=radii_min, radii_max=radii_max, xp=xp
173173
)
174174

175175
def decompose_convergence_via_cse(

0 commit comments

Comments
 (0)