22from typing import Optional , Tuple
33
44import autoarray as aa
5- import autolens as al
6-
75
86from autogalaxy .profiles .light .decorators import (
97 check_operated_only ,
108)
119from autogalaxy .profiles .light .standard .shapelets .abstract import AbstractShapelet
1210
1311import jax .numpy as jnp
14- from jax import lax
1512from jax .scipy .special import gammaln
1613
17- def genlaguerre_jax_recurrence (n , alpha , x ):
18- """
19- Generalized (associated) Laguerre polynomial $L_n^{(\a lpha)}(x)$
20- calculated using the three-term recurrence relation in pure JAX.
21-
22- Optimized for JAX via `lax.fori_loop` for loop unrolling.
23-
24- .. math::
25- (k+1) L_{k+1}^{(\a lpha)}(x) = (2k + 1 + \a lpha - x) L_k^{(\a lpha)}(x) - (k + \a lpha) L_{k-1}^{(\a lpha)}(x)
26-
27- Parameters
28- ----------
29- n : int
30- Degree of the polynomial. MUST be a non-negative static Python integer
31- for optimal JAX compilation.
32- alpha : Union[float, JAXArray]
33- Parameter $\a lpha > -1$.
34- x : JAXArray
35- Input array (points at which to evaluate).
36-
37- Returns
38- -------
39- L : JAXArray
40- Generalized Laguerre polynomial evaluated at x.
41- """
42-
43- # --- 0. Input Validation (Requires static Python int n) ---
44- if not isinstance (n , int ) or n < 0 :
45- raise ValueError (f"Degree n must be a non-negative Python integer (static), got { n } ." )
46-
47- # --- 1. Base Cases ---
48- L0 = jnp .ones_like (x )
49- if n == 0 :
50- return L0
51-
52- L1 = 1 + alpha - x
53- if n == 1 :
54- return L1
55-
56- # --- 2. JAX Recurrence Calculation ---
57-
58- def body (k , state ):
59- # state = (L_{k-1}, L_k)
60- L_nm1 , L_n = state
61-
62- # Recurrence relation:
63- # L_{k+1} = ((2k + 1 + alpha - x) * L_k - (k + alpha) * L_{k-1}) / (k + 1)
64- L_np1 = ((2 * k + 1 + alpha - x ) * L_n - (k + alpha ) * L_nm1 ) / (k + 1 )
65-
66- # Return new state: (L_k, L_{k+1})
67- return (L_n , L_np1 )
68-
69- # fori_loop(start, stop, body, init_state)
70- _ , res_Ln = lax .fori_loop (1 , n , body , (L0 , L1 ))
71- return res_Ln
72-
73- def genlaguerre_jax_summation (n , alpha , x ):
14+ def genlaguerre_jax (n , alpha , x ):
7415 """
7516 Generalized (associated) Laguerre polynomial L_n^alpha(x)
7617 calculated using the explicit summation formula, optimized for JAX vectorization.
@@ -202,10 +143,8 @@ def image_2d_from(
202143 image
203144 The image of the Polar Shapelet evaluated at every (y,x) coordinate on the transformed grid.
204145 """
205- # from scipy.special import genlaguerre
206146 from jax .scipy .special import factorial
207147
208- # laguerre = genlaguerre(n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m))
209148 grid = aa .util .geometry .transform_grid_2d_to_reference_frame (
210149 grid_2d = grid .array , centre = self .centre , angle = self .phi , xp = xp
211150 )
@@ -223,7 +162,7 @@ def image_2d_from(
223162
224163 m_abs = abs (self .m )
225164 n_laguerre = (self .n - m_abs ) // 2
226- laguerre_vals = genlaguerre_jax_summation (n = n_laguerre , alpha = m_abs , x = rsq )
165+ laguerre_vals = genlaguerre_jax (n = n_laguerre , alpha = m_abs , x = rsq )
227166
228167 radial = rsq ** (xp .abs (self .m ) / 2.0 ) * xp .exp (- rsq / 2.0 ) * laguerre_vals
229168
0 commit comments