Skip to content

Fix/ShapeletPolar_dPIEkappa#259

Merged
Jammy2211 merged 20 commits intoPyAutoLabs:mainfrom
Chocologism:fix/ShapeletPolar_dPIEkappa
Jan 17, 2026
Merged

Fix/ShapeletPolar_dPIEkappa#259
Jammy2211 merged 20 commits intoPyAutoLabs:mainfrom
Chocologism:fix/ShapeletPolar_dPIEkappa

Conversation

@Chocologism
Copy link
Copy Markdown
Contributor

@Chocologism Chocologism commented Dec 12, 2025

This PR introduces two quick fixes:

  • For the mass profiles — PIEMass, dPIEMass, and dPIEMassSph — it adds the correct convergence_2d_from function, replacing the previous version that had been copied from dPIEPotential.

  • For the light profile — ShapeletPolar — it includes four modifications:
  1. Implements the Laguerre polynomial genlaguerre using JAX.
  2. Adds the stretching effect of the axis ratio on the generated image.
  3. Converts from the ell_comps convention to the (q, φ) convention.
  4. Multiplies the output image by self._intensity.

The reasons for these changes are:

  1. The old genlaguerre was taken from SciPy, but jax.scipy does not provide this function.
  2. While basis functions can always approximate complex structures by increasing order, a simple elliptical galaxy can be modeled with much lower order if the axis ratio is treated as a free parameter.
  3. Under the original ell_comps convention, the only solution that yields the circular case is (0, 0), which forces both axis_ratio=0 and position_angle=0. This is fine for circularly symmetric profiles, but higher-order shapelets contain angular components, so rotation still has meaning even in circular cases. Likewise, although increasing shapelet order can approximate anything, allowing rotation lets us describe the light distribution with fewer shapelet terms.
  4. Although the absence of intensity does not affect the inversion process, it is crucial when converting a linear profile into a standard profile; otherwise, the output image always has intensity = 1.0.

Open questions:

  1. After implementing genlaguerre in JAX, ShapeletPolar now only supports JAX, unlike the rest of the codebase where xp switches between jnp and np. I’m undecided whether to simply replace all jnp in genlaguerre_jax with xp, or to add logic that chooses between genlaguerre_jax and scipy.genlaguerre depending on xp.
  2. After converting from ell_comps to the (q, φ) convention, ShapeletPolar produces correct images but fails during fitting (error in add_value_to_hash_list). This might be due to using parameters outside the expected specification, or due to unsafe type conversions (such as self.phi = float(phi)).
  3. I have modified only ShapeletPolar for now. Extending the same updates to the other shapelets should be straightforward — for example, ShapeletCartesian only needs a JAX implementation of the Hermite functions. After the add_value_to_hash_list issue is solved, I can update those as well if you don’t have time.
  4. SersicletPolar is also on my roadmap. After the above issues are resolved, I plan to implement a Sersiclet that reduces to a Shapelet in special cases while incorporating a Sersic-like weighting function.

@Jammy2211
Copy link
Copy Markdown
Collaborator

Under the original ell_comps convention, the only solution that yields the circular case is (0, 0), which forces both axis_ratio=0 and position_angle=0. This is fine for circularly symmetric profiles, but higher-order shapelets contain angular components, so rotation still has meaning even in circular cases. Likewise, although increasing shapelet order can approximate anything, allowing rotation lets us describe the light distribution with fewer shapelet terms.

I think we should think about if there is a way to retain a parameterization like ell_comps but suitable for shapelets. We use ell_comps because in terms of sampler (e.g. lens modeling) having a parameter space which "loops in on itselfat the two ends ofposition_angleand where all solutions are identical ataxis_ratio=1makes them difficult and inefficient to sample. Modeling results will be a lot more efficient if we can work out aell_comps- like parameterization where circularity is at the centre of the parameter space (e.g. at (0.0, 0.0`).

I think we have already dealt with a very similar situation for the power-law multipole mass model:

https://github.com/Jammy2211/PyAutoGalaxy/blob/main/autogalaxy/profiles/mass/total/power_law_multipole.py

https://github.com/Jammy2211/PyAutoGalaxy/blob/main/autogalaxy/convert.py

