@@ -113,8 +113,6 @@ def __init__(
113113 grid=grid
114114 )
115115 """
116- from astropy import units
117-
118116 super ().__init__ (centre = centre , ell_comps = (0.0 , 0.0 ))
119117
120118 self .m = int (m )
@@ -123,10 +121,40 @@ def __init__(
123121 self .slope = slope
124122
125123 self .multipole_comps = multipole_comps
126- self .k_m , self .angle_m = convert .multipole_k_m_and_phi_m_from (
127- multipole_comps = multipole_comps , m = m
124+
125+ def k_m_and_angle_m_from (self , xp = np ) -> Tuple [float , float ]:
126+ """
127+ Return the multipole normalization ``k_m`` and orientation angle ``angle_m``.
128+
129+ The multipole normalization and angle are computed from the multipole component
130+ parameters ``(epsilon_1, epsilon_2)`` using
131+ :func:`convert.multipole_k_m_and_phi_m_from`. The returned angle is converted
132+ from degrees to radians.
133+
134+ The numerical backend can be selected via the ``xp`` argument, allowing this
135+ method to be used with both NumPy and JAX (e.g. inside ``jax.jit``-compiled
136+ code).
137+
138+ Parameters
139+ ----------
140+ xp
141+ Numerical backend module, typically ``numpy`` or ``jax.numpy``.
142+
143+ Returns
144+ -------
145+ k_m
146+ The multipole normalization.
147+ angle_m
148+ The multipole orientation angle in radians.
149+ """
150+ from astropy import units
151+
152+ k_m , angle_m = convert .multipole_k_m_and_phi_m_from (
153+ multipole_comps = self .multipole_comps , m = self .m , xp = xp
128154 )
129- self .angle_m *= units .deg .to (units .rad )
155+ angle_m *= units .deg .to (units .rad )
156+
157+ return k_m , angle_m
130158
131159 def get_shape_angle (
132160 self ,
@@ -198,15 +226,17 @@ def deflections_yx_2d_from(
198226 """
199227 radial_grid , polar_angle_grid = radial_and_angle_grid_from (grid = grid , xp = xp )
200228
229+ k_m , angle_m = self .k_m_and_angle_m_from (xp = xp )
230+
201231 a_r = (
202232 - (
203233 (3.0 - self .slope )
204234 * self .einstein_radius ** (self .slope - 1.0 )
205235 * radial_grid ** (2.0 - self .slope )
206236 )
207237 / (self .m ** 2.0 - (3.0 - self .slope ) ** 2.0 )
208- * self . k_m
209- * xp .cos (self .m * (polar_angle_grid - self . angle_m ))
238+ * k_m
239+ * xp .cos (self .m * (polar_angle_grid - angle_m ))
210240 )
211241
212242 a_angle = (
@@ -216,8 +246,8 @@ def deflections_yx_2d_from(
216246 * radial_grid ** (2.0 - self .slope )
217247 )
218248 / (self .m ** 2.0 - (3.0 - self .slope ) ** 2.0 )
219- * self . k_m
220- * xp .sin (self .m * (polar_angle_grid - self . angle_m ))
249+ * k_m
250+ * xp .sin (self .m * (polar_angle_grid - angle_m ))
221251 )
222252
223253 return xp .stack (
@@ -242,13 +272,14 @@ def convergence_2d_from(
242272 The grid of (y,x) arc-second coordinates the convergence is computed on.
243273 """
244274 r , angle = radial_and_angle_grid_from (grid = grid , xp = xp )
275+ k_m , angle_m = self .k_m_and_angle_m_from (xp = xp )
245276
246277 return (
247278 1.0
248279 / 2.0
249280 * (self .einstein_radius / r ) ** (self .slope - 1 )
250- * self . k_m
251- * xp .cos (self .m * (angle - self . angle_m ))
281+ * k_m
282+ * xp .cos (self .m * (angle - angle_m ))
252283 )
253284
254285 @aa .grid_dec .to_array
0 commit comments