Skip to content

Commit 32a45cf

Browse files
authored
Merge pull request #263 from Jammy2211/feature/multipole_jax_fix
Feature/multipole jax fix
2 parents df6c6b3 + adb9c45 commit 32a45cf

1 file changed

Lines changed: 42 additions & 11 deletions

File tree

autogalaxy/profiles/mass/total/power_law_multipole.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,6 @@ def __init__(
113113
grid=grid
114114
)
115115
"""
116-
from astropy import units
117-
118116
super().__init__(centre=centre, ell_comps=(0.0, 0.0))
119117

120118
self.m = int(m)
@@ -123,10 +121,40 @@ def __init__(
123121
self.slope = slope
124122

125123
self.multipole_comps = multipole_comps
126-
self.k_m, self.angle_m = convert.multipole_k_m_and_phi_m_from(
127-
multipole_comps=multipole_comps, m=m
124+
125+
def k_m_and_angle_m_from(self, xp=np) -> Tuple[float, float]:
126+
"""
127+
Return the multipole normalization ``k_m`` and orientation angle ``angle_m``.
128+
129+
The multipole normalization and angle are computed from the multipole component
130+
parameters ``(epsilon_1, epsilon_2)`` using
131+
:func:`convert.multipole_k_m_and_phi_m_from`. The returned angle is converted
132+
from degrees to radians.
133+
134+
The numerical backend can be selected via the ``xp`` argument, allowing this
135+
method to be used with both NumPy and JAX (e.g. inside ``jax.jit``-compiled
136+
code).
137+
138+
Parameters
139+
----------
140+
xp
141+
Numerical backend module, typically ``numpy`` or ``jax.numpy``.
142+
143+
Returns
144+
-------
145+
k_m
146+
The multipole normalization.
147+
angle_m
148+
The multipole orientation angle in radians.
149+
"""
150+
from astropy import units
151+
152+
k_m, angle_m = convert.multipole_k_m_and_phi_m_from(
153+
multipole_comps=self.multipole_comps, m=self.m, xp=xp
128154
)
129-
self.angle_m *= units.deg.to(units.rad)
155+
angle_m *= units.deg.to(units.rad)
156+
157+
return k_m, angle_m
130158

131159
def get_shape_angle(
132160
self,
@@ -198,15 +226,17 @@ def deflections_yx_2d_from(
198226
"""
199227
radial_grid, polar_angle_grid = radial_and_angle_grid_from(grid=grid, xp=xp)
200228

229+
k_m, angle_m = self.k_m_and_angle_m_from(xp=xp)
230+
201231
a_r = (
202232
-(
203233
(3.0 - self.slope)
204234
* self.einstein_radius ** (self.slope - 1.0)
205235
* radial_grid ** (2.0 - self.slope)
206236
)
207237
/ (self.m**2.0 - (3.0 - self.slope) ** 2.0)
208-
* self.k_m
209-
* xp.cos(self.m * (polar_angle_grid - self.angle_m))
238+
* k_m
239+
* xp.cos(self.m * (polar_angle_grid - angle_m))
210240
)
211241

212242
a_angle = (
@@ -216,8 +246,8 @@ def deflections_yx_2d_from(
216246
* radial_grid ** (2.0 - self.slope)
217247
)
218248
/ (self.m**2.0 - (3.0 - self.slope) ** 2.0)
219-
* self.k_m
220-
* xp.sin(self.m * (polar_angle_grid - self.angle_m))
249+
* k_m
250+
* xp.sin(self.m * (polar_angle_grid - angle_m))
221251
)
222252

223253
return xp.stack(
@@ -242,13 +272,14 @@ def convergence_2d_from(
242272
The grid of (y,x) arc-second coordinates the convergence is computed on.
243273
"""
244274
r, angle = radial_and_angle_grid_from(grid=grid, xp=xp)
275+
k_m, angle_m = self.k_m_and_angle_m_from(xp=xp)
245276

246277
return (
247278
1.0
248279
/ 2.0
249280
* (self.einstein_radius / r) ** (self.slope - 1)
250-
* self.k_m
251-
* xp.cos(self.m * (angle - self.angle_m))
281+
* k_m
282+
* xp.cos(self.m * (angle - angle_m))
252283
)
253284

254285
@aa.grid_dec.to_array

0 commit comments

Comments
 (0)