def multipole_k_m_and_phi_m_from(
    multipole_comps: Tuple[float, float], m: int, xp=np
) -> Tuple[float, float]:
    """
    Returns the multipole normalization value `k_m` and angle `phi` from the multipole component parameters.

    The normalization and angle are given by:

    .. math::
        \phi^{\rm mass}_m = \frac{1}{m} \arctan{\frac{\epsilon_{\rm 2}^{\rm mp}}{\epsilon_{\rm 1}^{\rm mp}}}, \, \,
        k^{\rm mass}_m = \sqrt{{\epsilon_{\rm 1}^{\rm mp}}^2 + {\epsilon_{\rm 2}^{\rm mp}}^2} \, .

    The conversion depends on the multipole order `m`, to ensure that all possible rotationally symmetric
    multiple mass profiles are available in the conversion for multiple components spanning -inf to inf.

    Additional checks are performed which requires the angle `phi_m` is between -45 and 135 degrees. This ensures that
    for certain multipole component values the angle does not jump from one boundary to another (e.g. without
    this check certain values of `gamma_1` and `gamma_2` return -1.0 degrees and others 179.0 degrees).

    This ensures that when error estimates are computed from samples of a lens model via marginalization, the
    calculation is not biased by the angle jumping between these two values.

    Parameters
    ----------
    multipole_comps
        The first and second components of the multipole.

    Returns
    -------
    The normalization and angle parameters of the multipole.
    """
    phi_m = (
        xp.arctan2(multipole_comps[0], multipole_comps[1]) * 180.0 / xp.pi / float(m)
    )
    k_m = xp.sqrt(multipole_comps[1] ** 2 + multipole_comps[0] ** 2)

    phi_m = xp.where(phi_m < -90.0 / m, phi_m + 360.0 / m, phi_m)

    return k_m, phi_m

Its not exactly the same but basically the 90 / m and 360 / m terms make sure that increasing with order the rotations only span the sensible ranges.

After implementing genlaguerre in JAX, ShapeletPolar now only supports JAX, unlike the rest of the codebase where xp switches between jnp and np. I’m undecided whether to simply replace all jnp in genlaguerre_jax with xp, or to add logic that chooses between genlaguerre_jax and scipy.genlaguerre depending on xp.

If it makes sense to do the latter, I would do that, an example of me doing this is in autoarray's convolution functionality:

https://github.com/Jammy2211/PyAutoArray/blob/main/autoarray/structures/arrays/kernel_2d.py

        if xp is np:
            return self.convolved_mapping_matrix_via_real_space_np_from(
                mapping_matrix=mapping_matrix,
                mask=mask,
                blurring_mapping_matrix=blurring_mapping_matrix,
                blurring_mask=blurring_mask,
                xp=xp,
            )

... otherwise it goes on to the do the JAX functionality

So do the latter if it feels cleaner!

After converting from ell_comps to the (q, φ) convention, ShapeletPolar produces correct images but fails during fitting (error in add_value_to_hash_list). This might be due to using parameters outside the expected specification, or due to unsafe type conversions (such as self.phi = float(phi)).

I will investigate this tomrorow, shouldnt be hard to narrow down but will be easier with my understanding of autofit :).

I have modified only ShapeletPolar for now. Extending the same updates to the other shapelets should be straightforward — for example, ShapeletCartesian only needs a JAX implementation of the Hermite functions. After the add_value_to_hash_list issue is solved, I can update those as well if you don’t have time.

SersicletPolar is also on my roadmap. After the above issues are resolved, I plan to implement a Sersiclet that reduces to a Shapelet in special cases while incorporating a Sersic-like weighting function.

Both sound great :)

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces fixes for mass profile convergence calculations and updates to the ShapeletPolar light profile, including JAX implementation of generalized Laguerre polynomials and a change from elliptical components to axis-ratio/position-angle parameterization.

Changes:

  • Fixes convergence_2d_from methods for PIEMass, dPIEMass, and dPIEMassSph by replacing incorrect implementations that were copied from dPIEPotential with correct formulas based on the Kassiola & Kovner (1993) formulation
  • Adds xp parameter to helper functions (_ci05, _ci05f, _mdci05) for proper numpy/JAX compatibility
  • Implements genlaguerre_jax function to replace scipy's generalized Laguerre polynomial with a JAX-compatible version
  • Changes ShapeletPolar parameterization from ell_comps to (q, phi) convention to allow independent control of axis ratio and rotation
  • Updates ShapeletPolar image calculation to include axis ratio stretching and intensity multiplication

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 11 comments.

