Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions autogalaxy/profiles/light/linear/shapelets/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ def __init__(
The m order of the shapelets basis function in the x-direction.
centre
The (y,x) arc-second coordinates of the profile (shapelet) centre.
ell_comps
The first and second ellipticity components of the elliptical coordinate system.
q
The axis-ratio of the elliptical coordinate system, where a perfect circle has q=1.0.
phi
The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the
positive x-axis.
intensity
Overall intensity normalisation of the light profile (units are dimensionless and derived from the data
the light profile's image is compared too, which is expected to be electrons per second).
Expand All @@ -53,6 +56,7 @@ def __init__(
n: int,
m: int,
centre: Tuple[float, float] = (0.0, 0.0),
phi: float = 0.0,
beta: float = 1.0,
):
"""
Expand All @@ -74,8 +78,11 @@ def __init__(
The order of the shapelets basis function in the x-direction.
centre
The (y,x) arc-second coordinates of the profile (shapelet) centre.
phi
The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the
positive x-axis.
beta
The characteristic length scale of the shapelet basis function, defined in arc-seconds.
"""

super().__init__(n=n, m=m, centre=centre, ell_comps=(0.0, 0.0), beta=beta)
super().__init__(n=n, m=m, centre=centre, beta=beta)
84 changes: 51 additions & 33 deletions autogalaxy/profiles/light/standard/shapelets/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,34 @@
from autogalaxy.profiles.light.standard.shapelets.abstract import AbstractShapelet


def hermite_phys(n: int, x, xp=np):
"""
Physicists' Hermite polynomial H_n(x), compatible with NumPy and JAX via `xp`.

Recurrence:
H_0(x) = 1
H_1(x) = 2x
H_{n+1}(x) = 2x H_n(x) - 2n H_{n-1}(x)
"""
if n < 0:
raise ValueError("n must be >= 0")

H0 = xp.ones_like(x)
if n == 0:
return H0

H1 = 2.0 * x
if n == 1:
return H1

Hnm1 = H0
Hn = H1
for k in range(1, n):
Hnp1 = 2.0 * x * Hn - 2.0 * float(k) * Hnm1
Hnm1, Hn = Hn, Hnp1
return Hn


class ShapeletCartesian(AbstractShapelet):
def __init__(
self,
Expand Down Expand Up @@ -71,47 +99,37 @@ def image_2d_from(
) -> np.ndarray:
"""
Returns the Cartesian Shapelet light profile's 2D image from a 2D grid of Cartesian (y,x) coordinates.

If the coordinates have not been transformed to the profile's geometry (e.g. translated to the
profile `centre`), this is performed automatically.

Parameters
----------
grid
The 2D (y, x) coordinates in the original reference frame of the grid.

Returns
-------
image
The image of the Cartesian Shapelet evaluated at every (y,x) coordinate on the transformed grid.
"""
from jax.scipy.special import factorial
from scipy.special import hermite

hermite_y = hermite(n=self.n_y)
hermite_x = hermite(n=self.n_x)
# factorial backend switch
if xp is np:
from scipy.special import factorial
else:
from jax.scipy.special import factorial

y = grid.array[:, 0]
x = grid.array[:, 1]

shapelet_y = hermite_y(y / self.beta)
shapelet_x = hermite_x(x / self.beta)

return (
shapelet_y
* shapelet_x
* xp.exp(-0.5 * (y**2 + x**2) / (self.beta**2))
/ self.beta
/ (
xp.sqrt(
2 ** (self.n_x + self.n_y)
* (xp.pi)
* factorial(self.n_y)
* factorial(self.n_x)
)
)
# Apply axis-ratio stretching (minor axis)
q = self.axis_ratio(xp)
y_ell = y / q
x_ell = x

# Evaluate Hermite polynomials (JAX-safe)
shapelet_y = hermite_phys(self.n_y, y_ell / self.beta, xp=xp)
shapelet_x = hermite_phys(self.n_x, x_ell / self.beta, xp=xp)

gaussian = xp.exp(-0.5 * (x_ell**2 + y_ell**2) / (self.beta**2))

norm = self.beta * xp.sqrt(
(2.0 ** (self.n_x + self.n_y))
* xp.pi
* factorial(self.n_y)
* factorial(self.n_x)
)

return self._intensity * (shapelet_y * shapelet_x * gaussian) / norm


class ShapeletCartesianSph(ShapeletCartesian):
def __init__(
Expand Down
127 changes: 109 additions & 18 deletions autogalaxy/profiles/light/standard/shapelets/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,72 @@

import autoarray as aa


from autogalaxy.profiles.light.decorators import (
check_operated_only,
)
from autogalaxy import convert
from autogalaxy.profiles.light.standard.shapelets.abstract import AbstractShapelet


def genlaguerre_jax(n, alpha, x):
"""
Generalized (associated) Laguerre polynomial L_n^alpha(x)
calculated using the explicit summation formula, optimized for JAX vectorization.

Parameters:
n (int): Degree of the polynomial (static Python integer).
alpha (Numeric): Parameter alpha > -1.
x (Array): Input array (evaluation points).
"""
import jax.numpy as jnp
from jax.scipy.special import gammaln

# 0. Input Validation (Requires static Python int n)
if not isinstance(n, int) or n < 0:
# Use Python's math.isnan/isinf check if n is float, otherwise type error
raise ValueError(
f"Degree n must be a non-negative Python integer (static), got {n}."
)

# Base Case L0
if n == 0:
return jnp.ones_like(x)

# 1. Generate k values for summation range [0, 1, 2, ..., n]
k_values = jnp.arange(n + 1) # (n+1,)

# 2. Reshape inputs for broadcasting (x: (M, 1), k: (1, n+1))
x_expanded = jnp.expand_dims(x, axis=-1)
k_values_expanded = jnp.expand_dims(k_values, axis=0)

# --- A. Binomial Factor (BF) Calculation ---
# BF = exp( log( (n+alpha)! / ((n-k)! * (alpha+k)!) ) )

log_N_plus_alpha_fact = gammaln(n + alpha + 1)

log_BF_k = (
log_N_plus_alpha_fact
- gammaln(n - k_values + 1) # log( (n-k)! )
- gammaln(alpha + k_values + 1) # log( (alpha+k)! )
)

BF_k = jnp.exp(log_BF_k) # Shape: (n+1,)

# --- B. Term Factor (TF) Calculation ---
# TF = (-x)^k / k!

# Note: jnp.math.gamma(k_values + 1) is equivalent to k! in log-gamma space
TF_k = jnp.power(-x_expanded, k_values_expanded) / jnp.exp(
gammaln(k_values_expanded + 1)
)
# TF_k Shape: (M, n+1)

# --- C. Final Summation ---
# Sum over the last axis (axis=1), which corresponds to k
# BF_k broadcasts over the M dimension of TF_k
return jnp.sum(BF_k * TF_k, axis=1)


class ShapeletPolar(AbstractShapelet):
def __init__(
self,
Expand Down Expand Up @@ -39,17 +98,20 @@ def __init__(
The m order of the shapelets basis function in the x-direction.
centre
The (y,x) arc-second coordinates of the profile (shapelet) centre.
ell_comps
The first and second ellipticity components of the elliptical coordinate system.
q
The axis-ratio of the elliptical coordinate system, where a perfect circle has q=1.0.
phi
The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the
positive x-axis.
intensity
Overall intensity normalisation of the light profile (units are dimensionless and derived from the data
the light profile's image is compared too, which is expected to be electrons per second).
beta
The characteristic length scale of the shapelet basis function, defined in arc-seconds.
"""

self.n = n
self.m = m
self.n = int(n)
self.m = int(m)

super().__init__(
centre=centre, ell_comps=ell_comps, beta=beta, intensity=intensity
Expand Down Expand Up @@ -86,10 +148,13 @@ def image_2d_from(
image
The image of the Polar Shapelet evaluated at every (y,x) coordinate on the transformed grid.
"""
from scipy.special import genlaguerre
from jax.scipy.special import factorial
if xp is np:

laguerre = genlaguerre(n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m))
from scipy.special import factorial

else:

from jax.scipy.special import factorial

const = (
((-1) ** ((self.n - xp.abs(self.m)) // 2))
Expand All @@ -100,19 +165,43 @@ def image_2d_from(
/ self.beta
/ xp.sqrt(xp.pi)
)
y = grid.array[:, 0]
x = grid.array[:, 1]

rsq = (x**2 + (y / self.axis_ratio(xp)) ** 2) / self.beta**2
theta = xp.arctan2(y, x)

rsq = (grid.array[:, 0] ** 2 + grid.array[:, 1] ** 2) / self.beta**2
theta = xp.arctan2(grid.array[:, 1], grid.array[:, 0])
radial = rsq ** (abs(self.m / 2.0)) * xp.exp(-rsq / 2.0) * laguerre(rsq)
m_abs = abs(self.m)
n_laguerre = (self.n - m_abs) // 2

if xp is np:

from scipy.special import genlaguerre

laguerre = genlaguerre(
n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m)
)
laguerre_vals = laguerre(rsq)

if self.m == 0:
azimuthal = 1
elif self.m > 0:
azimuthal = xp.sin((-1) * self.m * theta)
else:
azimuthal = xp.cos((-1) * self.m * theta)

return const * radial * azimuthal
laguerre_vals = genlaguerre_jax(n=n_laguerre, alpha=m_abs, x=rsq)

radial = rsq ** (xp.abs(self.m) / 2.0) * xp.exp(-rsq / 2.0) * laguerre_vals

m = self.m

azimuthal = xp.where(
m == 0,
xp.ones_like(theta),
xp.where(
m > 0,
xp.sin(-m * theta),
xp.cos(-m * theta),
),
)

return self._intensity * const * radial * azimuthal


class ShapeletPolarSph(ShapeletPolar):
Expand Down Expand Up @@ -143,6 +232,9 @@ def __init__(
The order of the shapelets basis function in the x-direction.
centre
The (y,x) arc-second coordinates of the profile (shapelet) centre.
phi
The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the
positive x-axis.
intensity
Overall intensity normalisation of the light profile (units are dimensionless and derived from the data
the light profile's image is compared too, which is expected to be electrons per second).
Expand All @@ -154,7 +246,6 @@ def __init__(
n=n,
m=m,
centre=centre,
ell_comps=(0.0, 0.0),
intensity=intensity,
beta=beta,
)
Loading
Loading