From cea025a59280298b2954be273af6aa4fe7e5cb5a Mon Sep 17 00:00:00 2001 From: Hengkai_Pmo_StarB Date: Thu, 4 Dec 2025 11:18:08 +0800 Subject: [PATCH 01/17] fixed the convergence map calculation --- .../mass/total/dual_pseudo_isothermal_mass.py | 74 ++++++++++--------- 1 file changed, 41 insertions(+), 33 deletions(-) diff --git a/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py b/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py index bb6f98d46..42264afb4 100644 --- a/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py +++ b/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py @@ -287,6 +287,42 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): xp=xp, **kwargs, ) + + def _convergence(self, radii, xp=np): + + radsq = radii * radii + a = self.ra + + return ( + self.b0 + / 2 + * (1 / xp.sqrt(a**2 + radsq)) + ) + + @aa.grid_dec.to_array + @aa.grid_dec.transform + def convergence_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): + """ + Returns the two-dimensional projected convergence on a grid of (y,x) + arc-second coordinates. + + The `grid_2d_to_structure` decorator reshapes the ndarrays the convergence + is outputted on. See *aa.grid_2d_to_structure* for details. + + Parameters + ---------- + grid + The grid of (y,x) arc-second coordinates on which the convergence is computed. + """ + ellip = self._ellip(xp) + grid_radii = xp.sqrt( + grid.array[:, 1] ** 2 / (1 + ellip) ** 2 + grid.array[:, 0] ** 2 / (1 - ellip) ** 2 + ) + # Compute the convergence and deflection of a *circular* profile + kappa = self._convergence(grid_radii,xp) + + return kappa + @aa.grid_dec.transform def analytical_hessian_2d_from(self, grid: "aa.type.Grid2DLike", xp=np, **kwargs): @@ -410,22 +446,6 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): **kwargs, ) - def _deflection_angle(self, radii, xp=np): - """ - For a circularly symmetric dPIEPotential profile, computes the magnitude of the deflection at each radius. - """ - a, s = self.ra, self.rs - radii = xp.maximum(radii, 1e-8) - f = radii / (a + xp.sqrt(a**2 + radii**2)) - radii / ( - s + xp.sqrt(s**2 + radii**2) - ) - - # c.f. Eliasdottir '07 eq. A23 - # magnitude of deflection - # alpha = self.E0 * (s + a) / s * f - alpha = self.b0 * s / (s - a) * f - return alpha - def _convergence(self, radii, xp=np): radsq = radii * radii @@ -439,7 +459,7 @@ def _convergence(self, radii, xp=np): * (1 / xp.sqrt(a**2 + radsq) - 1 / xp.sqrt(s**2 + radsq)) ) - @aa.grid_dec.to_vector_yx + @aa.grid_dec.to_array @aa.grid_dec.transform def convergence_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): """ @@ -455,23 +475,11 @@ def convergence_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): """ ellip = self._ellip(xp) grid_radii = xp.sqrt( - grid.array[:, 1] ** 2 * (1 - ellip) + grid.array[:, 0] ** 2 * (1 + ellip) + grid.array[:, 1] ** 2 / (1 + ellip) ** 2 + grid.array[:, 0] ** 2 / (1 - ellip) ** 2 ) - - # Compute the convergence and deflection of a *circular* profile - kappa_circ = self._convergence(grid_radii, xp) - alpha_circ = self._deflection_angle(grid_radii, xp) - - asymm_term = ( - ellip * (1 - ellip) * grid.array[:, 1] ** 2 - - ellip * (1 + ellip) * grid.array[:, 0] ** 2 - ) / grid_radii**2 - - # convergence = 1/2 \nabla \alpha = 1/2 \nabla^2 potential - # The "asymm_term" is asymmetric on x and y, so averages out to - # zero over all space - return kappa_circ * (1 - asymm_term) + (alpha_circ / grid_radii) * asymm_term - + kappa = self._convergence(grid_radii,xp) + return kappa + @aa.grid_dec.transform def analytical_hessian_2d_from(self, grid: "aa.type.Grid2DLike", xp=np, **kwargs): """ From 007ebdb4f0b194e0939e95869143869afaecc088 Mon Sep 17 00:00:00 2001 From: Hengkai_Pmo_StarB Date: Sat, 6 Dec 2025 16:54:06 +0800 Subject: [PATCH 02/17] fixed 'xp' for _ci05, _ci05f and _mdci05 --- .../profiles/mass/total/dual_pseudo_isothermal_mass.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py b/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py index 42264afb4..8dae38b2a 100644 --- a/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py +++ b/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py @@ -275,7 +275,7 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): """ ellip = self._ellip(xp) factor = self.b0 - zis = _ci05(x=grid.array[:, 1], y=grid.array[:, 0], eps=ellip, rcore=self.ra) + zis = _ci05(x=grid.array[:, 1], y=grid.array[:, 0], eps=ellip, rcore=self.ra, xp=xp) # This is in axes aligned to the major/minor axis deflection_x = zis.real @@ -340,7 +340,7 @@ def analytical_hessian_2d_from(self, grid: "aa.type.Grid2DLike", xp=np, **kwargs ellip = self._ellip() hessian_xx, hessian_xy, hessian_yx, hessian_yy = _mdci05( - x=grid.array[:, 1], y=grid.array[:, 0], eps=ellip, rcore=self.ra, b0=self.b0 + x=grid.array[:, 1], y=grid.array[:, 0], eps=ellip, rcore=self.ra, b0=self.b0, xp=xp ) return hessian_yy, hessian_xy, hessian_yx, hessian_xx @@ -433,6 +433,7 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): eps=ellip, rcore=self.ra, rcut=self.rs, + xp=xp ) # This is in axes aligned to the major/minor axis @@ -498,10 +499,10 @@ def analytical_hessian_2d_from(self, grid: "aa.type.Grid2DLike", xp=np, **kwargs t05 = self.rs / (self.rs - self.ra) g05c_a, g05c_b, g05c_c, g05c_d = _mdci05( - x=grid.array[:, 1], y=grid.array[:, 0], eps=ellip, rcore=self.ra, b0=self.b0 + x=grid.array[:, 1], y=grid.array[:, 0], eps=ellip, rcore=self.ra, b0=self.b0, xp=xp ) g05cut_a, g05cut_b, g05cut_c, g05cut_d = _mdci05( - x=grid.array[:, 1], y=grid.array[:, 0], eps=ellip, rcore=self.rs, b0=self.b0 + x=grid.array[:, 1], y=grid.array[:, 0], eps=ellip, rcore=self.rs, b0=self.b0, xp=xp ) # Compute Hessian matrix components From 182eb57d21c85b89cc8aa423334a99640f275ab5 Mon Sep 17 00:00:00 2001 From: Hengkai_Pmo_StarB Date: Sat, 6 Dec 2025 16:54:57 +0800 Subject: [PATCH 03/17] jax supported genlaguerre --- .../light/standard/shapelets/polar.py | 125 +++++++++++++++++- 1 file changed, 120 insertions(+), 5 deletions(-) diff --git a/autogalaxy/profiles/light/standard/shapelets/polar.py b/autogalaxy/profiles/light/standard/shapelets/polar.py index 1d6c98db1..881cfb331 100644 --- a/autogalaxy/profiles/light/standard/shapelets/polar.py +++ b/autogalaxy/profiles/light/standard/shapelets/polar.py @@ -9,6 +9,116 @@ ) from autogalaxy.profiles.light.standard.shapelets.abstract import AbstractShapelet +import jax.numpy as jnp +from jax import lax +from jax.scipy.special import gammaln + +def genlaguerre_jax_recurrence(n, alpha, x): + """ + Generalized (associated) Laguerre polynomial $L_n^{(\alpha)}(x)$ + calculated using the three-term recurrence relation in pure JAX. + + Optimized for JAX via `lax.fori_loop` for loop unrolling. + + .. math:: + (k+1) L_{k+1}^{(\alpha)}(x) = (2k + 1 + \alpha - x) L_k^{(\alpha)}(x) - (k + \alpha) L_{k-1}^{(\alpha)}(x) + + Parameters + ---------- + n : int + Degree of the polynomial. MUST be a non-negative static Python integer + for optimal JAX compilation. + alpha : Union[float, JAXArray] + Parameter $\alpha > -1$. + x : JAXArray + Input array (points at which to evaluate). + + Returns + ------- + L : JAXArray + Generalized Laguerre polynomial evaluated at x. + """ + + # --- 0. Input Validation (Requires static Python int n) --- + if not isinstance(n, int) or n < 0: + raise ValueError(f"Degree n must be a non-negative Python integer (static), got {n}.") + + # --- 1. Base Cases --- + L0 = jnp.ones_like(x) + if n == 0: + return L0 + + L1 = 1 + alpha - x + if n == 1: + return L1 + + # --- 2. JAX Recurrence Calculation --- + + def body(k, state): + # state = (L_{k-1}, L_k) + L_nm1, L_n = state + + # Recurrence relation: + # L_{k+1} = ((2k + 1 + alpha - x) * L_k - (k + alpha) * L_{k-1}) / (k + 1) + L_np1 = ((2 * k + 1 + alpha - x) * L_n - (k + alpha) * L_nm1) / (k + 1) + + # Return new state: (L_k, L_{k+1}) + return (L_n, L_np1) + + # fori_loop(start, stop, body, init_state) + _, res_Ln = lax.fori_loop(1, n, body, (L0, L1)) + return res_Ln + +def genlaguerre_jax_summation(n, alpha, x): + """ + Generalized (associated) Laguerre polynomial L_n^alpha(x) + calculated using the explicit summation formula, optimized for JAX vectorization. + + Parameters: + n (int): Degree of the polynomial (static Python integer). + alpha (Numeric): Parameter alpha > -1. + x (Array): Input array (evaluation points). + """ + # 0. Input Validation (Requires static Python int n) + if not isinstance(n, int) or n < 0: + # Use Python's math.isnan/isinf check if n is float, otherwise type error + raise ValueError(f"Degree n must be a non-negative Python integer (static), got {n}.") + + # Base Case L0 + if n == 0: + return jnp.ones_like(x) + + # 1. Generate k values for summation range [0, 1, 2, ..., n] + k_values = jnp.arange(n + 1) # (n+1,) + + # 2. Reshape inputs for broadcasting (x: (M, 1), k: (1, n+1)) + x_expanded = jnp.expand_dims(x, axis=-1) + k_values_expanded = jnp.expand_dims(k_values, axis=0) + + # --- A. Binomial Factor (BF) Calculation --- + # BF = exp( log( (n+alpha)! / ((n-k)! * (alpha+k)!) ) ) + + log_N_plus_alpha_fact = gammaln(n + alpha + 1) + + log_BF_k = ( + log_N_plus_alpha_fact + - gammaln(n - k_values + 1) # log( (n-k)! ) + - gammaln(alpha + k_values + 1) # log( (alpha+k)! ) + ) + + BF_k = jnp.exp(log_BF_k) # Shape: (n+1,) + + # --- B. Term Factor (TF) Calculation --- + # TF = (-x)^k / k! + + # Note: jnp.math.gamma(k_values + 1) is equivalent to k! in log-gamma space + TF_k = jnp.power(-x_expanded, k_values_expanded) / jnp.exp(gammaln(k_values_expanded + 1)) + # TF_k Shape: (M, n+1) + + # --- C. Final Summation --- + # Sum over the last axis (axis=1), which corresponds to k + # BF_k broadcasts over the M dimension of TF_k + return jnp.sum(BF_k * TF_k, axis=1) class ShapeletPolar(AbstractShapelet): def __init__( @@ -48,8 +158,8 @@ def __init__( The characteristic length scale of the shapelet basis function, defined in arc-seconds. """ - self.n = n - self.m = m + self.n = int(n) + self.m = int(m) super().__init__( centre=centre, ell_comps=ell_comps, beta=beta, intensity=intensity @@ -86,10 +196,10 @@ def image_2d_from( image The image of the Polar Shapelet evaluated at every (y,x) coordinate on the transformed grid. """ - from scipy.special import genlaguerre + # from scipy.special import genlaguerre from jax.scipy.special import factorial - laguerre = genlaguerre(n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m)) + # laguerre = genlaguerre(n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m)) const = ( ((-1) ** ((self.n - xp.abs(self.m)) // 2)) @@ -103,7 +213,12 @@ def image_2d_from( rsq = (grid.array[:, 0] ** 2 + grid.array[:, 1] ** 2) / self.beta**2 theta = xp.arctan2(grid.array[:, 1], grid.array[:, 0]) - radial = rsq ** (abs(self.m / 2.0)) * xp.exp(-rsq / 2.0) * laguerre(rsq) + + m_abs = abs(self.m) + n_laguerre = (self.n - m_abs) // 2 + laguerre_vals = genlaguerre_jax_summation(n=n_laguerre, alpha=m_abs, x=rsq) + + radial = rsq ** (xp.abs(self.m) / 2.0) * xp.exp(-rsq / 2.0) * laguerre_vals if self.m == 0: azimuthal = 1 From 630d736571b39cf6ce86ea75127d5ba0b0c54f08 Mon Sep 17 00:00:00 2001 From: Hengkai_Pmo_StarB Date: Sat, 6 Dec 2025 21:41:05 +0800 Subject: [PATCH 04/17] added analytical covergence_2d of dPIEMassSph --- .../mass/total/dual_pseudo_isothermal_mass.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py b/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py index 8dae38b2a..8a13266f3 100644 --- a/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py +++ b/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py @@ -619,6 +619,25 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): xp=xp, **kwargs, ) + + @aa.grid_dec.to_array + @aa.grid_dec.transform + def convergence_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): + """ + Returns the two dimensional projected convergence on a grid of (y,x) arc-second coordinates. + + The `grid_2d_to_structure` decorator reshapes the ndarrays the convergence is outputted on. See + *aa.grid_2d_to_structure* for a description of the output. + + Parameters + ---------- + grid + The grid of (y,x) arc-second coordinates the convergence is computed on. + """ + # already transformed to center on profile centre so this works + radsq = grid.array[:, 0] ** 2 + grid.array[:, 1] ** 2 + + return self._convergence(xp.sqrt(radsq), xp) @aa.grid_dec.transform def analytical_hessian_2d_from(self, grid: "aa.type.Grid2DLike", xp=np, **kwargs): From cd299a002b3581c34170946411bf84a5eaec6318 Mon Sep 17 00:00:00 2001 From: Hengkai_Pmo_StarB Date: Sun, 7 Dec 2025 21:54:56 +0800 Subject: [PATCH 05/17] test shapelet with q and phi ranther ell_comps --- .../light/linear/shapelets/polar_q_phi.py | 90 ++++++ .../light/standard/shapelets/polar.py | 3 +- .../light/standard/shapelets/polar_q_phi.py | 288 ++++++++++++++++++ 3 files changed, 379 insertions(+), 2 deletions(-) create mode 100644 autogalaxy/profiles/light/linear/shapelets/polar_q_phi.py create mode 100644 autogalaxy/profiles/light/standard/shapelets/polar_q_phi.py diff --git a/autogalaxy/profiles/light/linear/shapelets/polar_q_phi.py b/autogalaxy/profiles/light/linear/shapelets/polar_q_phi.py new file mode 100644 index 000000000..fadd520c3 --- /dev/null +++ b/autogalaxy/profiles/light/linear/shapelets/polar_q_phi.py @@ -0,0 +1,90 @@ +from typing import Tuple + +from autogalaxy.profiles.light import standard as lp + +from autogalaxy.profiles.light.linear.abstract import LightProfileLinear + + +class ShapeletPolar(lp.ShapeletPolar, LightProfileLinear): + def __init__( + self, + n: int, + m: int, + centre: Tuple[float, float] = (0.0, 0.0), + # ell_comps: Tuple[float, float] = (0.0, 0.0), + q: float = 1.0, + phi: float = 0.0, + beta: float = 1.0, + ): + """ + Shapelets where the basis function is defined according to a Polar (r,theta) grid of coordinates. + + Shapelets are defined according to: + + https://arxiv.org/abs/astro-ph/0105178 + + Shapelets are described in the context of strong lens modeling in: + + https://ui.adsabs.harvard.edu/abs/2016MNRAS.457.3066T/abstract + + Parameters + ---------- + n + The n order of the shapelets basis function. + m + The m order of the shapelets basis function in the x-direction. + centre + The (y,x) arc-second coordinates of the profile (shapelet) centre. + q + The axis-ratio of the elliptical coordinate system, where a perfect circle has q=1.0. + phi + The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the + positive x-axis. + intensity + Overall intensity normalisation of the light profile (units are dimensionless and derived from the data + the light profile's image is compared too, which is expected to be electrons per second). + beta + The characteristic length scale of the shapelet basis function, defined in arc-seconds. + """ + + super().__init__( + n=n, m=m, centre=centre, q=q, phi=phi, beta=beta, intensity=1.0 + ) + + +class ShapeletPolarSph(ShapeletPolar): + def __init__( + self, + n: int, + m: int, + centre: Tuple[float, float] = (0.0, 0.0), + phi: float = 0.0, + beta: float = 1.0, + ): + """ + Shapelets where the basis function is defined according to a Polar (r,theta) grid of coordinates. + + Shapelets are defined according to: + + https://arxiv.org/abs/astro-ph/0105178 + + Shapelets are described in the context of strong lens modeling in: + + https://ui.adsabs.harvard.edu/abs/2016MNRAS.457.3066T/abstract + + Parameters + ---------- + n_y + The order of the shapelets basis function in the y-direction. + n_x + The order of the shapelets basis function in the x-direction. + centre + The (y,x) arc-second coordinates of the profile (shapelet) centre. + phi + The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the + positive x-axis. + beta + The characteristic length scale of the shapelet basis function, defined in arc-seconds. + """ + + super().__init__(n=n, m=m, centre=centre, q=1.0, phi=phi, beta=beta) diff --git a/autogalaxy/profiles/light/standard/shapelets/polar.py b/autogalaxy/profiles/light/standard/shapelets/polar.py index 881cfb331..627d78d7c 100644 --- a/autogalaxy/profiles/light/standard/shapelets/polar.py +++ b/autogalaxy/profiles/light/standard/shapelets/polar.py @@ -210,8 +210,7 @@ def image_2d_from( / self.beta / xp.sqrt(xp.pi) ) - - rsq = (grid.array[:, 0] ** 2 + grid.array[:, 1] ** 2) / self.beta**2 + rsq = (grid.array[:, 0] ** 2 + (grid.array[:, 1]/self.axis_ratio(xp)) ** 2) / self.beta**2 theta = xp.arctan2(grid.array[:, 1], grid.array[:, 0]) m_abs = abs(self.m) diff --git a/autogalaxy/profiles/light/standard/shapelets/polar_q_phi.py b/autogalaxy/profiles/light/standard/shapelets/polar_q_phi.py new file mode 100644 index 000000000..b6423da50 --- /dev/null +++ b/autogalaxy/profiles/light/standard/shapelets/polar_q_phi.py @@ -0,0 +1,288 @@ +import numpy as np +from typing import Optional, Tuple + +import autoarray as aa +import autolens as al + + +from autogalaxy.profiles.light.decorators import ( + check_operated_only, +) +from autogalaxy.profiles.light.standard.shapelets.abstract import AbstractShapelet + +import jax.numpy as jnp +from jax import lax +from jax.scipy.special import gammaln + +def genlaguerre_jax_recurrence(n, alpha, x): + """ + Generalized (associated) Laguerre polynomial $L_n^{(\alpha)}(x)$ + calculated using the three-term recurrence relation in pure JAX. + + Optimized for JAX via `lax.fori_loop` for loop unrolling. + + .. math:: + (k+1) L_{k+1}^{(\alpha)}(x) = (2k + 1 + \alpha - x) L_k^{(\alpha)}(x) - (k + \alpha) L_{k-1}^{(\alpha)}(x) + + Parameters + ---------- + n : int + Degree of the polynomial. MUST be a non-negative static Python integer + for optimal JAX compilation. + alpha : Union[float, JAXArray] + Parameter $\alpha > -1$. + x : JAXArray + Input array (points at which to evaluate). + + Returns + ------- + L : JAXArray + Generalized Laguerre polynomial evaluated at x. + """ + + # --- 0. Input Validation (Requires static Python int n) --- + if not isinstance(n, int) or n < 0: + raise ValueError(f"Degree n must be a non-negative Python integer (static), got {n}.") + + # --- 1. Base Cases --- + L0 = jnp.ones_like(x) + if n == 0: + return L0 + + L1 = 1 + alpha - x + if n == 1: + return L1 + + # --- 2. JAX Recurrence Calculation --- + + def body(k, state): + # state = (L_{k-1}, L_k) + L_nm1, L_n = state + + # Recurrence relation: + # L_{k+1} = ((2k + 1 + alpha - x) * L_k - (k + alpha) * L_{k-1}) / (k + 1) + L_np1 = ((2 * k + 1 + alpha - x) * L_n - (k + alpha) * L_nm1) / (k + 1) + + # Return new state: (L_k, L_{k+1}) + return (L_n, L_np1) + + # fori_loop(start, stop, body, init_state) + _, res_Ln = lax.fori_loop(1, n, body, (L0, L1)) + return res_Ln + +def genlaguerre_jax_summation(n, alpha, x): + """ + Generalized (associated) Laguerre polynomial L_n^alpha(x) + calculated using the explicit summation formula, optimized for JAX vectorization. + + Parameters: + n (int): Degree of the polynomial (static Python integer). + alpha (Numeric): Parameter alpha > -1. + x (Array): Input array (evaluation points). + """ + # 0. Input Validation (Requires static Python int n) + if not isinstance(n, int) or n < 0: + # Use Python's math.isnan/isinf check if n is float, otherwise type error + raise ValueError(f"Degree n must be a non-negative Python integer (static), got {n}.") + + # Base Case L0 + if n == 0: + return jnp.ones_like(x) + + # 1. Generate k values for summation range [0, 1, 2, ..., n] + k_values = jnp.arange(n + 1) # (n+1,) + + # 2. Reshape inputs for broadcasting (x: (M, 1), k: (1, n+1)) + x_expanded = jnp.expand_dims(x, axis=-1) + k_values_expanded = jnp.expand_dims(k_values, axis=0) + + # --- A. Binomial Factor (BF) Calculation --- + # BF = exp( log( (n+alpha)! / ((n-k)! * (alpha+k)!) ) ) + + log_N_plus_alpha_fact = gammaln(n + alpha + 1) + + log_BF_k = ( + log_N_plus_alpha_fact + - gammaln(n - k_values + 1) # log( (n-k)! ) + - gammaln(alpha + k_values + 1) # log( (alpha+k)! ) + ) + + BF_k = jnp.exp(log_BF_k) # Shape: (n+1,) + + # --- B. Term Factor (TF) Calculation --- + # TF = (-x)^k / k! + + # Note: jnp.math.gamma(k_values + 1) is equivalent to k! in log-gamma space + TF_k = jnp.power(-x_expanded, k_values_expanded) / jnp.exp(gammaln(k_values_expanded + 1)) + # TF_k Shape: (M, n+1) + + # --- C. Final Summation --- + # Sum over the last axis (axis=1), which corresponds to k + # BF_k broadcasts over the M dimension of TF_k + return jnp.sum(BF_k * TF_k, axis=1) + +class ShapeletPolar(AbstractShapelet): + def __init__( + self, + n: int, + m: int, + centre: Tuple[float, float] = (0.0, 0.0), + q: float = 1.0, + phi: float = 0.0, + intensity: float = 1.0, + beta: float = 1.0, + ): + """ + Shapelets where the basis function is defined according to a Polar (r,theta) grid of coordinates. + + Shapelets are defined according to: + + https://arxiv.org/abs/astro-ph/0105178 + + Shapelets are described in the context of strong lens modeling in: + + https://ui.adsabs.harvard.edu/abs/2016MNRAS.457.3066T/abstract + + Parameters + ---------- + n + The n order of the shapelets basis function. + m + The m order of the shapelets basis function in the x-direction. + centre + The (y,x) arc-second coordinates of the profile (shapelet) centre. + q + The axis-ratio of the elliptical coordinate system, where a perfect circle has q=1.0. + phi + The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the + positive x-axis. + intensity + Overall intensity normalisation of the light profile (units are dimensionless and derived from the data + the light profile's image is compared too, which is expected to be electrons per second). + beta + The characteristic length scale of the shapelet basis function, defined in arc-seconds. + """ + + self.n = int(n) + self.m = int(m) + self.phi = float(phi) + self.q = float(q) + + super().__init__( + centre=centre, beta=beta, ell_comps=al.convert.ell_comps_from(q,phi,np), intensity=intensity + ) + + @property + def coefficient_tag(self) -> str: + return f"n_{self.n}_m_{self.m}" + + @aa.over_sample + @aa.grid_dec.to_array + @check_operated_only + # @aa.grid_dec.transform + def image_2d_from( + self, + grid: aa.type.Grid2DLike, + xp=np, + operated_only: Optional[bool] = None, + **kwargs, + ) -> np.ndarray: + """ + Returns the Polar Shapelet light profile's 2D image from a 2D grid of Polar (y,x) coordinates. + + If the coordinates have not been transformed to the profile's geometry (e.g. translated to the + profile `centre`), this is performed automatically. + + Parameters + ---------- + grid + The 2D (y, x) coordinates in the original reference frame of the grid. + + Returns + ------- + image + The image of the Polar Shapelet evaluated at every (y,x) coordinate on the transformed grid. + """ + # from scipy.special import genlaguerre + from jax.scipy.special import factorial + + # laguerre = genlaguerre(n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m)) + grid = aa.util.geometry.transform_grid_2d_to_reference_frame( + grid_2d=grid.array, centre=self.centre, angle=self.phi, xp=xp + ) + const = ( + ((-1) ** ((self.n - xp.abs(self.m)) // 2)) + * xp.sqrt( + factorial((self.n - xp.abs(self.m)) // 2) + / factorial((self.n + xp.abs(self.m)) // 2) + ) + / self.beta + / xp.sqrt(xp.pi) + ) + rsq = (grid[:, 0] ** 2 + (grid[:, 1]/self.q) ** 2) / self.beta**2 + theta = xp.arctan2(grid[:, 1], grid[:, 0]) + + m_abs = abs(self.m) + n_laguerre = (self.n - m_abs) // 2 + laguerre_vals = genlaguerre_jax_summation(n=n_laguerre, alpha=m_abs, x=rsq) + + radial = rsq ** (xp.abs(self.m) / 2.0) * xp.exp(-rsq / 2.0) * laguerre_vals + + if self.m == 0: + azimuthal = 1 + elif self.m > 0: + azimuthal = xp.sin((-1) * self.m * theta) + else: + azimuthal = xp.cos((-1) * self.m * theta) + + return const * radial * azimuthal + + +class ShapeletPolarSph(ShapeletPolar): + def __init__( + self, + n: int, + m: int, + centre: Tuple[float, float] = (0.0, 0.0), + phi: float = 0.0, + intensity: float = 1.0, + beta: float = 1.0, + ): + """ + Shapelets where the basis function is defined according to a Polar (r,theta) grid of coordinates. + + Shapelets are defined according to: + + https://arxiv.org/abs/astro-ph/0105178 + + Shapelets are described in the context of strong lens modeling in: + + https://ui.adsabs.harvard.edu/abs/2016MNRAS.457.3066T/abstract + + Parameters + ---------- + n_y + The order of the shapelets basis function in the y-direction. + n_x + The order of the shapelets basis function in the x-direction. + centre + The (y,x) arc-second coordinates of the profile (shapelet) centre. + phi + The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the + positive x-axis. + intensity + Overall intensity normalisation of the light profile (units are dimensionless and derived from the data + the light profile's image is compared too, which is expected to be electrons per second). + beta + The characteristic length scale of the shapelet basis function, defined in arc-seconds. + """ + + super().__init__( + n=n, + m=m, + centre=centre, + q=1.0, + phi=phi, + intensity=intensity, + beta=beta, + ) From f589484aecd08115a87c1c092c8183dcddea8d53 Mon Sep 17 00:00:00 2001 From: Hengkai_Pmo_StarB Date: Fri, 12 Dec 2025 20:22:10 +0800 Subject: [PATCH 06/17] multiply self._intensity --- autogalaxy/profiles/light/standard/shapelets/polar.py | 2 +- autogalaxy/profiles/light/standard/shapelets/polar_q_phi.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/autogalaxy/profiles/light/standard/shapelets/polar.py b/autogalaxy/profiles/light/standard/shapelets/polar.py index 627d78d7c..5c7d06a40 100644 --- a/autogalaxy/profiles/light/standard/shapelets/polar.py +++ b/autogalaxy/profiles/light/standard/shapelets/polar.py @@ -226,7 +226,7 @@ def image_2d_from( else: azimuthal = xp.cos((-1) * self.m * theta) - return const * radial * azimuthal + return self._intensity * const * radial * azimuthal class ShapeletPolarSph(ShapeletPolar): diff --git a/autogalaxy/profiles/light/standard/shapelets/polar_q_phi.py b/autogalaxy/profiles/light/standard/shapelets/polar_q_phi.py index b6423da50..fe03de1ad 100644 --- a/autogalaxy/profiles/light/standard/shapelets/polar_q_phi.py +++ b/autogalaxy/profiles/light/standard/shapelets/polar_q_phi.py @@ -169,7 +169,7 @@ def __init__( self.q = float(q) super().__init__( - centre=centre, beta=beta, ell_comps=al.convert.ell_comps_from(q,phi,np), intensity=intensity + centre=centre, beta=beta, intensity=intensity ) @property @@ -235,7 +235,7 @@ def image_2d_from( else: azimuthal = xp.cos((-1) * self.m * theta) - return const * radial * azimuthal + return self._intensity * const * radial * azimuthal class ShapeletPolarSph(ShapeletPolar): From 846ca386e0c7272ecf05dadb42a49bf75bbf08d7 Mon Sep 17 00:00:00 2001 From: Hengkai_Pmo_StarB Date: Fri, 12 Dec 2025 20:51:47 +0800 Subject: [PATCH 07/17] ShapeletPolar with q,phi convention --- .../profiles/light/linear/shapelets/polar.py | 18 +- .../light/linear/shapelets/polar_q_phi.py | 90 ------ .../light/standard/shapelets/polar.py | 31 +- .../light/standard/shapelets/polar_q_phi.py | 288 ------------------ 4 files changed, 35 insertions(+), 392 deletions(-) delete mode 100644 autogalaxy/profiles/light/linear/shapelets/polar_q_phi.py delete mode 100644 autogalaxy/profiles/light/standard/shapelets/polar_q_phi.py diff --git a/autogalaxy/profiles/light/linear/shapelets/polar.py b/autogalaxy/profiles/light/linear/shapelets/polar.py index f3c34e7fa..8ad67d462 100644 --- a/autogalaxy/profiles/light/linear/shapelets/polar.py +++ b/autogalaxy/profiles/light/linear/shapelets/polar.py @@ -11,7 +11,8 @@ def __init__( n: int, m: int, centre: Tuple[float, float] = (0.0, 0.0), - ell_comps: Tuple[float, float] = (0.0, 0.0), + q: float = 1.0, + phi: float = 0.0, beta: float = 1.0, ): """ @@ -33,8 +34,11 @@ def __init__( The m order of the shapelets basis function in the x-direction. centre The (y,x) arc-second coordinates of the profile (shapelet) centre. - ell_comps - The first and second ellipticity components of the elliptical coordinate system. + q + The axis-ratio of the elliptical coordinate system, where a perfect circle has q=1.0. + phi + The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the + positive x-axis. intensity Overall intensity normalisation of the light profile (units are dimensionless and derived from the data the light profile's image is compared too, which is expected to be electrons per second). @@ -43,7 +47,7 @@ def __init__( """ super().__init__( - n=n, m=m, centre=centre, ell_comps=ell_comps, beta=beta, intensity=1.0 + n=n, m=m, centre=centre, q=q, phi=phi, beta=beta, intensity=1.0 ) @@ -53,6 +57,7 @@ def __init__( n: int, m: int, centre: Tuple[float, float] = (0.0, 0.0), + phi: float = 0.0, beta: float = 1.0, ): """ @@ -74,8 +79,11 @@ def __init__( The order of the shapelets basis function in the x-direction. centre The (y,x) arc-second coordinates of the profile (shapelet) centre. + phi + The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the + positive x-axis. beta The characteristic length scale of the shapelet basis function, defined in arc-seconds. """ - super().__init__(n=n, m=m, centre=centre, ell_comps=(0.0, 0.0), beta=beta) + super().__init__(n=n, m=m, centre=centre, q=1.0, phi=phi, beta=beta) diff --git a/autogalaxy/profiles/light/linear/shapelets/polar_q_phi.py b/autogalaxy/profiles/light/linear/shapelets/polar_q_phi.py deleted file mode 100644 index fadd520c3..000000000 --- a/autogalaxy/profiles/light/linear/shapelets/polar_q_phi.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Tuple - -from autogalaxy.profiles.light import standard as lp - -from autogalaxy.profiles.light.linear.abstract import LightProfileLinear - - -class ShapeletPolar(lp.ShapeletPolar, LightProfileLinear): - def __init__( - self, - n: int, - m: int, - centre: Tuple[float, float] = (0.0, 0.0), - # ell_comps: Tuple[float, float] = (0.0, 0.0), - q: float = 1.0, - phi: float = 0.0, - beta: float = 1.0, - ): - """ - Shapelets where the basis function is defined according to a Polar (r,theta) grid of coordinates. - - Shapelets are defined according to: - - https://arxiv.org/abs/astro-ph/0105178 - - Shapelets are described in the context of strong lens modeling in: - - https://ui.adsabs.harvard.edu/abs/2016MNRAS.457.3066T/abstract - - Parameters - ---------- - n - The n order of the shapelets basis function. - m - The m order of the shapelets basis function in the x-direction. - centre - The (y,x) arc-second coordinates of the profile (shapelet) centre. - q - The axis-ratio of the elliptical coordinate system, where a perfect circle has q=1.0. - phi - The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the - positive x-axis. - intensity - Overall intensity normalisation of the light profile (units are dimensionless and derived from the data - the light profile's image is compared too, which is expected to be electrons per second). - beta - The characteristic length scale of the shapelet basis function, defined in arc-seconds. - """ - - super().__init__( - n=n, m=m, centre=centre, q=q, phi=phi, beta=beta, intensity=1.0 - ) - - -class ShapeletPolarSph(ShapeletPolar): - def __init__( - self, - n: int, - m: int, - centre: Tuple[float, float] = (0.0, 0.0), - phi: float = 0.0, - beta: float = 1.0, - ): - """ - Shapelets where the basis function is defined according to a Polar (r,theta) grid of coordinates. - - Shapelets are defined according to: - - https://arxiv.org/abs/astro-ph/0105178 - - Shapelets are described in the context of strong lens modeling in: - - https://ui.adsabs.harvard.edu/abs/2016MNRAS.457.3066T/abstract - - Parameters - ---------- - n_y - The order of the shapelets basis function in the y-direction. - n_x - The order of the shapelets basis function in the x-direction. - centre - The (y,x) arc-second coordinates of the profile (shapelet) centre. - phi - The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the - positive x-axis. - beta - The characteristic length scale of the shapelet basis function, defined in arc-seconds. - """ - - super().__init__(n=n, m=m, centre=centre, q=1.0, phi=phi, beta=beta) diff --git a/autogalaxy/profiles/light/standard/shapelets/polar.py b/autogalaxy/profiles/light/standard/shapelets/polar.py index 5c7d06a40..de8beeca5 100644 --- a/autogalaxy/profiles/light/standard/shapelets/polar.py +++ b/autogalaxy/profiles/light/standard/shapelets/polar.py @@ -2,6 +2,7 @@ from typing import Optional, Tuple import autoarray as aa +import autolens as al from autogalaxy.profiles.light.decorators import ( @@ -126,7 +127,8 @@ def __init__( n: int, m: int, centre: Tuple[float, float] = (0.0, 0.0), - ell_comps: Tuple[float, float] = (0.0, 0.0), + q: float = 1.0, + phi: float = 0.0, intensity: float = 1.0, beta: float = 1.0, ): @@ -149,8 +151,11 @@ def __init__( The m order of the shapelets basis function in the x-direction. centre The (y,x) arc-second coordinates of the profile (shapelet) centre. - ell_comps - The first and second ellipticity components of the elliptical coordinate system. + q + The axis-ratio of the elliptical coordinate system, where a perfect circle has q=1.0. + phi + The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the + positive x-axis. intensity Overall intensity normalisation of the light profile (units are dimensionless and derived from the data the light profile's image is compared too, which is expected to be electrons per second). @@ -160,9 +165,11 @@ def __init__( self.n = int(n) self.m = int(m) + self.phi = float(phi) + self.q = float(q) super().__init__( - centre=centre, ell_comps=ell_comps, beta=beta, intensity=intensity + centre=centre, beta=beta, intensity=intensity ) @property @@ -172,7 +179,6 @@ def coefficient_tag(self) -> str: @aa.over_sample @aa.grid_dec.to_array @check_operated_only - @aa.grid_dec.transform def image_2d_from( self, grid: aa.type.Grid2DLike, @@ -200,7 +206,9 @@ def image_2d_from( from jax.scipy.special import factorial # laguerre = genlaguerre(n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m)) - + grid = aa.util.geometry.transform_grid_2d_to_reference_frame( + grid_2d=grid.array, centre=self.centre, angle=self.phi, xp=xp + ) const = ( ((-1) ** ((self.n - xp.abs(self.m)) // 2)) * xp.sqrt( @@ -210,8 +218,8 @@ def image_2d_from( / self.beta / xp.sqrt(xp.pi) ) - rsq = (grid.array[:, 0] ** 2 + (grid.array[:, 1]/self.axis_ratio(xp)) ** 2) / self.beta**2 - theta = xp.arctan2(grid.array[:, 1], grid.array[:, 0]) + rsq = (grid[:, 0] ** 2 + (grid[:, 1]/self.q) ** 2) / self.beta**2 + theta = xp.arctan2(grid[:, 1], grid[:, 0]) m_abs = abs(self.m) n_laguerre = (self.n - m_abs) // 2 @@ -235,6 +243,7 @@ def __init__( n: int, m: int, centre: Tuple[float, float] = (0.0, 0.0), + phi: float = 0.0, intensity: float = 1.0, beta: float = 1.0, ): @@ -257,6 +266,9 @@ def __init__( The order of the shapelets basis function in the x-direction. centre The (y,x) arc-second coordinates of the profile (shapelet) centre. + phi + The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the + positive x-axis. intensity Overall intensity normalisation of the light profile (units are dimensionless and derived from the data the light profile's image is compared too, which is expected to be electrons per second). @@ -268,7 +280,8 @@ def __init__( n=n, m=m, centre=centre, - ell_comps=(0.0, 0.0), + q=1.0, + phi=phi, intensity=intensity, beta=beta, ) diff --git a/autogalaxy/profiles/light/standard/shapelets/polar_q_phi.py b/autogalaxy/profiles/light/standard/shapelets/polar_q_phi.py deleted file mode 100644 index fe03de1ad..000000000 --- a/autogalaxy/profiles/light/standard/shapelets/polar_q_phi.py +++ /dev/null @@ -1,288 +0,0 @@ -import numpy as np -from typing import Optional, Tuple - -import autoarray as aa -import autolens as al - - -from autogalaxy.profiles.light.decorators import ( - check_operated_only, -) -from autogalaxy.profiles.light.standard.shapelets.abstract import AbstractShapelet - -import jax.numpy as jnp -from jax import lax -from jax.scipy.special import gammaln - -def genlaguerre_jax_recurrence(n, alpha, x): - """ - Generalized (associated) Laguerre polynomial $L_n^{(\alpha)}(x)$ - calculated using the three-term recurrence relation in pure JAX. - - Optimized for JAX via `lax.fori_loop` for loop unrolling. - - .. math:: - (k+1) L_{k+1}^{(\alpha)}(x) = (2k + 1 + \alpha - x) L_k^{(\alpha)}(x) - (k + \alpha) L_{k-1}^{(\alpha)}(x) - - Parameters - ---------- - n : int - Degree of the polynomial. MUST be a non-negative static Python integer - for optimal JAX compilation. - alpha : Union[float, JAXArray] - Parameter $\alpha > -1$. - x : JAXArray - Input array (points at which to evaluate). - - Returns - ------- - L : JAXArray - Generalized Laguerre polynomial evaluated at x. - """ - - # --- 0. Input Validation (Requires static Python int n) --- - if not isinstance(n, int) or n < 0: - raise ValueError(f"Degree n must be a non-negative Python integer (static), got {n}.") - - # --- 1. Base Cases --- - L0 = jnp.ones_like(x) - if n == 0: - return L0 - - L1 = 1 + alpha - x - if n == 1: - return L1 - - # --- 2. JAX Recurrence Calculation --- - - def body(k, state): - # state = (L_{k-1}, L_k) - L_nm1, L_n = state - - # Recurrence relation: - # L_{k+1} = ((2k + 1 + alpha - x) * L_k - (k + alpha) * L_{k-1}) / (k + 1) - L_np1 = ((2 * k + 1 + alpha - x) * L_n - (k + alpha) * L_nm1) / (k + 1) - - # Return new state: (L_k, L_{k+1}) - return (L_n, L_np1) - - # fori_loop(start, stop, body, init_state) - _, res_Ln = lax.fori_loop(1, n, body, (L0, L1)) - return res_Ln - -def genlaguerre_jax_summation(n, alpha, x): - """ - Generalized (associated) Laguerre polynomial L_n^alpha(x) - calculated using the explicit summation formula, optimized for JAX vectorization. - - Parameters: - n (int): Degree of the polynomial (static Python integer). - alpha (Numeric): Parameter alpha > -1. - x (Array): Input array (evaluation points). - """ - # 0. Input Validation (Requires static Python int n) - if not isinstance(n, int) or n < 0: - # Use Python's math.isnan/isinf check if n is float, otherwise type error - raise ValueError(f"Degree n must be a non-negative Python integer (static), got {n}.") - - # Base Case L0 - if n == 0: - return jnp.ones_like(x) - - # 1. Generate k values for summation range [0, 1, 2, ..., n] - k_values = jnp.arange(n + 1) # (n+1,) - - # 2. Reshape inputs for broadcasting (x: (M, 1), k: (1, n+1)) - x_expanded = jnp.expand_dims(x, axis=-1) - k_values_expanded = jnp.expand_dims(k_values, axis=0) - - # --- A. Binomial Factor (BF) Calculation --- - # BF = exp( log( (n+alpha)! / ((n-k)! * (alpha+k)!) ) ) - - log_N_plus_alpha_fact = gammaln(n + alpha + 1) - - log_BF_k = ( - log_N_plus_alpha_fact - - gammaln(n - k_values + 1) # log( (n-k)! ) - - gammaln(alpha + k_values + 1) # log( (alpha+k)! ) - ) - - BF_k = jnp.exp(log_BF_k) # Shape: (n+1,) - - # --- B. Term Factor (TF) Calculation --- - # TF = (-x)^k / k! - - # Note: jnp.math.gamma(k_values + 1) is equivalent to k! in log-gamma space - TF_k = jnp.power(-x_expanded, k_values_expanded) / jnp.exp(gammaln(k_values_expanded + 1)) - # TF_k Shape: (M, n+1) - - # --- C. Final Summation --- - # Sum over the last axis (axis=1), which corresponds to k - # BF_k broadcasts over the M dimension of TF_k - return jnp.sum(BF_k * TF_k, axis=1) - -class ShapeletPolar(AbstractShapelet): - def __init__( - self, - n: int, - m: int, - centre: Tuple[float, float] = (0.0, 0.0), - q: float = 1.0, - phi: float = 0.0, - intensity: float = 1.0, - beta: float = 1.0, - ): - """ - Shapelets where the basis function is defined according to a Polar (r,theta) grid of coordinates. - - Shapelets are defined according to: - - https://arxiv.org/abs/astro-ph/0105178 - - Shapelets are described in the context of strong lens modeling in: - - https://ui.adsabs.harvard.edu/abs/2016MNRAS.457.3066T/abstract - - Parameters - ---------- - n - The n order of the shapelets basis function. - m - The m order of the shapelets basis function in the x-direction. - centre - The (y,x) arc-second coordinates of the profile (shapelet) centre. - q - The axis-ratio of the elliptical coordinate system, where a perfect circle has q=1.0. - phi - The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the - positive x-axis. - intensity - Overall intensity normalisation of the light profile (units are dimensionless and derived from the data - the light profile's image is compared too, which is expected to be electrons per second). - beta - The characteristic length scale of the shapelet basis function, defined in arc-seconds. - """ - - self.n = int(n) - self.m = int(m) - self.phi = float(phi) - self.q = float(q) - - super().__init__( - centre=centre, beta=beta, intensity=intensity - ) - - @property - def coefficient_tag(self) -> str: - return f"n_{self.n}_m_{self.m}" - - @aa.over_sample - @aa.grid_dec.to_array - @check_operated_only - # @aa.grid_dec.transform - def image_2d_from( - self, - grid: aa.type.Grid2DLike, - xp=np, - operated_only: Optional[bool] = None, - **kwargs, - ) -> np.ndarray: - """ - Returns the Polar Shapelet light profile's 2D image from a 2D grid of Polar (y,x) coordinates. - - If the coordinates have not been transformed to the profile's geometry (e.g. translated to the - profile `centre`), this is performed automatically. - - Parameters - ---------- - grid - The 2D (y, x) coordinates in the original reference frame of the grid. - - Returns - ------- - image - The image of the Polar Shapelet evaluated at every (y,x) coordinate on the transformed grid. - """ - # from scipy.special import genlaguerre - from jax.scipy.special import factorial - - # laguerre = genlaguerre(n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m)) - grid = aa.util.geometry.transform_grid_2d_to_reference_frame( - grid_2d=grid.array, centre=self.centre, angle=self.phi, xp=xp - ) - const = ( - ((-1) ** ((self.n - xp.abs(self.m)) // 2)) - * xp.sqrt( - factorial((self.n - xp.abs(self.m)) // 2) - / factorial((self.n + xp.abs(self.m)) // 2) - ) - / self.beta - / xp.sqrt(xp.pi) - ) - rsq = (grid[:, 0] ** 2 + (grid[:, 1]/self.q) ** 2) / self.beta**2 - theta = xp.arctan2(grid[:, 1], grid[:, 0]) - - m_abs = abs(self.m) - n_laguerre = (self.n - m_abs) // 2 - laguerre_vals = genlaguerre_jax_summation(n=n_laguerre, alpha=m_abs, x=rsq) - - radial = rsq ** (xp.abs(self.m) / 2.0) * xp.exp(-rsq / 2.0) * laguerre_vals - - if self.m == 0: - azimuthal = 1 - elif self.m > 0: - azimuthal = xp.sin((-1) * self.m * theta) - else: - azimuthal = xp.cos((-1) * self.m * theta) - - return self._intensity * const * radial * azimuthal - - -class ShapeletPolarSph(ShapeletPolar): - def __init__( - self, - n: int, - m: int, - centre: Tuple[float, float] = (0.0, 0.0), - phi: float = 0.0, - intensity: float = 1.0, - beta: float = 1.0, - ): - """ - Shapelets where the basis function is defined according to a Polar (r,theta) grid of coordinates. - - Shapelets are defined according to: - - https://arxiv.org/abs/astro-ph/0105178 - - Shapelets are described in the context of strong lens modeling in: - - https://ui.adsabs.harvard.edu/abs/2016MNRAS.457.3066T/abstract - - Parameters - ---------- - n_y - The order of the shapelets basis function in the y-direction. - n_x - The order of the shapelets basis function in the x-direction. - centre - The (y,x) arc-second coordinates of the profile (shapelet) centre. - phi - The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the - positive x-axis. - intensity - Overall intensity normalisation of the light profile (units are dimensionless and derived from the data - the light profile's image is compared too, which is expected to be electrons per second). - beta - The characteristic length scale of the shapelet basis function, defined in arc-seconds. - """ - - super().__init__( - n=n, - m=m, - centre=centre, - q=1.0, - phi=phi, - intensity=intensity, - beta=beta, - ) From c711bce22761fdc55407f4f22969dc46116cbcfd Mon Sep 17 00:00:00 2001 From: Hengkai_Pmo_StarB Date: Sat, 13 Dec 2025 00:06:40 +0800 Subject: [PATCH 08/17] clear the redundant comments --- .../light/standard/shapelets/polar.py | 65 +------------------ 1 file changed, 2 insertions(+), 63 deletions(-) diff --git a/autogalaxy/profiles/light/standard/shapelets/polar.py b/autogalaxy/profiles/light/standard/shapelets/polar.py index de8beeca5..05bb7e02e 100644 --- a/autogalaxy/profiles/light/standard/shapelets/polar.py +++ b/autogalaxy/profiles/light/standard/shapelets/polar.py @@ -2,8 +2,6 @@ from typing import Optional, Tuple import autoarray as aa -import autolens as al - from autogalaxy.profiles.light.decorators import ( check_operated_only, @@ -11,66 +9,9 @@ from autogalaxy.profiles.light.standard.shapelets.abstract import AbstractShapelet import jax.numpy as jnp -from jax import lax from jax.scipy.special import gammaln -def genlaguerre_jax_recurrence(n, alpha, x): - """ - Generalized (associated) Laguerre polynomial $L_n^{(\alpha)}(x)$ - calculated using the three-term recurrence relation in pure JAX. - - Optimized for JAX via `lax.fori_loop` for loop unrolling. - - .. math:: - (k+1) L_{k+1}^{(\alpha)}(x) = (2k + 1 + \alpha - x) L_k^{(\alpha)}(x) - (k + \alpha) L_{k-1}^{(\alpha)}(x) - - Parameters - ---------- - n : int - Degree of the polynomial. MUST be a non-negative static Python integer - for optimal JAX compilation. - alpha : Union[float, JAXArray] - Parameter $\alpha > -1$. - x : JAXArray - Input array (points at which to evaluate). - - Returns - ------- - L : JAXArray - Generalized Laguerre polynomial evaluated at x. - """ - - # --- 0. Input Validation (Requires static Python int n) --- - if not isinstance(n, int) or n < 0: - raise ValueError(f"Degree n must be a non-negative Python integer (static), got {n}.") - - # --- 1. Base Cases --- - L0 = jnp.ones_like(x) - if n == 0: - return L0 - - L1 = 1 + alpha - x - if n == 1: - return L1 - - # --- 2. JAX Recurrence Calculation --- - - def body(k, state): - # state = (L_{k-1}, L_k) - L_nm1, L_n = state - - # Recurrence relation: - # L_{k+1} = ((2k + 1 + alpha - x) * L_k - (k + alpha) * L_{k-1}) / (k + 1) - L_np1 = ((2 * k + 1 + alpha - x) * L_n - (k + alpha) * L_nm1) / (k + 1) - - # Return new state: (L_k, L_{k+1}) - return (L_n, L_np1) - - # fori_loop(start, stop, body, init_state) - _, res_Ln = lax.fori_loop(1, n, body, (L0, L1)) - return res_Ln - -def genlaguerre_jax_summation(n, alpha, x): +def genlaguerre_jax(n, alpha, x): """ Generalized (associated) Laguerre polynomial L_n^alpha(x) calculated using the explicit summation formula, optimized for JAX vectorization. @@ -202,10 +143,8 @@ def image_2d_from( image The image of the Polar Shapelet evaluated at every (y,x) coordinate on the transformed grid. """ - # from scipy.special import genlaguerre from jax.scipy.special import factorial - # laguerre = genlaguerre(n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m)) grid = aa.util.geometry.transform_grid_2d_to_reference_frame( grid_2d=grid.array, centre=self.centre, angle=self.phi, xp=xp ) @@ -223,7 +162,7 @@ def image_2d_from( m_abs = abs(self.m) n_laguerre = (self.n - m_abs) // 2 - laguerre_vals = genlaguerre_jax_summation(n=n_laguerre, alpha=m_abs, x=rsq) + laguerre_vals = genlaguerre_jax(n=n_laguerre, alpha=m_abs, x=rsq) radial = rsq ** (xp.abs(self.m) / 2.0) * xp.exp(-rsq / 2.0) * laguerre_vals From cdea5ae2def1b22e59deb7a8247684816713de18 Mon Sep 17 00:00:00 2001 From: Hengkai_Pmo_StarB Date: Sat, 13 Dec 2025 00:07:39 +0800 Subject: [PATCH 09/17] Merge branch 'main' into fix/ShapeletPolar_dPIEkappa --- autogalaxy/profiles/mass/dark/nfw.py | 3 ++- autogalaxy/profiles/mass/dark/nfw_truncated.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/autogalaxy/profiles/mass/dark/nfw.py b/autogalaxy/profiles/mass/dark/nfw.py index 10a2db506..ccf9a9eaf 100644 --- a/autogalaxy/profiles/mass/dark/nfw.py +++ b/autogalaxy/profiles/mass/dark/nfw.py @@ -381,7 +381,8 @@ def deflections_2d_via_analytic_from( """ eta = xp.multiply( - 1.0 / self.scale_radius, self.radial_grid_from(grid=grid, xp=xp, **kwargs).array + 1.0 / self.scale_radius, + self.radial_grid_from(grid=grid, xp=xp, **kwargs).array, ) deflection_grid = xp.multiply( diff --git a/autogalaxy/profiles/mass/dark/nfw_truncated.py b/autogalaxy/profiles/mass/dark/nfw_truncated.py index 56dbcb726..7a1ff50fb 100644 --- a/autogalaxy/profiles/mass/dark/nfw_truncated.py +++ b/autogalaxy/profiles/mass/dark/nfw_truncated.py @@ -39,7 +39,8 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): """ eta = xp.multiply( - 1.0 / self.scale_radius, self.radial_grid_from(grid=grid, xp=xp, **kwargs).array + 1.0 / self.scale_radius, + self.radial_grid_from(grid=grid, xp=xp, **kwargs).array, ) deflection_grid = xp.multiply( From 802c0e563032bc93ce65eb31e921d867b36a185f Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sat, 17 Jan 2026 13:20:11 +0000 Subject: [PATCH 10/17] return np and xp branch --- .../light/standard/shapelets/polar.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/autogalaxy/profiles/light/standard/shapelets/polar.py b/autogalaxy/profiles/light/standard/shapelets/polar.py index 05bb7e02e..c2f996ce4 100644 --- a/autogalaxy/profiles/light/standard/shapelets/polar.py +++ b/autogalaxy/profiles/light/standard/shapelets/polar.py @@ -8,9 +8,6 @@ ) from autogalaxy.profiles.light.standard.shapelets.abstract import AbstractShapelet -import jax.numpy as jnp -from jax.scipy.special import gammaln - def genlaguerre_jax(n, alpha, x): """ Generalized (associated) Laguerre polynomial L_n^alpha(x) @@ -21,6 +18,9 @@ def genlaguerre_jax(n, alpha, x): alpha (Numeric): Parameter alpha > -1. x (Array): Input array (evaluation points). """ + import jax.numpy as jnp + from jax.scipy.special import gammaln + # 0. Input Validation (Requires static Python int n) if not isinstance(n, int) or n < 0: # Use Python's math.isnan/isinf check if n is float, otherwise type error @@ -157,12 +157,21 @@ def image_2d_from( / self.beta / xp.sqrt(xp.pi) ) - rsq = (grid[:, 0] ** 2 + (grid[:, 1]/self.q) ** 2) / self.beta**2 - theta = xp.arctan2(grid[:, 1], grid[:, 0]) + rsq = (grid.array[:, 0] ** 2 + (grid.array[:, 1]/self.q) ** 2) / self.beta**2 + theta = xp.arctan2(grid.array[:, 1], grid.array[:, 0]) m_abs = abs(self.m) n_laguerre = (self.n - m_abs) // 2 - laguerre_vals = genlaguerre_jax(n=n_laguerre, alpha=m_abs, x=rsq) + + if xp is np: + + from scipy.special import genlaguerre + + laguerre_vals = genlaguerre(n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m)) + + else: + + laguerre_vals = genlaguerre_jax(n=n_laguerre, alpha=m_abs, x=rsq) radial = rsq ** (xp.abs(self.m) / 2.0) * xp.exp(-rsq / 2.0) * laguerre_vals From 9fcde1ce2fa35a0bc1b6cbf9067932043cfe3bc1 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sat, 17 Jan 2026 13:33:45 +0000 Subject: [PATCH 11/17] fix numpy shapelet implementation --- .../light/standard/shapelets/polar.py | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/autogalaxy/profiles/light/standard/shapelets/polar.py b/autogalaxy/profiles/light/standard/shapelets/polar.py index c2f996ce4..0863fa075 100644 --- a/autogalaxy/profiles/light/standard/shapelets/polar.py +++ b/autogalaxy/profiles/light/standard/shapelets/polar.py @@ -143,11 +143,19 @@ def image_2d_from( image The image of the Polar Shapelet evaluated at every (y,x) coordinate on the transformed grid. """ - from jax.scipy.special import factorial + if xp is np: + + from scipy.special import factorial + + else: + + from jax.scipy.special import factorial grid = aa.util.geometry.transform_grid_2d_to_reference_frame( grid_2d=grid.array, centre=self.centre, angle=self.phi, xp=xp ) + grid = aa.Grid2DIrregular(values=grid) + const = ( ((-1) ** ((self.n - xp.abs(self.m)) // 2)) * xp.sqrt( @@ -167,7 +175,8 @@ def image_2d_from( from scipy.special import genlaguerre - laguerre_vals = genlaguerre(n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m)) + laguerre = genlaguerre(n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m)) + laguerre_vals = laguerre(rsq) else: @@ -175,15 +184,19 @@ def image_2d_from( radial = rsq ** (xp.abs(self.m) / 2.0) * xp.exp(-rsq / 2.0) * laguerre_vals - if self.m == 0: - azimuthal = 1 - elif self.m > 0: - azimuthal = xp.sin((-1) * self.m * theta) - else: - azimuthal = xp.cos((-1) * self.m * theta) + m = self.m - return self._intensity * const * radial * azimuthal + azimuthal = xp.where( + m == 0, + xp.ones_like(theta), + xp.where( + m > 0, + xp.sin(-m * theta), + xp.cos(-m * theta), + ), + ) + return self._intensity * const * radial * azimuthal class ShapeletPolarSph(ShapeletPolar): def __init__( From 676940cd004e00fbe3047adf64bd8c409644d65d Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sat, 17 Jan 2026 13:47:08 +0000 Subject: [PATCH 12/17] added methods to convert to go from shapelet ell comp to axis ratio and q --- autogalaxy/convert.py | 126 ++++++++++++++++++ .../profiles/light/linear/shapelets/polar.py | 1 + .../light/standard/shapelets/polar.py | 1 + test_autogalaxy/test_convert.py | 6 + 4 files changed, 134 insertions(+) diff --git a/autogalaxy/convert.py b/autogalaxy/convert.py index 5bda71120..39fb00a22 100644 --- a/autogalaxy/convert.py +++ b/autogalaxy/convert.py @@ -337,3 +337,129 @@ def multipole_comps_from( multipole_comp_1 = k_m * xp.cos(phi_m * float(m) * units.deg.to(units.rad)) return (multipole_comp_0, multipole_comp_1) + +def shapelet_axis_ratio_and_phi_from( + shapelet_comps: Tuple[float, float], + xp=np, +) -> Tuple[float, float]: + """ + Returns the elliptical axis-ratio `q` and position angle `phi` (in degrees) from the shapelet + elliptical component parameters `shapelet_comps`. + + This conversion is intentionally identical in *spirit* to the `ell_comps` parameterization used + throughout PyAutoGalaxy: the circular case corresponds to (0.0, 0.0) at the centre of parameter space, + avoiding the sampling pathologies of directly sampling `(q, phi)` where the angle is undefined at `q=1`. + + For shapelets, these components define the *elliptical coordinate system* on which the basis functions + are evaluated. This is geometric ellipticity (an isophote ellipse), therefore we always enforce the + ellipse's 180 degree rotational symmetry by using an effective order m = 2 (and we do **not** accept + `m` as an input). + + The conversion is: + + 1) Convert components -> "ellipticity-like" amplitude and angle: + + .. math:: + e = \\sqrt{\\epsilon_1^2 + \\epsilon_2^2} + + .. math:: + \\phi = \\frac{1}{2} \\arctan2(\\epsilon_1, \\epsilon_2) + + 2) Map amplitude -> axis-ratio using the standard stable relation: + + .. math:: + e = \\frac{1 - q}{1 + q} \\;\\;\\Rightarrow\\;\\; + q = \\frac{1 - e}{1 + e} + + The returned `phi` is wrapped to prevent boundary hopping when computing marginalized error estimates + from posterior samples. The wrapping enforces a continuous interval analogous to the multipole + conversion logic, but with the fixed symmetry m = 2. + + Parameters + ---------- + shapelet_comps + The first and second components of the shapelet ellipticity. The circular limit is (0.0, 0.0). + These are unconstrained and can span (-inf, inf) during sampling. + xp + The array library used for the calculation (e.g. `numpy` or `jax.numpy`). + + Returns + ------- + axis_ratio + The axis-ratio of the elliptical coordinate system, with 0 < q <= 1. + phi + The position angle in degrees, measured counter-clockwise from the positive x-axis. + """ + eps_1, eps_2 = shapelet_comps + + # Ellipticity-like amplitude (0 at circular). Clip to keep q well-defined in (0, 1]. + e = xp.sqrt(eps_1 * eps_1 + eps_2 * eps_2) + e = xp.clip(e, 0.0, 1.0 - 1e-12) + + axis_ratio = (1.0 - e) / (1.0 + e) + + # Fixed symmetry for ellipses: m = 2. + phi = xp.arctan2(eps_1, eps_2) * 180.0 / xp.pi / 2.0 + + # Wrap phi to a continuous interval to avoid boundary hopping in posteriors. + # (Analogue of: phi_m = where(phi_m < -90/m, phi_m + 360/m, phi_m) with m=2.) + phi = xp.where(phi < -45.0, phi + 180.0, phi) + + return axis_ratio, phi + + +def shapelet_comps_from_axis_ratio_and_phi( + axis_ratio: float, + phi: float, + xp=np, +) -> Tuple[float, float]: + """ + Returns the shapelet elliptical component parameters `shapelet_comps` from an axis-ratio `q` + and position angle `phi` (in degrees). + + This is the inverse of `shapelet_axis_ratio_and_phi_from` and uses the same fixed ellipse symmetry + (effective order m = 2). The mapping is: + + 1) Convert axis-ratio -> ellipticity-like amplitude: + + .. math:: + e = \\frac{1 - q}{1 + q} + + 2) Convert amplitude and angle -> components: + + .. math:: + \\epsilon_1 = e \\, \\sin(2 \\phi) + + .. math:: + \\epsilon_2 = e \\, \\cos(2 \\phi) + + This ensures: + + .. math:: + \\phi = \\frac{1}{2} \\arctan2(\\epsilon_1, \\epsilon_2) + + Parameters + ---------- + axis_ratio + The axis-ratio of the elliptical coordinate system, with 0 < q <= 1. + phi + The position angle in degrees, measured counter-clockwise from the positive x-axis. + xp + The array library used for the calculation (e.g. `numpy` or `jax.numpy`). + + Returns + ------- + shapelet_comps + The first and second components of the shapelet ellipticity, where the circular limit is (0.0, 0.0). + """ + axis_ratio = xp.clip(axis_ratio, 1e-12, 1.0) + + e = (1.0 - axis_ratio) / (1.0 + axis_ratio) + + # Fixed symmetry for ellipses: m = 2. + ang = 2.0 * phi * xp.pi / 180.0 + + eps_1 = e * xp.sin(ang) + eps_2 = e * xp.cos(ang) + + return (eps_1, eps_2) diff --git a/autogalaxy/profiles/light/linear/shapelets/polar.py b/autogalaxy/profiles/light/linear/shapelets/polar.py index 8ad67d462..15f34468f 100644 --- a/autogalaxy/profiles/light/linear/shapelets/polar.py +++ b/autogalaxy/profiles/light/linear/shapelets/polar.py @@ -13,6 +13,7 @@ def __init__( centre: Tuple[float, float] = (0.0, 0.0), q: float = 1.0, phi: float = 0.0, + ell_comps: Tuple[float, float] = (0.0, 0.0), beta: float = 1.0, ): """ diff --git a/autogalaxy/profiles/light/standard/shapelets/polar.py b/autogalaxy/profiles/light/standard/shapelets/polar.py index 0863fa075..96b523173 100644 --- a/autogalaxy/profiles/light/standard/shapelets/polar.py +++ b/autogalaxy/profiles/light/standard/shapelets/polar.py @@ -70,6 +70,7 @@ def __init__( centre: Tuple[float, float] = (0.0, 0.0), q: float = 1.0, phi: float = 0.0, + ell_comps: Tuple[float, float] = (0.0, 0.0), intensity: float = 1.0, beta: float = 1.0, ): diff --git a/test_autogalaxy/test_convert.py b/test_autogalaxy/test_convert.py index 5250864e0..a1e00f3a6 100644 --- a/test_autogalaxy/test_convert.py +++ b/test_autogalaxy/test_convert.py @@ -170,3 +170,9 @@ def test__multipole_comps_from(): multipole_comps = ag.convert.multipole_comps_from(k_m=0.14142135, phi_m=112.5, m=2) assert multipole_comps == pytest.approx((-0.1, -0.1), abs=1e-3) + + +def test__shapelet_axis_ratio_and_phi_from(): + + axis_ratio, phi = ag.convert.shapelet_axis_ratio_and_phi_from( + \ No newline at end of file From 898142b1165f27e11b334cb0bca3e2136ee64fd6 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sat, 17 Jan 2026 13:50:09 +0000 Subject: [PATCH 13/17] test__shapelet_axis_ratio_and_phi_from --- autogalaxy/convert.py | 14 +++++++------- test_autogalaxy/test_convert.py | 27 ++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/autogalaxy/convert.py b/autogalaxy/convert.py index 39fb00a22..d8897ab78 100644 --- a/autogalaxy/convert.py +++ b/autogalaxy/convert.py @@ -339,12 +339,12 @@ def multipole_comps_from( return (multipole_comp_0, multipole_comp_1) def shapelet_axis_ratio_and_phi_from( - shapelet_comps: Tuple[float, float], + ell_comps: Tuple[float, float], xp=np, ) -> Tuple[float, float]: """ Returns the elliptical axis-ratio `q` and position angle `phi` (in degrees) from the shapelet - elliptical component parameters `shapelet_comps`. + elliptical component parameters `ell_comps`. This conversion is intentionally identical in *spirit* to the `ell_comps` parameterization used throughout PyAutoGalaxy: the circular case corresponds to (0.0, 0.0) at the centre of parameter space, @@ -377,7 +377,7 @@ def shapelet_axis_ratio_and_phi_from( Parameters ---------- - shapelet_comps + ell_comps The first and second components of the shapelet ellipticity. The circular limit is (0.0, 0.0). These are unconstrained and can span (-inf, inf) during sampling. xp @@ -390,7 +390,7 @@ def shapelet_axis_ratio_and_phi_from( phi The position angle in degrees, measured counter-clockwise from the positive x-axis. """ - eps_1, eps_2 = shapelet_comps + eps_1, eps_2 = ell_comps # Ellipticity-like amplitude (0 at circular). Clip to keep q well-defined in (0, 1]. e = xp.sqrt(eps_1 * eps_1 + eps_2 * eps_2) @@ -408,13 +408,13 @@ def shapelet_axis_ratio_and_phi_from( return axis_ratio, phi -def shapelet_comps_from_axis_ratio_and_phi( +def ell_comps_from_axis_ratio_and_phi( axis_ratio: float, phi: float, xp=np, ) -> Tuple[float, float]: """ - Returns the shapelet elliptical component parameters `shapelet_comps` from an axis-ratio `q` + Returns the shapelet elliptical component parameters `ell_comps` from an axis-ratio `q` and position angle `phi` (in degrees). This is the inverse of `shapelet_axis_ratio_and_phi_from` and uses the same fixed ellipse symmetry @@ -449,7 +449,7 @@ def shapelet_comps_from_axis_ratio_and_phi( Returns ------- - shapelet_comps + ell_comps The first and second components of the shapelet ellipticity, where the circular limit is (0.0, 0.0). """ axis_ratio = xp.clip(axis_ratio, 1e-12, 1.0) diff --git a/test_autogalaxy/test_convert.py b/test_autogalaxy/test_convert.py index a1e00f3a6..bf15b0e09 100644 --- a/test_autogalaxy/test_convert.py +++ b/test_autogalaxy/test_convert.py @@ -175,4 +175,29 @@ def test__multipole_comps_from(): def test__shapelet_axis_ratio_and_phi_from(): axis_ratio, phi = ag.convert.shapelet_axis_ratio_and_phi_from( - \ No newline at end of file + ell_comps=(0.0, 0.5) + ) + + assert axis_ratio == pytest.approx(0.3333333, abs=1e-4) + assert phi == pytest.approx(0.0, abs=1e-4) + + axis_ratio, phi = ag.convert.shapelet_axis_ratio_and_phi_from( + ell_comps=(0.5, 0.0) + ) + + assert axis_ratio == pytest.approx(0.3333333, abs=1e-4) + assert phi == pytest.approx(45.0, abs=1e-4) + + axis_ratio, phi = ag.convert.shapelet_axis_ratio_and_phi_from( + ell_comps=(0.0, -0.5) + ) + + assert axis_ratio == pytest.approx(0.3333333, abs=1e-4) + assert phi == pytest.approx(90.0, abs=1e-4) + + axis_ratio, phi = ag.convert.shapelet_axis_ratio_and_phi_from( + ell_comps=(-0.5, 0.0) + ) + + assert axis_ratio == pytest.approx(0.3333333, abs=1e-4) + assert phi == pytest.approx(-45.0, abs=1e-4) \ No newline at end of file From 9656fa044973c27ffcc90ea4dae4fd405997289c Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sat, 17 Jan 2026 13:51:09 +0000 Subject: [PATCH 14/17] test__shapelet_ell_comps_from_axis_ratio_and_phi --- autogalaxy/convert.py | 2 +- test_autogalaxy/test_convert.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/autogalaxy/convert.py b/autogalaxy/convert.py index d8897ab78..15e2d8e94 100644 --- a/autogalaxy/convert.py +++ b/autogalaxy/convert.py @@ -408,7 +408,7 @@ def shapelet_axis_ratio_and_phi_from( return axis_ratio, phi -def ell_comps_from_axis_ratio_and_phi( +def shapelet_ell_comps_from_axis_ratio_and_phi( axis_ratio: float, phi: float, xp=np, diff --git a/test_autogalaxy/test_convert.py b/test_autogalaxy/test_convert.py index bf15b0e09..b1ea51dee 100644 --- a/test_autogalaxy/test_convert.py +++ b/test_autogalaxy/test_convert.py @@ -200,4 +200,31 @@ def test__shapelet_axis_ratio_and_phi_from(): ) assert axis_ratio == pytest.approx(0.3333333, abs=1e-4) - assert phi == pytest.approx(-45.0, abs=1e-4) \ No newline at end of file + assert phi == pytest.approx(-45.0, abs=1e-4) + + +def test__shapelet_ell_comps_from_axis_ratio_and_phi(): + + ell_comps = ag.convert.shapelet_ell_comps_from_axis_ratio_and_phi( + axis_ratio=0.3333333, phi=0.0 + ) + + assert ell_comps == pytest.approx((0.0, 0.5), abs=1e-4) + + ell_comps = ag.convert.shapelet_ell_comps_from_axis_ratio_and_phi( + axis_ratio=0.3333333, phi=45.0 + ) + + assert ell_comps == pytest.approx((0.5, 0.0), abs=1e-4) + + ell_comps = ag.convert.shapelet_ell_comps_from_axis_ratio_and_phi( + axis_ratio=0.3333333, phi=90.0 + ) + + assert ell_comps == pytest.approx((0.0, -0.5), abs=1e-4) + + ell_comps = ag.convert.shapelet_ell_comps_from_axis_ratio_and_phi( + axis_ratio=0.3333333, phi=-45.0 + ) + + assert ell_comps == pytest.approx((-0.5, 0.0), abs=1e-4) \ No newline at end of file From 4c7a22fbdff6b4f8d61571b44d0add4e41d32b90 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sat, 17 Jan 2026 14:35:52 +0000 Subject: [PATCH 15/17] simplify stretchihng --- autogalaxy/convert.py | 128 +----------------- .../profiles/light/linear/shapelets/polar.py | 6 +- .../light/standard/shapelets/polar.py | 100 +++++++------- .../profiles/light/shapelets/test_polar.py | 4 +- test_autogalaxy/test_convert.py | 58 -------- 5 files changed, 53 insertions(+), 243 deletions(-) diff --git a/autogalaxy/convert.py b/autogalaxy/convert.py index 15e2d8e94..a0d963d4c 100644 --- a/autogalaxy/convert.py +++ b/autogalaxy/convert.py @@ -336,130 +336,4 @@ def multipole_comps_from( multipole_comp_0 = k_m * xp.sin(phi_m * float(m) * units.deg.to(units.rad)) multipole_comp_1 = k_m * xp.cos(phi_m * float(m) * units.deg.to(units.rad)) - return (multipole_comp_0, multipole_comp_1) - -def shapelet_axis_ratio_and_phi_from( - ell_comps: Tuple[float, float], - xp=np, -) -> Tuple[float, float]: - """ - Returns the elliptical axis-ratio `q` and position angle `phi` (in degrees) from the shapelet - elliptical component parameters `ell_comps`. - - This conversion is intentionally identical in *spirit* to the `ell_comps` parameterization used - throughout PyAutoGalaxy: the circular case corresponds to (0.0, 0.0) at the centre of parameter space, - avoiding the sampling pathologies of directly sampling `(q, phi)` where the angle is undefined at `q=1`. - - For shapelets, these components define the *elliptical coordinate system* on which the basis functions - are evaluated. This is geometric ellipticity (an isophote ellipse), therefore we always enforce the - ellipse's 180 degree rotational symmetry by using an effective order m = 2 (and we do **not** accept - `m` as an input). - - The conversion is: - - 1) Convert components -> "ellipticity-like" amplitude and angle: - - .. math:: - e = \\sqrt{\\epsilon_1^2 + \\epsilon_2^2} - - .. math:: - \\phi = \\frac{1}{2} \\arctan2(\\epsilon_1, \\epsilon_2) - - 2) Map amplitude -> axis-ratio using the standard stable relation: - - .. math:: - e = \\frac{1 - q}{1 + q} \\;\\;\\Rightarrow\\;\\; - q = \\frac{1 - e}{1 + e} - - The returned `phi` is wrapped to prevent boundary hopping when computing marginalized error estimates - from posterior samples. The wrapping enforces a continuous interval analogous to the multipole - conversion logic, but with the fixed symmetry m = 2. - - Parameters - ---------- - ell_comps - The first and second components of the shapelet ellipticity. The circular limit is (0.0, 0.0). - These are unconstrained and can span (-inf, inf) during sampling. - xp - The array library used for the calculation (e.g. `numpy` or `jax.numpy`). - - Returns - ------- - axis_ratio - The axis-ratio of the elliptical coordinate system, with 0 < q <= 1. - phi - The position angle in degrees, measured counter-clockwise from the positive x-axis. - """ - eps_1, eps_2 = ell_comps - - # Ellipticity-like amplitude (0 at circular). Clip to keep q well-defined in (0, 1]. - e = xp.sqrt(eps_1 * eps_1 + eps_2 * eps_2) - e = xp.clip(e, 0.0, 1.0 - 1e-12) - - axis_ratio = (1.0 - e) / (1.0 + e) - - # Fixed symmetry for ellipses: m = 2. - phi = xp.arctan2(eps_1, eps_2) * 180.0 / xp.pi / 2.0 - - # Wrap phi to a continuous interval to avoid boundary hopping in posteriors. - # (Analogue of: phi_m = where(phi_m < -90/m, phi_m + 360/m, phi_m) with m=2.) - phi = xp.where(phi < -45.0, phi + 180.0, phi) - - return axis_ratio, phi - - -def shapelet_ell_comps_from_axis_ratio_and_phi( - axis_ratio: float, - phi: float, - xp=np, -) -> Tuple[float, float]: - """ - Returns the shapelet elliptical component parameters `ell_comps` from an axis-ratio `q` - and position angle `phi` (in degrees). - - This is the inverse of `shapelet_axis_ratio_and_phi_from` and uses the same fixed ellipse symmetry - (effective order m = 2). The mapping is: - - 1) Convert axis-ratio -> ellipticity-like amplitude: - - .. math:: - e = \\frac{1 - q}{1 + q} - - 2) Convert amplitude and angle -> components: - - .. math:: - \\epsilon_1 = e \\, \\sin(2 \\phi) - - .. math:: - \\epsilon_2 = e \\, \\cos(2 \\phi) - - This ensures: - - .. math:: - \\phi = \\frac{1}{2} \\arctan2(\\epsilon_1, \\epsilon_2) - - Parameters - ---------- - axis_ratio - The axis-ratio of the elliptical coordinate system, with 0 < q <= 1. - phi - The position angle in degrees, measured counter-clockwise from the positive x-axis. - xp - The array library used for the calculation (e.g. `numpy` or `jax.numpy`). - - Returns - ------- - ell_comps - The first and second components of the shapelet ellipticity, where the circular limit is (0.0, 0.0). - """ - axis_ratio = xp.clip(axis_ratio, 1e-12, 1.0) - - e = (1.0 - axis_ratio) / (1.0 + axis_ratio) - - # Fixed symmetry for ellipses: m = 2. - ang = 2.0 * phi * xp.pi / 180.0 - - eps_1 = e * xp.sin(ang) - eps_2 = e * xp.cos(ang) - - return (eps_1, eps_2) + return (multipole_comp_0, multipole_comp_1) \ No newline at end of file diff --git a/autogalaxy/profiles/light/linear/shapelets/polar.py b/autogalaxy/profiles/light/linear/shapelets/polar.py index 15f34468f..1cb57c6fd 100644 --- a/autogalaxy/profiles/light/linear/shapelets/polar.py +++ b/autogalaxy/profiles/light/linear/shapelets/polar.py @@ -11,8 +11,6 @@ def __init__( n: int, m: int, centre: Tuple[float, float] = (0.0, 0.0), - q: float = 1.0, - phi: float = 0.0, ell_comps: Tuple[float, float] = (0.0, 0.0), beta: float = 1.0, ): @@ -48,7 +46,7 @@ def __init__( """ super().__init__( - n=n, m=m, centre=centre, q=q, phi=phi, beta=beta, intensity=1.0 + n=n, m=m, centre=centre, ell_comps=ell_comps, beta=beta, intensity=1.0 ) @@ -87,4 +85,4 @@ def __init__( The characteristic length scale of the shapelet basis function, defined in arc-seconds. """ - super().__init__(n=n, m=m, centre=centre, q=1.0, phi=phi, beta=beta) + super().__init__(n=n, m=m, centre=centre, beta=beta) diff --git a/autogalaxy/profiles/light/standard/shapelets/polar.py b/autogalaxy/profiles/light/standard/shapelets/polar.py index 96b523173..ba60d2e74 100644 --- a/autogalaxy/profiles/light/standard/shapelets/polar.py +++ b/autogalaxy/profiles/light/standard/shapelets/polar.py @@ -6,11 +6,13 @@ from autogalaxy.profiles.light.decorators import ( check_operated_only, ) +from autogalaxy import convert from autogalaxy.profiles.light.standard.shapelets.abstract import AbstractShapelet + def genlaguerre_jax(n, alpha, x): """ - Generalized (associated) Laguerre polynomial L_n^alpha(x) + Generalized (associated) Laguerre polynomial L_n^alpha(x) calculated using the explicit summation formula, optimized for JAX vectorization. Parameters: @@ -23,38 +25,38 @@ def genlaguerre_jax(n, alpha, x): # 0. Input Validation (Requires static Python int n) if not isinstance(n, int) or n < 0: - # Use Python's math.isnan/isinf check if n is float, otherwise type error - raise ValueError(f"Degree n must be a non-negative Python integer (static), got {n}.") - + # Use Python's math.isnan/isinf check if n is float, otherwise type error + raise ValueError(f"Degree n must be a non-negative Python integer (static), got {n}.") + # Base Case L0 if n == 0: return jnp.ones_like(x) # 1. Generate k values for summation range [0, 1, 2, ..., n] - k_values = jnp.arange(n + 1) # (n+1,) + k_values = jnp.arange(n + 1) # (n+1,) # 2. Reshape inputs for broadcasting (x: (M, 1), k: (1, n+1)) - x_expanded = jnp.expand_dims(x, axis=-1) + x_expanded = jnp.expand_dims(x, axis=-1) k_values_expanded = jnp.expand_dims(k_values, axis=0) # --- A. Binomial Factor (BF) Calculation --- # BF = exp( log( (n+alpha)! / ((n-k)! * (alpha+k)!) ) ) - + log_N_plus_alpha_fact = gammaln(n + alpha + 1) - + log_BF_k = ( - log_N_plus_alpha_fact - - gammaln(n - k_values + 1) # log( (n-k)! ) - - gammaln(alpha + k_values + 1) # log( (alpha+k)! ) + log_N_plus_alpha_fact + - gammaln(n - k_values + 1) # log( (n-k)! ) + - gammaln(alpha + k_values + 1) # log( (alpha+k)! ) ) - - BF_k = jnp.exp(log_BF_k) # Shape: (n+1,) + + BF_k = jnp.exp(log_BF_k) # Shape: (n+1,) # --- B. Term Factor (TF) Calculation --- # TF = (-x)^k / k! - + # Note: jnp.math.gamma(k_values + 1) is equivalent to k! in log-gamma space - TF_k = jnp.power(-x_expanded, k_values_expanded) / jnp.exp(gammaln(k_values_expanded + 1)) + TF_k = jnp.power(-x_expanded, k_values_expanded) / jnp.exp(gammaln(k_values_expanded + 1)) # TF_k Shape: (M, n+1) # --- C. Final Summation --- @@ -62,17 +64,16 @@ def genlaguerre_jax(n, alpha, x): # BF_k broadcasts over the M dimension of TF_k return jnp.sum(BF_k * TF_k, axis=1) + class ShapeletPolar(AbstractShapelet): def __init__( - self, - n: int, - m: int, - centre: Tuple[float, float] = (0.0, 0.0), - q: float = 1.0, - phi: float = 0.0, - ell_comps: Tuple[float, float] = (0.0, 0.0), - intensity: float = 1.0, - beta: float = 1.0, + self, + n: int, + m: int, + centre: Tuple[float, float] = (0.0, 0.0), + ell_comps: Tuple[float, float] = (0.0, 0.0), + intensity: float = 1.0, + beta: float = 1.0, ): """ Shapelets where the basis function is defined according to a Polar (r,theta) grid of coordinates. @@ -96,7 +97,7 @@ def __init__( q The axis-ratio of the elliptical coordinate system, where a perfect circle has q=1.0. phi - The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the + The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the positive x-axis. intensity Overall intensity normalisation of the light profile (units are dimensionless and derived from the data @@ -107,11 +108,9 @@ def __init__( self.n = int(n) self.m = int(m) - self.phi = float(phi) - self.q = float(q) super().__init__( - centre=centre, beta=beta, intensity=intensity + centre=centre, ell_comps=ell_comps, beta=beta, intensity=intensity ) @property @@ -121,6 +120,7 @@ def coefficient_tag(self) -> str: @aa.over_sample @aa.grid_dec.to_array @check_operated_only + @aa.grid_dec.transform def image_2d_from( self, grid: aa.type.Grid2DLike, @@ -152,22 +152,20 @@ def image_2d_from( from jax.scipy.special import factorial - grid = aa.util.geometry.transform_grid_2d_to_reference_frame( - grid_2d=grid.array, centre=self.centre, angle=self.phi, xp=xp - ) - grid = aa.Grid2DIrregular(values=grid) - const = ( - ((-1) ** ((self.n - xp.abs(self.m)) // 2)) - * xp.sqrt( - factorial((self.n - xp.abs(self.m)) // 2) - / factorial((self.n + xp.abs(self.m)) // 2) - ) - / self.beta - / xp.sqrt(xp.pi) + ((-1) ** ((self.n - xp.abs(self.m)) // 2)) + * xp.sqrt( + factorial((self.n - xp.abs(self.m)) // 2) + / factorial((self.n + xp.abs(self.m)) // 2) + ) + / self.beta + / xp.sqrt(xp.pi) ) - rsq = (grid.array[:, 0] ** 2 + (grid.array[:, 1]/self.q) ** 2) / self.beta**2 - theta = xp.arctan2(grid.array[:, 1], grid.array[:, 0]) + y = grid.array[:, 0] + x = grid.array[:, 1] + + rsq = (x ** 2 + (y / self.axis_ratio(xp)) ** 2) / self.beta ** 2 + theta = xp.arctan2(y, x) m_abs = abs(self.m) n_laguerre = (self.n - m_abs) // 2 @@ -199,15 +197,15 @@ def image_2d_from( return self._intensity * const * radial * azimuthal + class ShapeletPolarSph(ShapeletPolar): def __init__( - self, - n: int, - m: int, - centre: Tuple[float, float] = (0.0, 0.0), - phi: float = 0.0, - intensity: float = 1.0, - beta: float = 1.0, + self, + n: int, + m: int, + centre: Tuple[float, float] = (0.0, 0.0), + intensity: float = 1.0, + beta: float = 1.0, ): """ Shapelets where the basis function is defined according to a Polar (r,theta) grid of coordinates. @@ -229,7 +227,7 @@ def __init__( centre The (y,x) arc-second coordinates of the profile (shapelet) centre. phi - The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the + The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the positive x-axis. intensity Overall intensity normalisation of the light profile (units are dimensionless and derived from the data @@ -242,8 +240,6 @@ def __init__( n=n, m=m, centre=centre, - q=1.0, - phi=phi, intensity=intensity, beta=beta, ) diff --git a/test_autogalaxy/profiles/light/shapelets/test_polar.py b/test_autogalaxy/profiles/light/shapelets/test_polar.py index eef873e91..90a90500e 100644 --- a/test_autogalaxy/profiles/light/shapelets/test_polar.py +++ b/test_autogalaxy/profiles/light/shapelets/test_polar.py @@ -26,7 +26,7 @@ def test__elliptical__image_2d_from(): image = shapelet.image_2d_from(grid=ag.Grid2DIrregular([[0.0, 1.0], [0.5, 0.25]])) - assert image == pytest.approx(np.array([0.0, -0.33177]), abs=1e-4) + assert image == pytest.approx(np.array([0.02577349206014249, -0.17434262753]), abs=1e-4) shapelet = ag.lp_linear.ShapeletPolar( n=2, m=0, centre=(0.0, 0.0), ell_comps=(0.5, 0.7), beta=1.0 @@ -34,4 +34,4 @@ def test__elliptical__image_2d_from(): image = shapelet.image_2d_from(grid=ag.Grid2DIrregular([[0.0, 1.0], [0.5, 0.25]])) - assert image == pytest.approx(np.array([0.0, -0.33177]), abs=1e-4) + assert image == pytest.approx(np.array([0.001538188813, 0.0]), abs=1e-4) diff --git a/test_autogalaxy/test_convert.py b/test_autogalaxy/test_convert.py index b1ea51dee..5250864e0 100644 --- a/test_autogalaxy/test_convert.py +++ b/test_autogalaxy/test_convert.py @@ -170,61 +170,3 @@ def test__multipole_comps_from(): multipole_comps = ag.convert.multipole_comps_from(k_m=0.14142135, phi_m=112.5, m=2) assert multipole_comps == pytest.approx((-0.1, -0.1), abs=1e-3) - - -def test__shapelet_axis_ratio_and_phi_from(): - - axis_ratio, phi = ag.convert.shapelet_axis_ratio_and_phi_from( - ell_comps=(0.0, 0.5) - ) - - assert axis_ratio == pytest.approx(0.3333333, abs=1e-4) - assert phi == pytest.approx(0.0, abs=1e-4) - - axis_ratio, phi = ag.convert.shapelet_axis_ratio_and_phi_from( - ell_comps=(0.5, 0.0) - ) - - assert axis_ratio == pytest.approx(0.3333333, abs=1e-4) - assert phi == pytest.approx(45.0, abs=1e-4) - - axis_ratio, phi = ag.convert.shapelet_axis_ratio_and_phi_from( - ell_comps=(0.0, -0.5) - ) - - assert axis_ratio == pytest.approx(0.3333333, abs=1e-4) - assert phi == pytest.approx(90.0, abs=1e-4) - - axis_ratio, phi = ag.convert.shapelet_axis_ratio_and_phi_from( - ell_comps=(-0.5, 0.0) - ) - - assert axis_ratio == pytest.approx(0.3333333, abs=1e-4) - assert phi == pytest.approx(-45.0, abs=1e-4) - - -def test__shapelet_ell_comps_from_axis_ratio_and_phi(): - - ell_comps = ag.convert.shapelet_ell_comps_from_axis_ratio_and_phi( - axis_ratio=0.3333333, phi=0.0 - ) - - assert ell_comps == pytest.approx((0.0, 0.5), abs=1e-4) - - ell_comps = ag.convert.shapelet_ell_comps_from_axis_ratio_and_phi( - axis_ratio=0.3333333, phi=45.0 - ) - - assert ell_comps == pytest.approx((0.5, 0.0), abs=1e-4) - - ell_comps = ag.convert.shapelet_ell_comps_from_axis_ratio_and_phi( - axis_ratio=0.3333333, phi=90.0 - ) - - assert ell_comps == pytest.approx((0.0, -0.5), abs=1e-4) - - ell_comps = ag.convert.shapelet_ell_comps_from_axis_ratio_and_phi( - axis_ratio=0.3333333, phi=-45.0 - ) - - assert ell_comps == pytest.approx((-0.5, 0.0), abs=1e-4) \ No newline at end of file From ddc41e6c6df9678a8bdb5c3408fb003c251ded7e Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sat, 17 Jan 2026 14:44:46 +0000 Subject: [PATCH 16/17] fix shapelet cartesian --- .../light/standard/shapelets/cartesian.py | 95 +++++++++++-------- .../light/standard/shapelets/polar.py | 4 +- .../light/shapelets/test_cartesian.py | 4 +- 3 files changed, 61 insertions(+), 42 deletions(-) diff --git a/autogalaxy/profiles/light/standard/shapelets/cartesian.py b/autogalaxy/profiles/light/standard/shapelets/cartesian.py index 9be1b123c..906891e31 100644 --- a/autogalaxy/profiles/light/standard/shapelets/cartesian.py +++ b/autogalaxy/profiles/light/standard/shapelets/cartesian.py @@ -8,6 +8,32 @@ ) from autogalaxy.profiles.light.standard.shapelets.abstract import AbstractShapelet +def hermite_phys(n: int, x, xp=np): + """ + Physicists' Hermite polynomial H_n(x), compatible with NumPy and JAX via `xp`. + + Recurrence: + H_0(x) = 1 + H_1(x) = 2x + H_{n+1}(x) = 2x H_n(x) - 2n H_{n-1}(x) + """ + if n < 0: + raise ValueError("n must be >= 0") + + H0 = xp.ones_like(x) + if n == 0: + return H0 + + H1 = 2.0 * x + if n == 1: + return H1 + + Hnm1 = H0 + Hn = H1 + for k in range(1, n): + Hnp1 = 2.0 * x * Hn - 2.0 * float(k) * Hnm1 + Hnm1, Hn = Hn, Hnp1 + return Hn class ShapeletCartesian(AbstractShapelet): def __init__( @@ -63,54 +89,47 @@ def coefficient_tag(self) -> str: @check_operated_only @aa.grid_dec.transform def image_2d_from( - self, - grid: aa.type.Grid2DLike, - xp=np, - operated_only: Optional[bool] = None, - **kwargs, + self, + grid: aa.type.Grid2DLike, + xp=np, + operated_only: Optional[bool] = None, + **kwargs, ) -> np.ndarray: """ Returns the Cartesian Shapelet light profile's 2D image from a 2D grid of Cartesian (y,x) coordinates. - - If the coordinates have not been transformed to the profile's geometry (e.g. translated to the - profile `centre`), this is performed automatically. - - Parameters - ---------- - grid - The 2D (y, x) coordinates in the original reference frame of the grid. - - Returns - ------- - image - The image of the Cartesian Shapelet evaluated at every (y,x) coordinate on the transformed grid. """ - from jax.scipy.special import factorial - from scipy.special import hermite - hermite_y = hermite(n=self.n_y) - hermite_x = hermite(n=self.n_x) + # factorial backend switch + if xp is np: + from scipy.special import factorial + else: + from jax.scipy.special import factorial y = grid.array[:, 0] x = grid.array[:, 1] - shapelet_y = hermite_y(y / self.beta) - shapelet_x = hermite_x(x / self.beta) - - return ( - shapelet_y - * shapelet_x - * xp.exp(-0.5 * (y**2 + x**2) / (self.beta**2)) - / self.beta - / ( - xp.sqrt( - 2 ** (self.n_x + self.n_y) - * (xp.pi) - * factorial(self.n_y) - * factorial(self.n_x) - ) - ) + # Apply axis-ratio stretching (minor axis) + q = self.axis_ratio(xp) + y_ell = y / q + x_ell = x + + # Evaluate Hermite polynomials (JAX-safe) + shapelet_y = hermite_phys(self.n_y, y_ell / self.beta, xp=xp) + shapelet_x = hermite_phys(self.n_x, x_ell / self.beta, xp=xp) + + gaussian = xp.exp(-0.5 * (x_ell ** 2 + y_ell ** 2) / (self.beta ** 2)) + + norm = ( + self.beta + * xp.sqrt( + (2.0 ** (self.n_x + self.n_y)) + * xp.pi + * factorial(self.n_y) + * factorial(self.n_x) ) + ) + + return (shapelet_y * shapelet_x * gaussian) / norm class ShapeletCartesianSph(ShapeletCartesian): diff --git a/autogalaxy/profiles/light/standard/shapelets/polar.py b/autogalaxy/profiles/light/standard/shapelets/polar.py index ba60d2e74..a88a2196c 100644 --- a/autogalaxy/profiles/light/standard/shapelets/polar.py +++ b/autogalaxy/profiles/light/standard/shapelets/polar.py @@ -155,8 +155,8 @@ def image_2d_from( const = ( ((-1) ** ((self.n - xp.abs(self.m)) // 2)) * xp.sqrt( - factorial((self.n - xp.abs(self.m)) // 2) - / factorial((self.n + xp.abs(self.m)) // 2) + factorial((self.n - xp.abs(self.m)) // 2) + / factorial((self.n + xp.abs(self.m)) // 2) ) / self.beta / xp.sqrt(xp.pi) diff --git a/test_autogalaxy/profiles/light/shapelets/test_cartesian.py b/test_autogalaxy/profiles/light/shapelets/test_cartesian.py index ad4dbf2f0..b6bb9b3fe 100644 --- a/test_autogalaxy/profiles/light/shapelets/test_cartesian.py +++ b/test_autogalaxy/profiles/light/shapelets/test_cartesian.py @@ -30,7 +30,7 @@ def test__elliptical__image_2d_from(): image = shapelet.image_2d_from(grid=ag.Grid2DIrregular([[0.0, 1.0], [0.5, 0.25]])) - assert image == pytest.approx(np.array([0.13444, 0.122273]), 1e-4) + assert image == pytest.approx(np.array([0.1066423886714124, 0.014346163370]), 1e-4) shapelet = ag.lp_linear.ShapeletCartesian( n_y=2, n_x=3, centre=(0.0, 0.0), ell_comps=(0.2, 0.3), beta=1.0 @@ -38,4 +38,4 @@ def test__elliptical__image_2d_from(): image = shapelet.image_2d_from(grid=ag.Grid2DIrregular([[0.0, 1.0], [0.5, 0.25]])) - assert image == pytest.approx(np.array([0.12993, 0.13719]), 1e-4) + assert image == pytest.approx(np.array([0.0322749900, -0.075487038]), 1e-4) From 4121217d67458ec36366686d1eb341393affe562 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sat, 17 Jan 2026 14:52:11 +0000 Subject: [PATCH 17/17] final cartesian fix --- autogalaxy/convert.py | 2 +- .../profiles/light/linear/shapelets/polar.py | 4 +- .../light/standard/shapelets/cartesian.py | 21 ++++--- .../light/standard/shapelets/polar.py | 56 ++++++++++--------- .../mass/total/dual_pseudo_isothermal_mass.py | 50 +++++++++++------ .../profiles/light/shapelets/test_polar.py | 4 +- 6 files changed, 79 insertions(+), 58 deletions(-) diff --git a/autogalaxy/convert.py b/autogalaxy/convert.py index a0d963d4c..5bda71120 100644 --- a/autogalaxy/convert.py +++ b/autogalaxy/convert.py @@ -336,4 +336,4 @@ def multipole_comps_from( multipole_comp_0 = k_m * xp.sin(phi_m * float(m) * units.deg.to(units.rad)) multipole_comp_1 = k_m * xp.cos(phi_m * float(m) * units.deg.to(units.rad)) - return (multipole_comp_0, multipole_comp_1) \ No newline at end of file + return (multipole_comp_0, multipole_comp_1) diff --git a/autogalaxy/profiles/light/linear/shapelets/polar.py b/autogalaxy/profiles/light/linear/shapelets/polar.py index 1cb57c6fd..1898ceaf2 100644 --- a/autogalaxy/profiles/light/linear/shapelets/polar.py +++ b/autogalaxy/profiles/light/linear/shapelets/polar.py @@ -36,7 +36,7 @@ def __init__( q The axis-ratio of the elliptical coordinate system, where a perfect circle has q=1.0. phi - The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the + The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the positive x-axis. intensity Overall intensity normalisation of the light profile (units are dimensionless and derived from the data @@ -79,7 +79,7 @@ def __init__( centre The (y,x) arc-second coordinates of the profile (shapelet) centre. phi - The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the + The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the positive x-axis. beta The characteristic length scale of the shapelet basis function, defined in arc-seconds. diff --git a/autogalaxy/profiles/light/standard/shapelets/cartesian.py b/autogalaxy/profiles/light/standard/shapelets/cartesian.py index 906891e31..95b7c8d8f 100644 --- a/autogalaxy/profiles/light/standard/shapelets/cartesian.py +++ b/autogalaxy/profiles/light/standard/shapelets/cartesian.py @@ -8,6 +8,7 @@ ) from autogalaxy.profiles.light.standard.shapelets.abstract import AbstractShapelet + def hermite_phys(n: int, x, xp=np): """ Physicists' Hermite polynomial H_n(x), compatible with NumPy and JAX via `xp`. @@ -35,6 +36,7 @@ def hermite_phys(n: int, x, xp=np): Hnm1, Hn = Hn, Hnp1 return Hn + class ShapeletCartesian(AbstractShapelet): def __init__( self, @@ -89,11 +91,11 @@ def coefficient_tag(self) -> str: @check_operated_only @aa.grid_dec.transform def image_2d_from( - self, - grid: aa.type.Grid2DLike, - xp=np, - operated_only: Optional[bool] = None, - **kwargs, + self, + grid: aa.type.Grid2DLike, + xp=np, + operated_only: Optional[bool] = None, + **kwargs, ) -> np.ndarray: """ Returns the Cartesian Shapelet light profile's 2D image from a 2D grid of Cartesian (y,x) coordinates. @@ -117,19 +119,16 @@ def image_2d_from( shapelet_y = hermite_phys(self.n_y, y_ell / self.beta, xp=xp) shapelet_x = hermite_phys(self.n_x, x_ell / self.beta, xp=xp) - gaussian = xp.exp(-0.5 * (x_ell ** 2 + y_ell ** 2) / (self.beta ** 2)) + gaussian = xp.exp(-0.5 * (x_ell**2 + y_ell**2) / (self.beta**2)) - norm = ( - self.beta - * xp.sqrt( + norm = self.beta * xp.sqrt( (2.0 ** (self.n_x + self.n_y)) * xp.pi * factorial(self.n_y) * factorial(self.n_x) ) - ) - return (shapelet_y * shapelet_x * gaussian) / norm + return self._intensity * (shapelet_y * shapelet_x * gaussian) / norm class ShapeletCartesianSph(ShapeletCartesian): diff --git a/autogalaxy/profiles/light/standard/shapelets/polar.py b/autogalaxy/profiles/light/standard/shapelets/polar.py index a88a2196c..e91501a15 100644 --- a/autogalaxy/profiles/light/standard/shapelets/polar.py +++ b/autogalaxy/profiles/light/standard/shapelets/polar.py @@ -26,7 +26,9 @@ def genlaguerre_jax(n, alpha, x): # 0. Input Validation (Requires static Python int n) if not isinstance(n, int) or n < 0: # Use Python's math.isnan/isinf check if n is float, otherwise type error - raise ValueError(f"Degree n must be a non-negative Python integer (static), got {n}.") + raise ValueError( + f"Degree n must be a non-negative Python integer (static), got {n}." + ) # Base Case L0 if n == 0: @@ -45,9 +47,9 @@ def genlaguerre_jax(n, alpha, x): log_N_plus_alpha_fact = gammaln(n + alpha + 1) log_BF_k = ( - log_N_plus_alpha_fact - - gammaln(n - k_values + 1) # log( (n-k)! ) - - gammaln(alpha + k_values + 1) # log( (alpha+k)! ) + log_N_plus_alpha_fact + - gammaln(n - k_values + 1) # log( (n-k)! ) + - gammaln(alpha + k_values + 1) # log( (alpha+k)! ) ) BF_k = jnp.exp(log_BF_k) # Shape: (n+1,) @@ -56,7 +58,9 @@ def genlaguerre_jax(n, alpha, x): # TF = (-x)^k / k! # Note: jnp.math.gamma(k_values + 1) is equivalent to k! in log-gamma space - TF_k = jnp.power(-x_expanded, k_values_expanded) / jnp.exp(gammaln(k_values_expanded + 1)) + TF_k = jnp.power(-x_expanded, k_values_expanded) / jnp.exp( + gammaln(k_values_expanded + 1) + ) # TF_k Shape: (M, n+1) # --- C. Final Summation --- @@ -67,13 +71,13 @@ def genlaguerre_jax(n, alpha, x): class ShapeletPolar(AbstractShapelet): def __init__( - self, - n: int, - m: int, - centre: Tuple[float, float] = (0.0, 0.0), - ell_comps: Tuple[float, float] = (0.0, 0.0), - intensity: float = 1.0, - beta: float = 1.0, + self, + n: int, + m: int, + centre: Tuple[float, float] = (0.0, 0.0), + ell_comps: Tuple[float, float] = (0.0, 0.0), + intensity: float = 1.0, + beta: float = 1.0, ): """ Shapelets where the basis function is defined according to a Polar (r,theta) grid of coordinates. @@ -153,18 +157,18 @@ def image_2d_from( from jax.scipy.special import factorial const = ( - ((-1) ** ((self.n - xp.abs(self.m)) // 2)) - * xp.sqrt( + ((-1) ** ((self.n - xp.abs(self.m)) // 2)) + * xp.sqrt( factorial((self.n - xp.abs(self.m)) // 2) / factorial((self.n + xp.abs(self.m)) // 2) - ) - / self.beta - / xp.sqrt(xp.pi) + ) + / self.beta + / xp.sqrt(xp.pi) ) y = grid.array[:, 0] x = grid.array[:, 1] - rsq = (x ** 2 + (y / self.axis_ratio(xp)) ** 2) / self.beta ** 2 + rsq = (x**2 + (y / self.axis_ratio(xp)) ** 2) / self.beta**2 theta = xp.arctan2(y, x) m_abs = abs(self.m) @@ -174,7 +178,9 @@ def image_2d_from( from scipy.special import genlaguerre - laguerre = genlaguerre(n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m)) + laguerre = genlaguerre( + n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m) + ) laguerre_vals = laguerre(rsq) else: @@ -200,12 +206,12 @@ def image_2d_from( class ShapeletPolarSph(ShapeletPolar): def __init__( - self, - n: int, - m: int, - centre: Tuple[float, float] = (0.0, 0.0), - intensity: float = 1.0, - beta: float = 1.0, + self, + n: int, + m: int, + centre: Tuple[float, float] = (0.0, 0.0), + intensity: float = 1.0, + beta: float = 1.0, ): """ Shapelets where the basis function is defined according to a Polar (r,theta) grid of coordinates. diff --git a/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py b/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py index 8a13266f3..a981526ec 100644 --- a/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py +++ b/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py @@ -275,7 +275,9 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): """ ellip = self._ellip(xp) factor = self.b0 - zis = _ci05(x=grid.array[:, 1], y=grid.array[:, 0], eps=ellip, rcore=self.ra, xp=xp) + zis = _ci05( + x=grid.array[:, 1], y=grid.array[:, 0], eps=ellip, rcore=self.ra, xp=xp + ) # This is in axes aligned to the major/minor axis deflection_x = zis.real @@ -287,17 +289,13 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): xp=xp, **kwargs, ) - + def _convergence(self, radii, xp=np): radsq = radii * radii a = self.ra - return ( - self.b0 - / 2 - * (1 / xp.sqrt(a**2 + radsq)) - ) + return self.b0 / 2 * (1 / xp.sqrt(a**2 + radsq)) @aa.grid_dec.to_array @aa.grid_dec.transform @@ -316,14 +314,14 @@ def convergence_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): """ ellip = self._ellip(xp) grid_radii = xp.sqrt( - grid.array[:, 1] ** 2 / (1 + ellip) ** 2 + grid.array[:, 0] ** 2 / (1 - ellip) ** 2 + grid.array[:, 1] ** 2 / (1 + ellip) ** 2 + + grid.array[:, 0] ** 2 / (1 - ellip) ** 2 ) # Compute the convergence and deflection of a *circular* profile - kappa = self._convergence(grid_radii,xp) + kappa = self._convergence(grid_radii, xp) return kappa - @aa.grid_dec.transform def analytical_hessian_2d_from(self, grid: "aa.type.Grid2DLike", xp=np, **kwargs): """ @@ -340,7 +338,12 @@ def analytical_hessian_2d_from(self, grid: "aa.type.Grid2DLike", xp=np, **kwargs ellip = self._ellip() hessian_xx, hessian_xy, hessian_yx, hessian_yy = _mdci05( - x=grid.array[:, 1], y=grid.array[:, 0], eps=ellip, rcore=self.ra, b0=self.b0, xp=xp + x=grid.array[:, 1], + y=grid.array[:, 0], + eps=ellip, + rcore=self.ra, + b0=self.b0, + xp=xp, ) return hessian_yy, hessian_xy, hessian_yx, hessian_xx @@ -433,7 +436,7 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): eps=ellip, rcore=self.ra, rcut=self.rs, - xp=xp + xp=xp, ) # This is in axes aligned to the major/minor axis @@ -476,11 +479,12 @@ def convergence_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): """ ellip = self._ellip(xp) grid_radii = xp.sqrt( - grid.array[:, 1] ** 2 / (1 + ellip) ** 2 + grid.array[:, 0] ** 2 / (1 - ellip) ** 2 + grid.array[:, 1] ** 2 / (1 + ellip) ** 2 + + grid.array[:, 0] ** 2 / (1 - ellip) ** 2 ) - kappa = self._convergence(grid_radii,xp) + kappa = self._convergence(grid_radii, xp) return kappa - + @aa.grid_dec.transform def analytical_hessian_2d_from(self, grid: "aa.type.Grid2DLike", xp=np, **kwargs): """ @@ -499,10 +503,20 @@ def analytical_hessian_2d_from(self, grid: "aa.type.Grid2DLike", xp=np, **kwargs t05 = self.rs / (self.rs - self.ra) g05c_a, g05c_b, g05c_c, g05c_d = _mdci05( - x=grid.array[:, 1], y=grid.array[:, 0], eps=ellip, rcore=self.ra, b0=self.b0, xp=xp + x=grid.array[:, 1], + y=grid.array[:, 0], + eps=ellip, + rcore=self.ra, + b0=self.b0, + xp=xp, ) g05cut_a, g05cut_b, g05cut_c, g05cut_d = _mdci05( - x=grid.array[:, 1], y=grid.array[:, 0], eps=ellip, rcore=self.rs, b0=self.b0, xp=xp + x=grid.array[:, 1], + y=grid.array[:, 0], + eps=ellip, + rcore=self.rs, + b0=self.b0, + xp=xp, ) # Compute Hessian matrix components @@ -619,7 +633,7 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): xp=xp, **kwargs, ) - + @aa.grid_dec.to_array @aa.grid_dec.transform def convergence_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): diff --git a/test_autogalaxy/profiles/light/shapelets/test_polar.py b/test_autogalaxy/profiles/light/shapelets/test_polar.py index 90a90500e..f4f754514 100644 --- a/test_autogalaxy/profiles/light/shapelets/test_polar.py +++ b/test_autogalaxy/profiles/light/shapelets/test_polar.py @@ -26,7 +26,9 @@ def test__elliptical__image_2d_from(): image = shapelet.image_2d_from(grid=ag.Grid2DIrregular([[0.0, 1.0], [0.5, 0.25]])) - assert image == pytest.approx(np.array([0.02577349206014249, -0.17434262753]), abs=1e-4) + assert image == pytest.approx( + np.array([0.02577349206014249, -0.17434262753]), abs=1e-4 + ) shapelet = ag.lp_linear.ShapeletPolar( n=2, m=0, centre=(0.0, 0.0), ell_comps=(0.5, 0.7), beta=1.0