File Description
autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py Adds correct convergence_2d_from implementations for PIEMass, dPIEMass, and dPIEMassSph; adds xp parameter to helper functions for numpy/JAX compatibility
autogalaxy/profiles/light/standard/shapelets/polar.py Implements JAX-based Laguerre polynomials, changes from ell_comps to (q, phi) parameterization, removes transform decorator, adds manual grid transformation and axis ratio scaling
autogalaxy/profiles/light/linear/shapelets/polar.py Updates parameter signatures to use (q, phi) instead of ell_comps for consistency with standard shapelet changes

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

laguerre = genlaguerre(n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m))

grid = aa.util.geometry.transform_grid_2d_to_reference_frame(
grid_2d=grid.array, centre=self.centre, angle=self.phi, xp=xp
Copy link

Copilot AI Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decorator @aa.grid_dec.transform has been removed from the image_2d_from method. This decorator typically handles the transformation of grid coordinates to the profile's reference frame. By removing it and manually calling transform_grid_2d_to_reference_frame, the code may not properly integrate with the parent class's geometry system and could produce incorrect results when the profile has elliptical components.

Suggested change
grid_2d=grid.array, centre=self.centre, angle=self.phi, xp=xp
grid_2d=grid, centre=self.centre, angle=self.phi, xp=xp

Copilot uses AI. Check for mistakes.
# zero over all space
return kappa_circ * (1 - asymm_term) + (alpha_circ / grid_radii) * asymm_term

kappa = self._convergence(grid_radii,xp)
Copy link

Copilot AI Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing space after comma in function call. Should be self._convergence(grid_radii, xp) for consistency with Python style guidelines.

Suggested change
kappa = self._convergence(grid_radii,xp)
kappa = self._convergence(grid_radii, xp)

Copilot uses AI. Check for mistakes.
Comment on lines 112 to 114
super().__init__(
centre=centre, ell_comps=ell_comps, beta=beta, intensity=intensity
centre=centre, beta=beta, intensity=intensity
)
Copy link

Copilot AI Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parent class AbstractShapelet requires ell_comps as a parameter, but this change removes it from the super().init() call. This will cause an error because the parent class LightProfile (through AbstractShapelet) expects ell_comps to be provided. The parent's init signature is def __init__(self, centre, ell_comps, intensity, beta), so omitting ell_comps will result in a TypeError.

Copilot uses AI. Check for mistakes.
Comment on lines +14 to +63
def genlaguerre_jax(n, alpha, x):
"""
Generalized (associated) Laguerre polynomial L_n^alpha(x)
calculated using the explicit summation formula, optimized for JAX vectorization.

Parameters:
n (int): Degree of the polynomial (static Python integer).
alpha (Numeric): Parameter alpha > -1.
x (Array): Input array (evaluation points).
"""
# 0. Input Validation (Requires static Python int n)
if not isinstance(n, int) or n < 0:
# Use Python's math.isnan/isinf check if n is float, otherwise type error
raise ValueError(f"Degree n must be a non-negative Python integer (static), got {n}.")

# Base Case L0
if n == 0:
return jnp.ones_like(x)

# 1. Generate k values for summation range [0, 1, 2, ..., n]
k_values = jnp.arange(n + 1) # (n+1,)

# 2. Reshape inputs for broadcasting (x: (M, 1), k: (1, n+1))
x_expanded = jnp.expand_dims(x, axis=-1)
k_values_expanded = jnp.expand_dims(k_values, axis=0)

# --- A. Binomial Factor (BF) Calculation ---
# BF = exp( log( (n+alpha)! / ((n-k)! * (alpha+k)!) ) )

log_N_plus_alpha_fact = gammaln(n + alpha + 1)

log_BF_k = (
log_N_plus_alpha_fact
- gammaln(n - k_values + 1) # log( (n-k)! )
- gammaln(alpha + k_values + 1) # log( (alpha+k)! )
)

BF_k = jnp.exp(log_BF_k) # Shape: (n+1,)

# --- B. Term Factor (TF) Calculation ---
# TF = (-x)^k / k!

# Note: jnp.math.gamma(k_values + 1) is equivalent to k! in log-gamma space
TF_k = jnp.power(-x_expanded, k_values_expanded) / jnp.exp(gammaln(k_values_expanded + 1))
# TF_k Shape: (M, n+1)

# --- C. Final Summation ---
# Sum over the last axis (axis=1), which corresponds to k
# BF_k broadcasts over the M dimension of TF_k
return jnp.sum(BF_k * TF_k, axis=1)
Copy link

Copilot AI Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function uses hardcoded jnp operations instead of respecting the xp parameter that is passed to image_2d_from. This means the function will fail if numpy is expected (when xp=np). The genlaguerre_jax function should either use the xp parameter or there should be conditional logic to choose between a JAX and NumPy implementation.

Copilot uses AI. Check for mistakes.
Comment on lines +148 to +150
grid = aa.util.geometry.transform_grid_2d_to_reference_frame(
grid_2d=grid.array, centre=self.centre, angle=self.phi, xp=xp
)
Copy link

Copilot AI Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The manual transformation uses self.phi directly as the rotation angle, but this doesn't account for the elliptical coordinate system transformations that the parent class handles. The removed @aa.grid_dec.transform decorator would have applied the proper transformations including centre translation, rotation, and elliptical scaling based on ell_comps. The manual approach bypasses this and may produce incorrect results.

Copilot uses AI. Check for mistakes.
/ self.beta
/ xp.sqrt(xp.pi)
)
rsq = (grid[:, 0] ** 2 + (grid[:, 1]/self.q) ** 2) / self.beta**2
Copy link

