Skip to content

Commit f317c2e

Browse files
Jammy2211claude
authored andcommitted
Register FitInterferometer return value as a JAX pytree
Adds ``_register_fit_interferometer_pytrees`` to ``AnalysisInterferometer``, mirroring the imaging-side ``_register_fit_imaging_pytrees``. Called at the top of ``fit_from`` when ``use_jax=True`` so ``jax.jit(analysis.fit_from)`` can flatten the returned ``FitInterferometer``. Registers: - ``FitInterferometer`` with dataset / adapt_images / settings as aux - ``Tracer`` with cosmology as aux (already covered by the imaging path but harmless to re-register under the interferometer site) - ``DatasetModel`` — always present on ``FitDataset`` subclasses even when ``fit_from`` doesn't pass one explicitly (``FitDataset.__init__`` falls back to ``DatasetModel()``) Unlocks the Path A PoC in ``autolens_workspace_test/scripts/jax_likelihood_functions/interferometer/mge_pytree.py`` for the MGE-source interferometer variant. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent f2d5cd6 commit f317c2e

1 file changed

Lines changed: 24 additions & 0 deletions

File tree

autolens/interferometer/model/analysis.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ def fit_from(self, instance: af.ModelInstance) -> FitInterferometer:
173173
The fit of the plane to the interferometer dataset, which includes the log likelihood.
174174
"""
175175

176+
if self._use_jax:
177+
self._register_fit_interferometer_pytrees()
178+
176179
tracer = self.tracer_via_instance_from(
177180
instance=instance,
178181
)
@@ -187,6 +190,27 @@ def fit_from(self, instance: af.ModelInstance) -> FitInterferometer:
187190
xp=self._xp,
188191
)
189192

193+
@staticmethod
194+
def _register_fit_interferometer_pytrees() -> None:
195+
"""Register every type reachable from a ``FitInterferometer`` return
196+
value so ``jax.jit(fit_from)`` can flatten its output.
197+
198+
``dataset``, ``adapt_images`` and ``settings`` are constants per
199+
analysis — ride as aux so JAX does not recurse into them. Everything
200+
else (``tracer`` and the autoarray wrappers it carries) is dynamic
201+
per fit.
202+
"""
203+
from autoarray.abstract_ndarray import register_instance_pytree
204+
from autoarray.dataset.dataset_model import DatasetModel # fit-interferometer-pytree-mge
205+
from autolens.lens.tracer import Tracer
206+
207+
register_instance_pytree(
208+
FitInterferometer,
209+
no_flatten=("dataset", "adapt_images", "settings"),
210+
)
211+
register_instance_pytree(Tracer, no_flatten=("cosmology",))
212+
register_instance_pytree(DatasetModel) # fit-interferometer-pytree-mge
213+
190214
def save_attributes(self, paths: af.DirectoryPaths):
191215
"""
192216
Before the model-fit begins, this routine saves attributes of the `Analysis` object to the `files` folder

0 commit comments

Comments
 (0)