Skip to content

Commit c711bce

Browse files
author
Hengkai_Pmo_StarB
committed
clear the redundant comments
1 parent 846ca38 commit c711bce

1 file changed

Lines changed: 2 additions & 63 deletions

File tree

  • autogalaxy/profiles/light/standard/shapelets

autogalaxy/profiles/light/standard/shapelets/polar.py

Lines changed: 2 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2,75 +2,16 @@
22
from typing import Optional, Tuple
33

44
import autoarray as aa
5-
import autolens as al
6-
75

86
from autogalaxy.profiles.light.decorators import (
97
check_operated_only,
108
)
119
from autogalaxy.profiles.light.standard.shapelets.abstract import AbstractShapelet
1210

1311
import jax.numpy as jnp
14-
from jax import lax
1512
from jax.scipy.special import gammaln
1613

17-
def genlaguerre_jax_recurrence(n, alpha, x):
18-
"""
19-
Generalized (associated) Laguerre polynomial $L_n^{(\alpha)}(x)$
20-
calculated using the three-term recurrence relation in pure JAX.
21-
22-
Optimized for JAX via `lax.fori_loop` for loop unrolling.
23-
24-
.. math::
25-
(k+1) L_{k+1}^{(\alpha)}(x) = (2k + 1 + \alpha - x) L_k^{(\alpha)}(x) - (k + \alpha) L_{k-1}^{(\alpha)}(x)
26-
27-
Parameters
28-
----------
29-
n : int
30-
Degree of the polynomial. MUST be a non-negative static Python integer
31-
for optimal JAX compilation.
32-
alpha : Union[float, JAXArray]
33-
Parameter $\alpha > -1$.
34-
x : JAXArray
35-
Input array (points at which to evaluate).
36-
37-
Returns
38-
-------
39-
L : JAXArray
40-
Generalized Laguerre polynomial evaluated at x.
41-
"""
42-
43-
# --- 0. Input Validation (Requires static Python int n) ---
44-
if not isinstance(n, int) or n < 0:
45-
raise ValueError(f"Degree n must be a non-negative Python integer (static), got {n}.")
46-
47-
# --- 1. Base Cases ---
48-
L0 = jnp.ones_like(x)
49-
if n == 0:
50-
return L0
51-
52-
L1 = 1 + alpha - x
53-
if n == 1:
54-
return L1
55-
56-
# --- 2. JAX Recurrence Calculation ---
57-
58-
def body(k, state):
59-
# state = (L_{k-1}, L_k)
60-
L_nm1, L_n = state
61-
62-
# Recurrence relation:
63-
# L_{k+1} = ((2k + 1 + alpha - x) * L_k - (k + alpha) * L_{k-1}) / (k + 1)
64-
L_np1 = ((2 * k + 1 + alpha - x) * L_n - (k + alpha) * L_nm1) / (k + 1)
65-
66-
# Return new state: (L_k, L_{k+1})
67-
return (L_n, L_np1)
68-
69-
# fori_loop(start, stop, body, init_state)
70-
_, res_Ln = lax.fori_loop(1, n, body, (L0, L1))
71-
return res_Ln
72-
73-
def genlaguerre_jax_summation(n, alpha, x):
14+
def genlaguerre_jax(n, alpha, x):
7415
"""
7516
Generalized (associated) Laguerre polynomial L_n^alpha(x)
7617
calculated using the explicit summation formula, optimized for JAX vectorization.
@@ -202,10 +143,8 @@ def image_2d_from(
202143
image
203144
The image of the Polar Shapelet evaluated at every (y,x) coordinate on the transformed grid.
204145
"""
205-
# from scipy.special import genlaguerre
206146
from jax.scipy.special import factorial
207147

208-
# laguerre = genlaguerre(n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m))
209148
grid = aa.util.geometry.transform_grid_2d_to_reference_frame(
210149
grid_2d=grid.array, centre=self.centre, angle=self.phi, xp=xp
211150
)
@@ -223,7 +162,7 @@ def image_2d_from(
223162

224163
m_abs = abs(self.m)
225164
n_laguerre = (self.n - m_abs) // 2
226-
laguerre_vals = genlaguerre_jax_summation(n=n_laguerre, alpha=m_abs, x=rsq)
165+
laguerre_vals = genlaguerre_jax(n=n_laguerre, alpha=m_abs, x=rsq)
227166

228167
radial = rsq ** (xp.abs(self.m) / 2.0) * xp.exp(-rsq / 2.0) * laguerre_vals
229168

0 commit comments

Comments
 (0)