Copilot AI Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The axis ratio self.q is applied inconsistently in the calculation. The code divides grid[:, 1] by self.q when computing rsq, but this doesn't account for the full elliptical transformation. In an elliptical coordinate system, both coordinates should be scaled properly, and the transformation should be applied after rotation to the major/minor axis frame. The current implementation may not correctly represent an elliptical shapelet.

Copilot uses AI. Check for mistakes.
grid.array[:, 1] ** 2 / (1 + ellip) ** 2 + grid.array[:, 0] ** 2 / (1 - ellip) ** 2
)
# Compute the convergence and deflection of a *circular* profile
kappa = self._convergence(grid_radii,xp)
Copy link

Copilot AI Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing space after comma in function call. Should be self._convergence(grid_radii, xp) for consistency with Python style guidelines.

Suggested change
kappa = self._convergence(grid_radii,xp)
kappa = self._convergence(grid_radii, xp)

Copilot uses AI. Check for mistakes.
Comment on lines +107 to +110
self.n = int(n)
self.m = int(m)
self.phi = float(phi)
self.q = float(q)
Copy link

Copilot AI Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameters q and phi are stored as instance attributes but they are not being used by the parent class. Since the parent class EllProfile expects ell_comps to define the elliptical geometry and uses it in transformation methods, replacing the parameter convention breaks the integration with the parent class geometry system. The code should either maintain ell_comps compatibility or the AbstractShapelet base class needs to be updated to support the new parameter convention.

Copilot uses AI. Check for mistakes.
Comment on lines +11 to +12
import jax.numpy as jnp
from jax.scipy.special import gammaln
Copy link

Copilot AI Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded import of jax.numpy as jnp at the module level forces a dependency on JAX for all users of this module, even if they are not using JAX. This breaks the existing pattern in the codebase where xp is used as a parameter to switch between numpy and JAX. Consider making JAX an optional dependency and importing it conditionally, or implementing a fallback to scipy.special.genlaguerre when JAX is not available.

Copilot uses AI. Check for mistakes.
@Jammy2211
Copy link
Copy Markdown
Collaborator

Ok, I dug deeper on this and in the end think that the only thing that needed changing was that the axis ratio divison was missing here:

rsq = (grid.array[:, 0] ** 2 + grid.array[:, 1] ** 2) / self.beta**2

Which is now:

        y = grid.array[:, 0]
        x = grid.array[:, 1]

        rsq = (x ** 2 + (y / self.axis_ratio(xp)) ** 2) / self.beta ** 2
        theta = xp.arctan2(y, x)

This means I can retain the ell_comps being the free parameter of the shapelets, we just needed their axis ratio to be worked into the division.

I have deleted the code which rotates the grid by phi as it just does what the @aa.grid_dec.transform did originally.

I have also fixed Cartesian shapelets based on the issues you spottted.

Once you're back on the project please check that the merged code works compared to what you had before, I could of re introduced a stretching error but I think my changes simplify the code whilst maintaining the correct functionality.

@Jammy2211 Jammy2211 merged commit 48797d4 into PyAutoLabs:main Jan 17, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants