diff --git a/autogalaxy/analysis/jax_pytrees.py b/autogalaxy/analysis/jax_pytrees.py new file mode 100644 index 00000000..cfc03305 --- /dev/null +++ b/autogalaxy/analysis/jax_pytrees.py @@ -0,0 +1,41 @@ +"""Shared JAX pytree registrations for autogalaxy analysis classes. + +Each ``Analysis*`` class registers its own ``Fit*`` and per-analysis +constants inline (so the call site stays self-documenting), but the +``Galaxies`` registration is shared from here because the custom +flatten/unflatten logic is non-trivial and identical across all +analyses that hold a ``Galaxies`` aggregate. +""" + + +def register_galaxies_pytree() -> None: + """Register ``Galaxies`` as a JAX pytree. + + ``Galaxies`` is a ``list`` subclass — the generic ``__dict__`` flatten + in ``register_instance_pytree`` would drop the list contents. This + registers a custom flatten that carries the list items as dynamic + children and the ``__dict__`` entries as aux. + + Idempotent — guarded by ``_pytree_registered_classes`` so repeated + calls (e.g. from each ``Analysis*.fit_from``) are cheap. + """ + from autoarray.abstract_ndarray import _pytree_registered_classes + from autoconf.jax_wrapper import register_pytree_node + from autogalaxy.galaxy.galaxies import Galaxies + + if Galaxies in _pytree_registered_classes: + return + + def _flatten_galaxies(galaxies): + dict_items = tuple(sorted(galaxies.__dict__.items())) + return tuple(galaxies), dict_items + + def _unflatten_galaxies(aux, children): + new = Galaxies.__new__(Galaxies) + list.__init__(new, children) + for key, value in aux: + setattr(new, key, value) + return new + + register_pytree_node(Galaxies, _flatten_galaxies, _unflatten_galaxies) + _pytree_registered_classes.add(Galaxies) diff --git a/autogalaxy/imaging/model/analysis.py b/autogalaxy/imaging/model/analysis.py index d6610a32..f24e45bd 100644 --- a/autogalaxy/imaging/model/analysis.py +++ b/autogalaxy/imaging/model/analysis.py @@ -175,37 +175,16 @@ def _register_fit_imaging_pytrees() -> None: else (``galaxies``, ``dataset_model`` and the autoarray wrappers they carry) is dynamic per fit. """ - from autoarray.abstract_ndarray import ( - _pytree_registered_classes, - register_instance_pytree, - ) + from autoarray.abstract_ndarray import register_instance_pytree from autoarray.dataset.dataset_model import DatasetModel - from autoconf.jax_wrapper import register_pytree_node - from autogalaxy.galaxy.galaxies import Galaxies + from autogalaxy.analysis.jax_pytrees import register_galaxies_pytree register_instance_pytree( FitImaging, no_flatten=("dataset", "adapt_images", "settings"), ) register_instance_pytree(DatasetModel) - - # ``Galaxies`` is a ``list`` subclass — the generic ``__dict__`` flatten - # in ``register_instance_pytree`` would drop the list contents. Register - # a custom flatten that carries the list items as dynamic children. - if Galaxies not in _pytree_registered_classes: - def _flatten_galaxies(galaxies): - dict_items = tuple(sorted(galaxies.__dict__.items())) - return tuple(galaxies), dict_items - - def _unflatten_galaxies(aux, children): - new = Galaxies.__new__(Galaxies) - list.__init__(new, children) - for key, value in aux: - setattr(new, key, value) - return new - - register_pytree_node(Galaxies, _flatten_galaxies, _unflatten_galaxies) - _pytree_registered_classes.add(Galaxies) + register_galaxies_pytree() def save_attributes(self, paths: af.DirectoryPaths): """ diff --git a/autogalaxy/interferometer/model/analysis.py b/autogalaxy/interferometer/model/analysis.py index d45e6da6..2d0be29d 100644 --- a/autogalaxy/interferometer/model/analysis.py +++ b/autogalaxy/interferometer/model/analysis.py @@ -142,6 +142,10 @@ def fit_from(self, instance: af.ModelInstance) -> FitInterferometer: FitInterferometer The fit of the galaxies to the interferometer dataset, which includes the log likelihood. """ + + if self._use_jax: + self._register_fit_interferometer_pytrees() + galaxies = self.galaxies_via_instance_from( instance=instance, ) @@ -158,6 +162,27 @@ def fit_from(self, instance: af.ModelInstance) -> FitInterferometer: xp=self._xp, ) + @staticmethod + def _register_fit_interferometer_pytrees() -> None: + """Register every type reachable from a ``FitInterferometer`` return + value so ``jax.jit(fit_from)`` can flatten its output. + + ``dataset``, ``adapt_images`` and ``settings`` are constants per + analysis — ride as aux so JAX does not recurse into them. Everything + else (``galaxies`` and the autoarray wrappers it carries) is dynamic + per fit. + """ + from autoarray.abstract_ndarray import register_instance_pytree + from autoarray.dataset.dataset_model import DatasetModel + from autogalaxy.analysis.jax_pytrees import register_galaxies_pytree + + register_instance_pytree( + FitInterferometer, + no_flatten=("dataset", "adapt_images", "settings"), + ) + register_instance_pytree(DatasetModel) + register_galaxies_pytree() + def save_attributes(self, paths: af.DirectoryPaths): """ Before the model-fit begins, this routine saves attributes of the `Analysis` object to the `files` folder