Fix/ShapeletPolar_dPIEkappa#259
Conversation
I think we should think about if there is a way to retain a parameterization like I think we have already dealt with a very similar situation for the power-law multipole mass model: https://github.com/Jammy2211/PyAutoGalaxy/blob/main/autogalaxy/convert.py Its not exactly the same but basically the
If it makes sense to do the latter, I would do that, an example of me doing this is in autoarray's convolution functionality: https://github.com/Jammy2211/PyAutoArray/blob/main/autoarray/structures/arrays/kernel_2d.py So do the latter if it feels cleaner!
I will investigate this tomrorow, shouldnt be hard to narrow down but will be easier with my understanding of autofit :).
SersicletPolar is also on my roadmap. After the above issues are resolved, I plan to implement a Sersiclet that reduces to a Shapelet in special cases while incorporating a Sersic-like weighting function. Both sound great :) |
There was a problem hiding this comment.
Pull request overview
This PR introduces fixes for mass profile convergence calculations and updates to the ShapeletPolar light profile, including JAX implementation of generalized Laguerre polynomials and a change from elliptical components to axis-ratio/position-angle parameterization.
Changes:
- Fixes
convergence_2d_frommethods for PIEMass, dPIEMass, and dPIEMassSph by replacing incorrect implementations that were copied from dPIEPotential with correct formulas based on the Kassiola & Kovner (1993) formulation - Adds
xpparameter to helper functions (_ci05, _ci05f, _mdci05) for proper numpy/JAX compatibility - Implements
genlaguerre_jaxfunction to replace scipy's generalized Laguerre polynomial with a JAX-compatible version - Changes ShapeletPolar parameterization from
ell_compsto(q, phi)convention to allow independent control of axis ratio and rotation - Updates ShapeletPolar image calculation to include axis ratio stretching and intensity multiplication
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 11 comments.
| File | Description |
|---|---|
| autogalaxy/profiles/mass/total/dual_pseudo_isothermal_mass.py | Adds correct convergence_2d_from implementations for PIEMass, dPIEMass, and dPIEMassSph; adds xp parameter to helper functions for numpy/JAX compatibility |
| autogalaxy/profiles/light/standard/shapelets/polar.py | Implements JAX-based Laguerre polynomials, changes from ell_comps to (q, phi) parameterization, removes transform decorator, adds manual grid transformation and axis ratio scaling |
| autogalaxy/profiles/light/linear/shapelets/polar.py | Updates parameter signatures to use (q, phi) instead of ell_comps for consistency with standard shapelet changes |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| laguerre = genlaguerre(n=(self.n - xp.abs(self.m)) / 2.0, alpha=xp.abs(self.m)) | ||
|
|
||
| grid = aa.util.geometry.transform_grid_2d_to_reference_frame( | ||
| grid_2d=grid.array, centre=self.centre, angle=self.phi, xp=xp |
There was a problem hiding this comment.
The decorator @aa.grid_dec.transform has been removed from the image_2d_from method. This decorator typically handles the transformation of grid coordinates to the profile's reference frame. By removing it and manually calling transform_grid_2d_to_reference_frame, the code may not properly integrate with the parent class's geometry system and could produce incorrect results when the profile has elliptical components.
| grid_2d=grid.array, centre=self.centre, angle=self.phi, xp=xp | |
| grid_2d=grid, centre=self.centre, angle=self.phi, xp=xp |
| # zero over all space | ||
| return kappa_circ * (1 - asymm_term) + (alpha_circ / grid_radii) * asymm_term | ||
|
|
||
| kappa = self._convergence(grid_radii,xp) |
There was a problem hiding this comment.
Missing space after comma in function call. Should be self._convergence(grid_radii, xp) for consistency with Python style guidelines.
| kappa = self._convergence(grid_radii,xp) | |
| kappa = self._convergence(grid_radii, xp) |
| super().__init__( | ||
| centre=centre, ell_comps=ell_comps, beta=beta, intensity=intensity | ||
| centre=centre, beta=beta, intensity=intensity | ||
| ) |
There was a problem hiding this comment.
The parent class AbstractShapelet requires ell_comps as a parameter, but this change removes it from the super().init() call. This will cause an error because the parent class LightProfile (through AbstractShapelet) expects ell_comps to be provided. The parent's init signature is def __init__(self, centre, ell_comps, intensity, beta), so omitting ell_comps will result in a TypeError.
| def genlaguerre_jax(n, alpha, x): | ||
| """ | ||
| Generalized (associated) Laguerre polynomial L_n^alpha(x) | ||
| calculated using the explicit summation formula, optimized for JAX vectorization. | ||
|
|
||
| Parameters: | ||
| n (int): Degree of the polynomial (static Python integer). | ||
| alpha (Numeric): Parameter alpha > -1. | ||
| x (Array): Input array (evaluation points). | ||
| """ | ||
| # 0. Input Validation (Requires static Python int n) | ||
| if not isinstance(n, int) or n < 0: | ||
| # Use Python's math.isnan/isinf check if n is float, otherwise type error | ||
| raise ValueError(f"Degree n must be a non-negative Python integer (static), got {n}.") | ||
|
|
||
| # Base Case L0 | ||
| if n == 0: | ||
| return jnp.ones_like(x) | ||
|
|
||
| # 1. Generate k values for summation range [0, 1, 2, ..., n] | ||
| k_values = jnp.arange(n + 1) # (n+1,) | ||
|
|
||
| # 2. Reshape inputs for broadcasting (x: (M, 1), k: (1, n+1)) | ||
| x_expanded = jnp.expand_dims(x, axis=-1) | ||
| k_values_expanded = jnp.expand_dims(k_values, axis=0) | ||
|
|
||
| # --- A. Binomial Factor (BF) Calculation --- | ||
| # BF = exp( log( (n+alpha)! / ((n-k)! * (alpha+k)!) ) ) | ||
|
|
||
| log_N_plus_alpha_fact = gammaln(n + alpha + 1) | ||
|
|
||
| log_BF_k = ( | ||
| log_N_plus_alpha_fact | ||
| - gammaln(n - k_values + 1) # log( (n-k)! ) | ||
| - gammaln(alpha + k_values + 1) # log( (alpha+k)! ) | ||
| ) | ||
|
|
||
| BF_k = jnp.exp(log_BF_k) # Shape: (n+1,) | ||
|
|
||
| # --- B. Term Factor (TF) Calculation --- | ||
| # TF = (-x)^k / k! | ||
|
|
||
| # Note: jnp.math.gamma(k_values + 1) is equivalent to k! in log-gamma space | ||
| TF_k = jnp.power(-x_expanded, k_values_expanded) / jnp.exp(gammaln(k_values_expanded + 1)) | ||
| # TF_k Shape: (M, n+1) | ||
|
|
||
| # --- C. Final Summation --- | ||
| # Sum over the last axis (axis=1), which corresponds to k | ||
| # BF_k broadcasts over the M dimension of TF_k | ||
| return jnp.sum(BF_k * TF_k, axis=1) |
There was a problem hiding this comment.
The function uses hardcoded jnp operations instead of respecting the xp parameter that is passed to image_2d_from. This means the function will fail if numpy is expected (when xp=np). The genlaguerre_jax function should either use the xp parameter or there should be conditional logic to choose between a JAX and NumPy implementation.
| grid = aa.util.geometry.transform_grid_2d_to_reference_frame( | ||
| grid_2d=grid.array, centre=self.centre, angle=self.phi, xp=xp | ||
| ) |
There was a problem hiding this comment.
The manual transformation uses self.phi directly as the rotation angle, but this doesn't account for the elliptical coordinate system transformations that the parent class handles. The removed @aa.grid_dec.transform decorator would have applied the proper transformations including centre translation, rotation, and elliptical scaling based on ell_comps. The manual approach bypasses this and may produce incorrect results.
| / self.beta | ||
| / xp.sqrt(xp.pi) | ||
| ) | ||
| rsq = (grid[:, 0] ** 2 + (grid[:, 1]/self.q) ** 2) / self.beta**2 |
There was a problem hiding this comment.
The axis ratio self.q is applied inconsistently in the calculation. The code divides grid[:, 1] by self.q when computing rsq, but this doesn't account for the full elliptical transformation. In an elliptical coordinate system, both coordinates should be scaled properly, and the transformation should be applied after rotation to the major/minor axis frame. The current implementation may not correctly represent an elliptical shapelet.
| grid.array[:, 1] ** 2 / (1 + ellip) ** 2 + grid.array[:, 0] ** 2 / (1 - ellip) ** 2 | ||
| ) | ||
| # Compute the convergence and deflection of a *circular* profile | ||
| kappa = self._convergence(grid_radii,xp) |
There was a problem hiding this comment.
Missing space after comma in function call. Should be self._convergence(grid_radii, xp) for consistency with Python style guidelines.
| kappa = self._convergence(grid_radii,xp) | |
| kappa = self._convergence(grid_radii, xp) |
| self.n = int(n) | ||
| self.m = int(m) | ||
| self.phi = float(phi) | ||
| self.q = float(q) |
There was a problem hiding this comment.
The parameters q and phi are stored as instance attributes but they are not being used by the parent class. Since the parent class EllProfile expects ell_comps to define the elliptical geometry and uses it in transformation methods, replacing the parameter convention breaks the integration with the parent class geometry system. The code should either maintain ell_comps compatibility or the AbstractShapelet base class needs to be updated to support the new parameter convention.
| import jax.numpy as jnp | ||
| from jax.scipy.special import gammaln |
There was a problem hiding this comment.
The hardcoded import of jax.numpy as jnp at the module level forces a dependency on JAX for all users of this module, even if they are not using JAX. This breaks the existing pattern in the codebase where xp is used as a parameter to switch between numpy and JAX. Consider making JAX an optional dependency and importing it conditionally, or implementing a fallback to scipy.special.genlaguerre when JAX is not available.
|
Ok, I dug deeper on this and in the end think that the only thing that needed changing was that the axis ratio divison was missing here:
Which is now: This means I can retain the I have deleted the code which rotates the grid by phi as it just does what the I have also fixed Cartesian shapelets based on the issues you spottted. Once you're back on the project please check that the merged code works compared to what you had before, I could of re introduced a stretching error but I think my changes simplify the code whilst maintaining the correct functionality. |
This PR introduces two quick fixes:
PIEMass,dPIEMass, anddPIEMassSph— it adds the correctconvergence_2d_fromfunction, replacing the previous version that had been copied fromdPIEPotential.The reasons for these changes are:
genlaguerrewas taken from SciPy, butjax.scipydoes not provide this function.(0, 0), which forces bothaxis_ratio=0andposition_angle=0. This is fine for circularly symmetric profiles, but higher-order shapelets contain angular components, so rotation still has meaning even in circular cases. Likewise, although increasing shapelet order can approximate anything, allowing rotation lets us describe the light distribution with fewer shapelet terms.intensity = 1.0.Open questions:
genlaguerrein JAX, ShapeletPolar now only supports JAX, unlike the rest of the codebase wherexpswitches betweenjnpandnp. I’m undecided whether to simply replace alljnpingenlaguerre_jaxwithxp, or to add logic that chooses betweengenlaguerre_jaxandscipy.genlaguerredepending onxp.ell_compsto the(q, φ)convention,ShapeletPolarproduces correct images but fails during fitting (error inadd_value_to_hash_list). This might be due to using parameters outside the expected specification, or due to unsafe type conversions (such asself.phi = float(phi)).ShapeletPolarfor now. Extending the same updates to the other shapelets should be straightforward — for example,ShapeletCartesianonly needs a JAX implementation of the Hermite functions. After theadd_value_to_hash_listissue is solved, I can update those as well if you don’t have time.SersicletPolaris also on my roadmap. After the above issues are resolved, I plan to implement aSersicletthat reduces to a Shapelet in special cases while incorporating a Sersic-like weighting function.