diff --git a/autogalaxy/profiles/light/linear/shapelets/polar.py b/autogalaxy/profiles/light/linear/shapelets/polar.py index f3c34e7fa..1898ceaf2 100644 --- a/autogalaxy/profiles/light/linear/shapelets/polar.py +++ b/autogalaxy/profiles/light/linear/shapelets/polar.py @@ -33,8 +33,11 @@ def __init__( The m order of the shapelets basis function in the x-direction. centre The (y,x) arc-second coordinates of the profile (shapelet) centre. - ell_comps - The first and second ellipticity components of the elliptical coordinate system. + q + The axis-ratio of the elliptical coordinate system, where a perfect circle has q=1.0. + phi + The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the + positive x-axis. intensity Overall intensity normalisation of the light profile (units are dimensionless and derived from the data the light profile's image is compared too, which is expected to be electrons per second). @@ -53,6 +56,7 @@ def __init__( n: int, m: int, centre: Tuple[float, float] = (0.0, 0.0), + phi: float = 0.0, beta: float = 1.0, ): """ @@ -74,8 +78,11 @@ def __init__( The order of the shapelets basis function in the x-direction. centre The (y,x) arc-second coordinates of the profile (shapelet) centre. + phi + The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the + positive x-axis. beta The characteristic length scale of the shapelet basis function, defined in arc-seconds. """ - super().__init__(n=n, m=m, centre=centre, ell_comps=(0.0, 0.0), beta=beta) + super().__init__(n=n, m=m, centre=centre, beta=beta) diff --git a/autogalaxy/profiles/light/standard/shapelets/cartesian.py b/autogalaxy/profiles/light/standard/shapelets/cartesian.py index 9be1b123c..95b7c8d8f 100644 --- a/autogalaxy/profiles/light/standard/shapelets/cartesian.py +++ b/autogalaxy/profiles/light/standard/shapelets/cartesian.py @@ -9,6 +9,34 @@ from autogalaxy.profiles.light.standard.shapelets.abstract import AbstractShapelet +def hermite_phys(n: int, x, xp=np): + """ + Physicists' Hermite polynomial H_n(x), compatible with NumPy and JAX via `xp`. + + Recurrence: + H_0(x) = 1 + H_1(x) = 2x + H_{n+1}(x) = 2x H_n(x) - 2n H_{n-1}(x) + """ + if n < 0: + raise ValueError("n must be >= 0") + + H0 = xp.ones_like(x) + if n == 0: + return H0 + + H1 = 2.0 * x + if n == 1: + return H1 + + Hnm1 = H0 + Hn = H1 + for k in range(1, n): + Hnp1 = 2.0 * x * Hn - 2.0 * float(k) * Hnm1 + Hnm1, Hn = Hn, Hnp1 + return Hn + + class ShapeletCartesian(AbstractShapelet): def __init__( self, @@ -71,47 +99,37 @@ def image_2d_from( ) -> np.ndarray: """ Returns the Cartesian Shapelet light profile's 2D image from a 2D grid of Cartesian (y,x) coordinates. - - If the coordinates have not been transformed to the profile's geometry (e.g. translated to the - profile `centre`), this is performed automatically. - - Parameters - ---------- - grid - The 2D (y, x) coordinates in the original reference frame of the grid. - - Returns - ------- - image - The image of the Cartesian Shapelet evaluated at every (y,x) coordinate on the transformed grid. """ - from jax.scipy.special import factorial - from scipy.special import hermite - hermite_y = hermite(n=self.n_y) - hermite_x = hermite(n=self.n_x) + # factorial backend switch + if xp is np: + from scipy.special import factorial + else: + from jax.scipy.special import factorial y = grid.array[:, 0] x = grid.array[:, 1] - shapelet_y = hermite_y(y / self.beta) - shapelet_x = hermite_x(x / self.beta) - - return ( - shapelet_y - * shapelet_x - * xp.exp(-0.5 * (y**2 + x**2) / (self.beta**2)) - / self.beta - / ( - xp.sqrt( - 2 ** (self.n_x + self.n_y) - * (xp.pi) - * factorial(self.n_y) - * factorial(self.n_x) - ) - ) + # Apply axis-ratio stretching (minor axis) + q = self.axis_ratio(xp) + y_ell = y / q + x_ell = x + + # Evaluate Hermite polynomials (JAX-safe) + shapelet_y = hermite_phys(self.n_y, y_ell / self.beta, xp=xp) + shapelet_x = hermite_phys(self.n_x, x_ell / self.beta, xp=xp) + + gaussian = xp.exp(-0.5 * (x_ell**2 + y_ell**2) / (self.beta**2)) + + norm = self.beta * xp.sqrt( + (2.0 ** (self.n_x + self.n_y)) + * xp.pi + * factorial(self.n_y) + * factorial(self.n_x) ) + return self._intensity * (shapelet_y * shapelet_x * gaussian) / norm + class ShapeletCartesianSph(ShapeletCartesian): def __init__( diff --git a/autogalaxy/profiles/light/standard/shapelets/polar.py b/autogalaxy/profiles/light/standard/shapelets/polar.py index 1d6c98db1..e91501a15 100644 --- a/autogalaxy/profiles/light/standard/shapelets/polar.py +++ b/autogalaxy/profiles/light/standard/shapelets/polar.py @@ -3,13 +3,72 @@ import autoarray as aa - from autogalaxy.profiles.light.decorators import ( check_operated_only, ) +from autogalaxy import convert from autogalaxy.profiles.light.standard.shapelets.abstract import AbstractShapelet +def genlaguerre_jax(n, alpha, x): + """ + Generalized (associated) Laguerre polynomial L_n^alpha(x) + calculated using the explicit summation formula, optimized for JAX vectorization. + + Parameters: + n (int): Degree of the polynomial (static Python integer). + alpha (Numeric): Parameter alpha > -1. + x (Array): Input array (evaluation points). + """ + import jax.numpy as jnp + from jax.scipy.special import gammaln + + # 0. Input Validation (Requires static Python int n) + if not isinstance(n, int) or n < 0: + # Use Python's math.isnan/isinf check if n is float, otherwise type error + raise ValueError( + f"Degree n must be a non-negative Python integer (static), got {n}." + ) + + # Base Case L0 + if n == 0: + return jnp.ones_like(x) + + # 1. Generate k values for summation range [0, 1, 2, ..., n] + k_values = jnp.arange(n + 1) # (n+1,) + + # 2. Reshape inputs for broadcasting (x: (M, 1), k: (1, n+1)) + x_expanded = jnp.expand_dims(x, axis=-1) + k_values_expanded = jnp.expand_dims(k_values, axis=0) + + # --- A. Binomial Factor (BF) Calculation --- + # BF = exp( log( (n+alpha)! / ((n-k)! * (alpha+k)!) ) ) + + log_N_plus_alpha_fact = gammaln(n + alpha + 1) + + log_BF_k = ( + log_N_plus_alpha_fact + - gammaln(n - k_values + 1) # log( (n-k)! ) + - gammaln(alpha + k_values + 1) # log( (alpha+k)! ) + ) + + BF_k = jnp.exp(log_BF_k) # Shape: (n+1,) + + # --- B. Term Factor (TF) Calculation --- + # TF = (-x)^k / k! + + # Note: jnp.math.gamma(k_values + 1) is equivalent to k! in log-gamma space + TF_k = jnp.power(-x_expanded, k_values_expanded) / jnp.exp( + gammaln(k_values_expanded + 1) + ) + # TF_k Shape: (M, n+1) + + # --- C. Final Summation --- + # Sum over the last axis (axis=1), which corresponds to k + # BF_k broadcasts over the M dimension of TF_k + return jnp.sum(BF_k * TF_k, axis=1) + + class ShapeletPolar(AbstractShapelet): def __init__( self, @@ -39,8 +98,11 @@ def __init__( The m order of the shapelets basis function in the x-direction. centre The (y,x) arc-second coordinates of the profile (shapelet) centre. - ell_comps - The first and second ellipticity components of the elliptical coordinate system. + q + The axis-ratio of the elliptical coordinate system, where a perfect circle has q=1.0. + phi + The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the + positive x-axis. intensity Overall intensity normalisation of the light profile (units are dimensionless and derived from the data the light profile's image is compared too, which is expected to be electrons per second). @@ -48,8 +110,8 @@ def __init__( The characteristic length scale of the shapelet basis function, defined in arc-seconds. """ - self.n = n - self.m = m + self.n = int(n) + self.m = int(m) super().__init__( centre=centre, ell_comps=ell_comps, beta=beta, intensity=intensity @@ -86,10 +148,13 @@ def image_2d_from( image The image of the Polar Shapelet evaluated at every (y,x) coordinate on the transformed grid. """ - from scipy.special import genlaguerre - from jax.scipy.special import factorial + if xp is np: - laguerre = genlaguerre(n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m)) + from scipy.special import factorial + + else: + + from jax.scipy.special import factorial const = ( ((-1) ** ((self.n - xp.abs(self.m)) // 2)) @@ -100,19 +165,43 @@ def image_2d_from( / self.beta / xp.sqrt(xp.pi) ) + y = grid.array[:, 0] + x = grid.array[:, 1] + + rsq = (x**2 + (y / self.axis_ratio(xp)) ** 2) / self.beta**2 + theta = xp.arctan2(y, x) - rsq = (grid.array[:, 0] ** 2 + grid.array[:, 1] ** 2) / self.beta**2 - theta = xp.arctan2(grid.array[:, 1], grid.array[:, 0]) - radial = rsq ** (abs(self.m / 2.0)) * xp.exp(-rsq / 2.0) * laguerre(rsq) + m_abs = abs(self.m) + n_laguerre = (self.n - m_abs) // 2 + + if xp is np: + + from scipy.special import genlaguerre + + laguerre = genlaguerre( + n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m) + ) + laguerre_vals = laguerre(rsq) - if self.m == 0: - azimuthal = 1 - elif self.m > 0: - azimuthal = xp.sin((-1) * self.m * theta) else: - azimuthal = xp.cos((-1) * self.m * theta) - return const * radial * azimuthal + laguerre_vals = genlaguerre_jax(n=n_laguerre, alpha=m_abs, x=rsq) + + radial = rsq ** (xp.abs(self.m) / 2.0) * xp.exp(-rsq / 2.0) * laguerre_vals + + m = self.m + + azimuthal = xp.where( + m == 0, + xp.ones_like(theta), + xp.where( + m > 0, + xp.sin(-m * theta), + xp.cos(-m * theta), + ), + ) + + return self._intensity * const * radial * azimuthal class ShapeletPolarSph(ShapeletPolar): @@ -143,6 +232,9 @@ def __init__( The order of the shapelets basis function in the x-direction. centre The (y,x) arc-second coordinates of the profile (shapelet) centre. + phi + The position angle (in degrees) of the elliptical coordinate system, measured counter-clockwise from the + positive x-axis. intensity Overall intensity normalisation of the light profile (units are dimensionless and derived from the data the light profile's image is compared too, which is expected to be electrons per second). @@ -154,7 +246,6 @@ def __init__( n=n, m=m, centre=centre, - ell_comps=(0.0, 0.0), intensity=intensity, beta=beta, ) diff --git a/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py b/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py index bb6f98d46..a981526ec 100644 --- a/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py +++ b/autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py @@ -275,7 +275,9 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): """ ellip = self._ellip(xp) factor = self.b0 - zis = _ci05(x=grid.array[:, 1], y=grid.array[:, 0], eps=ellip, rcore=self.ra) + zis = _ci05( + x=grid.array[:, 1], y=grid.array[:, 0], eps=ellip, rcore=self.ra, xp=xp + ) # This is in axes aligned to the major/minor axis deflection_x = zis.real @@ -288,6 +290,38 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): **kwargs, ) + def _convergence(self, radii, xp=np): + + radsq = radii * radii + a = self.ra + + return self.b0 / 2 * (1 / xp.sqrt(a**2 + radsq)) + + @aa.grid_dec.to_array + @aa.grid_dec.transform + def convergence_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): + """ + Returns the two-dimensional projected convergence on a grid of (y,x) + arc-second coordinates. + + The `grid_2d_to_structure` decorator reshapes the ndarrays the convergence + is outputted on. See *aa.grid_2d_to_structure* for details. + + Parameters + ---------- + grid + The grid of (y,x) arc-second coordinates on which the convergence is computed. + """ + ellip = self._ellip(xp) + grid_radii = xp.sqrt( + grid.array[:, 1] ** 2 / (1 + ellip) ** 2 + + grid.array[:, 0] ** 2 / (1 - ellip) ** 2 + ) + # Compute the convergence and deflection of a *circular* profile + kappa = self._convergence(grid_radii, xp) + + return kappa + @aa.grid_dec.transform def analytical_hessian_2d_from(self, grid: "aa.type.Grid2DLike", xp=np, **kwargs): """ @@ -304,7 +338,12 @@ def analytical_hessian_2d_from(self, grid: "aa.type.Grid2DLike", xp=np, **kwargs ellip = self._ellip() hessian_xx, hessian_xy, hessian_yx, hessian_yy = _mdci05( - x=grid.array[:, 1], y=grid.array[:, 0], eps=ellip, rcore=self.ra, b0=self.b0 + x=grid.array[:, 1], + y=grid.array[:, 0], + eps=ellip, + rcore=self.ra, + b0=self.b0, + xp=xp, ) return hessian_yy, hessian_xy, hessian_yx, hessian_xx @@ -397,6 +436,7 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): eps=ellip, rcore=self.ra, rcut=self.rs, + xp=xp, ) # This is in axes aligned to the major/minor axis @@ -410,22 +450,6 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): **kwargs, ) - def _deflection_angle(self, radii, xp=np): - """ - For a circularly symmetric dPIEPotential profile, computes the magnitude of the deflection at each radius. - """ - a, s = self.ra, self.rs - radii = xp.maximum(radii, 1e-8) - f = radii / (a + xp.sqrt(a**2 + radii**2)) - radii / ( - s + xp.sqrt(s**2 + radii**2) - ) - - # c.f. Eliasdottir '07 eq. A23 - # magnitude of deflection - # alpha = self.E0 * (s + a) / s * f - alpha = self.b0 * s / (s - a) * f - return alpha - def _convergence(self, radii, xp=np): radsq = radii * radii @@ -439,7 +463,7 @@ def _convergence(self, radii, xp=np): * (1 / xp.sqrt(a**2 + radsq) - 1 / xp.sqrt(s**2 + radsq)) ) - @aa.grid_dec.to_vector_yx + @aa.grid_dec.to_array @aa.grid_dec.transform def convergence_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): """ @@ -455,22 +479,11 @@ def convergence_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): """ ellip = self._ellip(xp) grid_radii = xp.sqrt( - grid.array[:, 1] ** 2 * (1 - ellip) + grid.array[:, 0] ** 2 * (1 + ellip) + grid.array[:, 1] ** 2 / (1 + ellip) ** 2 + + grid.array[:, 0] ** 2 / (1 - ellip) ** 2 ) - - # Compute the convergence and deflection of a *circular* profile - kappa_circ = self._convergence(grid_radii, xp) - alpha_circ = self._deflection_angle(grid_radii, xp) - - asymm_term = ( - ellip * (1 - ellip) * grid.array[:, 1] ** 2 - - ellip * (1 + ellip) * grid.array[:, 0] ** 2 - ) / grid_radii**2 - - # convergence = 1/2 \nabla \alpha = 1/2 \nabla^2 potential - # The "asymm_term" is asymmetric on x and y, so averages out to - # zero over all space - return kappa_circ * (1 - asymm_term) + (alpha_circ / grid_radii) * asymm_term + kappa = self._convergence(grid_radii, xp) + return kappa @aa.grid_dec.transform def analytical_hessian_2d_from(self, grid: "aa.type.Grid2DLike", xp=np, **kwargs): @@ -490,10 +503,20 @@ def analytical_hessian_2d_from(self, grid: "aa.type.Grid2DLike", xp=np, **kwargs t05 = self.rs / (self.rs - self.ra) g05c_a, g05c_b, g05c_c, g05c_d = _mdci05( - x=grid.array[:, 1], y=grid.array[:, 0], eps=ellip, rcore=self.ra, b0=self.b0 + x=grid.array[:, 1], + y=grid.array[:, 0], + eps=ellip, + rcore=self.ra, + b0=self.b0, + xp=xp, ) g05cut_a, g05cut_b, g05cut_c, g05cut_d = _mdci05( - x=grid.array[:, 1], y=grid.array[:, 0], eps=ellip, rcore=self.rs, b0=self.b0 + x=grid.array[:, 1], + y=grid.array[:, 0], + eps=ellip, + rcore=self.rs, + b0=self.b0, + xp=xp, ) # Compute Hessian matrix components @@ -611,6 +634,25 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): **kwargs, ) + @aa.grid_dec.to_array + @aa.grid_dec.transform + def convergence_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): + """ + Returns the two dimensional projected convergence on a grid of (y,x) arc-second coordinates. + + The `grid_2d_to_structure` decorator reshapes the ndarrays the convergence is outputted on. See + *aa.grid_2d_to_structure* for a description of the output. + + Parameters + ---------- + grid + The grid of (y,x) arc-second coordinates the convergence is computed on. + """ + # already transformed to center on profile centre so this works + radsq = grid.array[:, 0] ** 2 + grid.array[:, 1] ** 2 + + return self._convergence(xp.sqrt(radsq), xp) + @aa.grid_dec.transform def analytical_hessian_2d_from(self, grid: "aa.type.Grid2DLike", xp=np, **kwargs): """ diff --git a/test_autogalaxy/profiles/light/shapelets/test_cartesian.py b/test_autogalaxy/profiles/light/shapelets/test_cartesian.py index ad4dbf2f0..b6bb9b3fe 100644 --- a/test_autogalaxy/profiles/light/shapelets/test_cartesian.py +++ b/test_autogalaxy/profiles/light/shapelets/test_cartesian.py @@ -30,7 +30,7 @@ def test__elliptical__image_2d_from(): image = shapelet.image_2d_from(grid=ag.Grid2DIrregular([[0.0, 1.0], [0.5, 0.25]])) - assert image == pytest.approx(np.array([0.13444, 0.122273]), 1e-4) + assert image == pytest.approx(np.array([0.1066423886714124, 0.014346163370]), 1e-4) shapelet = ag.lp_linear.ShapeletCartesian( n_y=2, n_x=3, centre=(0.0, 0.0), ell_comps=(0.2, 0.3), beta=1.0 @@ -38,4 +38,4 @@ def test__elliptical__image_2d_from(): image = shapelet.image_2d_from(grid=ag.Grid2DIrregular([[0.0, 1.0], [0.5, 0.25]])) - assert image == pytest.approx(np.array([0.12993, 0.13719]), 1e-4) + assert image == pytest.approx(np.array([0.0322749900, -0.075487038]), 1e-4) diff --git a/test_autogalaxy/profiles/light/shapelets/test_polar.py b/test_autogalaxy/profiles/light/shapelets/test_polar.py index eef873e91..f4f754514 100644 --- a/test_autogalaxy/profiles/light/shapelets/test_polar.py +++ b/test_autogalaxy/profiles/light/shapelets/test_polar.py @@ -26,7 +26,9 @@ def test__elliptical__image_2d_from(): image = shapelet.image_2d_from(grid=ag.Grid2DIrregular([[0.0, 1.0], [0.5, 0.25]])) - assert image == pytest.approx(np.array([0.0, -0.33177]), abs=1e-4) + assert image == pytest.approx( + np.array([0.02577349206014249, -0.17434262753]), abs=1e-4 + ) shapelet = ag.lp_linear.ShapeletPolar( n=2, m=0, centre=(0.0, 0.0), ell_comps=(0.5, 0.7), beta=1.0 @@ -34,4 +36,4 @@ def test__elliptical__image_2d_from(): image = shapelet.image_2d_from(grid=ag.Grid2DIrregular([[0.0, 1.0], [0.5, 0.25]])) - assert image == pytest.approx(np.array([0.0, -0.33177]), abs=1e-4) + assert image == pytest.approx(np.array([0.001538188813, 0.0]), abs=1e-4)