Skip to content

Commit 1f4a81a

Browse files
Jammy2211claude
authored andcommitted
feat: AnalysisEllipse.fit_from + JAX pytree registration
Wires AnalysisEllipse for end-to-end JAX traceability. Adds use_jax flag (default True, matching AnalysisImaging), a fit_from method returning a new FitEllipseSummed aggregate, and a lazy idempotent _register_fit_ellipse_pytrees that registers FitEllipse, FitEllipseSummed, DatasetModel, Ellipse, EllipseMultipole, and EllipseMultipoleScaled as JAX pytrees on the first fit_from call. log_likelihood_function now returns figure_of_merit (sum over the per-ellipse fit_list) instead of log_likelihood, matching AnalysisImaging exactly. The two differ by the noise_normalization term per FitEllipse: figure_of_merit = -0.5 * (chi_squared + noise_normalization), log_likelihood = -0.5 * chi_squared. Step 7 of 7 in z_features/ellipse_fitting_jax.md — the keystone. After this lands, ellipse modeling runs inside any JAX-compatible search via AnalysisEllipse(dataset, use_jax=True). Workspace follow-up flips prompt 2's workspace_test scripts to a JIT round-trip with rtol=1e-4 parity against the locked-in numpy reference numbers. Issue PyAutoGalaxy#411. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 8e50837 commit 1f4a81a

4 files changed

Lines changed: 413 additions & 230 deletions

File tree

autogalaxy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from .ellipse.ellipse.ellipse_multipole import EllipseMultipole
6666
from .ellipse.ellipse.ellipse_multipole import EllipseMultipoleScaled
6767
from .ellipse.fit_ellipse import FitEllipse
68+
from .ellipse.fit_ellipse import FitEllipseSummed
6869
from .ellipse.model.analysis import AnalysisEllipse
6970
from .operate.image import OperateImage
7071
from .operate.image import OperateImageList

autogalaxy/ellipse/fit_ellipse.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
This is the classical isophote-fitting approach used in tools such as IRAF/ELLIPSE and galfit, and is
1212
appropriate for measuring galaxy morphology, position angles, axis ratios, and multipole perturbations
1313
directly from imaging data without fitting a parametric light profile model.
14+
15+
`FitEllipseSummed` aggregates multiple `FitEllipse` objects (one per ellipse in a model) and exposes
16+
sum-over-list versions of `figure_of_merit`, `log_likelihood`, and `chi_squared`. This is the return
17+
type of `AnalysisEllipse.fit_from`, mirroring the single-object return of `AnalysisImaging.fit_from`.
1418
"""
1519
import numpy as np
1620
from typing import List, Optional
@@ -333,3 +337,54 @@ def figure_of_merit(self) -> float:
333337
The figure of merit of the fit.
334338
"""
335339
return -0.5 * (self.chi_squared + self.noise_normalization)
340+
341+
342+
class FitEllipseSummed:
343+
"""
344+
Aggregate of one or more :class:`FitEllipse` objects whose
345+
``figure_of_merit`` / ``log_likelihood`` / ``chi_squared`` properties
346+
sum over the contained fits.
347+
348+
Used by :class:`AnalysisEllipse.fit_from` so the return type is a
349+
single object (matching :class:`AnalysisImaging.fit_from`'s pattern)
350+
even when a model contains multiple ellipses. Each contained
351+
:class:`FitEllipse` carries the same shared ``dataset``; this class
352+
exposes that dataset for JAX pytree-registration purposes
353+
(``no_flatten=("dataset",)``).
354+
"""
355+
356+
def __init__(self, fit_list: List[FitEllipse]):
357+
self.fit_list = fit_list
358+
359+
@property
360+
def dataset(self):
361+
"""
362+
All fits in the list share the same dataset; expose the first
363+
for pytree no_flatten purposes.
364+
"""
365+
return self.fit_list[0].dataset
366+
367+
@property
368+
def figure_of_merit(self):
369+
"""
370+
The sum of the ``figure_of_merit`` values of every contained
371+
:class:`FitEllipse`, used by the non-linear search as the
372+
overall objective.
373+
"""
374+
return sum(f.figure_of_merit for f in self.fit_list)
375+
376+
@property
377+
def log_likelihood(self):
378+
"""
379+
The sum of the ``log_likelihood`` values of every contained
380+
:class:`FitEllipse`.
381+
"""
382+
return sum(f.log_likelihood for f in self.fit_list)
383+
384+
@property
385+
def chi_squared(self):
386+
"""
387+
The sum of the ``chi_squared`` values of every contained
388+
:class:`FitEllipse`.
389+
"""
390+
return sum(f.chi_squared for f in self.fit_list)

0 commit comments

Comments
 (0)