Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions autogalaxy/analysis/jax_pytrees.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Shared JAX pytree registrations for autogalaxy analysis classes.

Each ``Analysis*`` class registers its own ``Fit*`` and per-analysis
constants inline (so the call site stays self-documenting), but the
``Galaxies`` registration is shared from here because the custom
flatten/unflatten logic is non-trivial and identical across all
analyses that hold a ``Galaxies`` aggregate.
"""


def register_galaxies_pytree() -> None:
"""Register ``Galaxies`` as a JAX pytree.

``Galaxies`` is a ``list`` subclass — the generic ``__dict__`` flatten
in ``register_instance_pytree`` would drop the list contents. This
registers a custom flatten that carries the list items as dynamic
children and the ``__dict__`` entries as aux.

Idempotent — guarded by ``_pytree_registered_classes`` so repeated
calls (e.g. from each ``Analysis*.fit_from``) are cheap.
"""
from autoarray.abstract_ndarray import _pytree_registered_classes
from autoconf.jax_wrapper import register_pytree_node
from autogalaxy.galaxy.galaxies import Galaxies

if Galaxies in _pytree_registered_classes:
return

def _flatten_galaxies(galaxies):
dict_items = tuple(sorted(galaxies.__dict__.items()))
return tuple(galaxies), dict_items

def _unflatten_galaxies(aux, children):
new = Galaxies.__new__(Galaxies)
list.__init__(new, children)
for key, value in aux:
setattr(new, key, value)
return new

register_pytree_node(Galaxies, _flatten_galaxies, _unflatten_galaxies)
_pytree_registered_classes.add(Galaxies)
27 changes: 3 additions & 24 deletions autogalaxy/imaging/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,37 +175,16 @@ def _register_fit_imaging_pytrees() -> None:
else (``galaxies``, ``dataset_model`` and the autoarray wrappers they
carry) is dynamic per fit.
"""
from autoarray.abstract_ndarray import (
_pytree_registered_classes,
register_instance_pytree,
)
from autoarray.abstract_ndarray import register_instance_pytree
from autoarray.dataset.dataset_model import DatasetModel
from autoconf.jax_wrapper import register_pytree_node
from autogalaxy.galaxy.galaxies import Galaxies
from autogalaxy.analysis.jax_pytrees import register_galaxies_pytree

register_instance_pytree(
FitImaging,
no_flatten=("dataset", "adapt_images", "settings"),
)
register_instance_pytree(DatasetModel)

# ``Galaxies`` is a ``list`` subclass — the generic ``__dict__`` flatten
# in ``register_instance_pytree`` would drop the list contents. Register
# a custom flatten that carries the list items as dynamic children.
if Galaxies not in _pytree_registered_classes:
def _flatten_galaxies(galaxies):
dict_items = tuple(sorted(galaxies.__dict__.items()))
return tuple(galaxies), dict_items

def _unflatten_galaxies(aux, children):
new = Galaxies.__new__(Galaxies)
list.__init__(new, children)
for key, value in aux:
setattr(new, key, value)
return new

register_pytree_node(Galaxies, _flatten_galaxies, _unflatten_galaxies)
_pytree_registered_classes.add(Galaxies)
register_galaxies_pytree()

def save_attributes(self, paths: af.DirectoryPaths):
"""
Expand Down
25 changes: 25 additions & 0 deletions autogalaxy/interferometer/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ def fit_from(self, instance: af.ModelInstance) -> FitInterferometer:
FitInterferometer
The fit of the galaxies to the interferometer dataset, which includes the log likelihood.
"""

if self._use_jax:
self._register_fit_interferometer_pytrees()

galaxies = self.galaxies_via_instance_from(
instance=instance,
)
Expand All @@ -158,6 +162,27 @@ def fit_from(self, instance: af.ModelInstance) -> FitInterferometer:
xp=self._xp,
)

@staticmethod
def _register_fit_interferometer_pytrees() -> None:
"""Register every type reachable from a ``FitInterferometer`` return
value so ``jax.jit(fit_from)`` can flatten its output.

``dataset``, ``adapt_images`` and ``settings`` are constants per
analysis — ride as aux so JAX does not recurse into them. Everything
else (``galaxies`` and the autoarray wrappers it carries) is dynamic
per fit.
"""
from autoarray.abstract_ndarray import register_instance_pytree
from autoarray.dataset.dataset_model import DatasetModel
from autogalaxy.analysis.jax_pytrees import register_galaxies_pytree

register_instance_pytree(
FitInterferometer,
no_flatten=("dataset", "adapt_images", "settings"),
)
register_instance_pytree(DatasetModel)
register_galaxies_pytree()

def save_attributes(self, paths: af.DirectoryPaths):
"""
Before the model-fit begins, this routine saves attributes of the `Analysis` object to the `files` folder
Expand Down
Loading