Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions scripts/ellipse/modeling_visualization_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
correctly during the live search callback.

This script deliberately opts in with
``AnalysisEllipse(use_jax=True, use_jax_for_visualization=True)``.
Default ellipse model-fit scripts elsewhere in the workspace leave both
flags at ``False`` and are therefore untouched by this change.
``AnalysisEllipse(use_jax=True)``.
Default ellipse model-fit scripts elsewhere in the workspace leave the flag
at ``False`` and are therefore untouched by this change.
"""

import shutil
Expand All @@ -37,9 +37,7 @@

import autofit as af
import autogalaxy as ag
from autofit.jax.pytrees import enable_pytrees, register_model

enable_pytrees()


"""
Expand Down Expand Up @@ -92,12 +90,10 @@

model_mge = af.Collection(ellipses=af.Collection(ellipse_0=ellipse_mge))

register_model(model_mge)

analysis_mge = ag.AnalysisEllipse(
dataset=dataset,
use_jax=True,
use_jax_for_visualization=True,
)

instance_mge = model_mge.instance_from_prior_medians()
Expand Down Expand Up @@ -125,9 +121,6 @@
f"Cached call ({cached_time:.3f}s) not faster than compile "
f"({compile_time:.3f}s) — JIT cache is not being hit."
)
assert (
analysis_mge._jitted_fit_from is not None
), "expected _jitted_fit_from to be cached on the analysis instance after first call"
print("PASS: Ellipse jit-cached fit_for_visualization works and is reused.")


Expand Down Expand Up @@ -177,12 +170,10 @@

model_mge2 = af.Collection(ellipses=af.Collection(ellipse_0=ellipse_2))

register_model(model_mge2)

analysis_mge2 = ag.AnalysisEllipse(
dataset=dataset,
use_jax=True,
use_jax_for_visualization=True,
)

output_root = Path("scripts") / "ellipse" / "images" / "modeling_visualization_jit"
Expand Down
31 changes: 13 additions & 18 deletions scripts/ellipse/visualization_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
Visualization JAX Pilot: Ellipse Analysis (autogalaxy)
======================================================

Tests that ``VisualizerEllipse.visualize`` with
``use_jax_for_visualization=True`` dispatches through the JIT-cached
``fit_for_visualization`` path that the parent ``af.Analysis`` already
provides. ``AnalysisEllipse.__init__`` passes ``**kwargs`` to its
parent, so ``use_jax_for_visualization=True`` flows through to the
PyAutoFit-level dispatch without a library-side change.
Tests that ``VisualizerEllipse.visualize`` with ``use_jax=True`` dispatches
through the JIT-cached ``fit_for_visualization`` path that the parent
``af.Analysis`` already provides. Visualization follows ``use_jax``
automatically — ``AnalysisEllipse.__init__`` passes ``**kwargs`` to its
parent without a library-side change needed.

Scope
-----
Expand All @@ -16,8 +15,8 @@
- Calls ``VisualizerEllipse.visualize`` only (not ``visualize_before_fit``).
- Reuses the ``dataset/imaging/jax_test`` dataset that the
``jax_likelihood_functions`` scripts produce.
- ``use_jax=True`` turns on the JAX path; ``use_jax_for_visualization=True``
routes ``Visualizer*.visualize`` through ``analysis.fit_for_visualization``.
- ``use_jax=True`` turns on the JAX path and routes ``Visualizer*.visualize``
through ``analysis.fit_for_visualization``.
"""

import shutil
Expand All @@ -36,10 +35,8 @@

import autofit as af
import autogalaxy as ag
from autofit.jax.pytrees import enable_pytrees, register_model
from autogalaxy.ellipse.model.visualizer import VisualizerEllipse

enable_pytrees()


"""
Expand Down Expand Up @@ -85,22 +82,20 @@

model = af.Collection(ellipses=af.Collection(ellipse_0=ellipse))

register_model(model)


"""
__Analysis__

``use_jax=True`` turns on the JAX path; ``use_jax_for_visualization=True``
tells the visualizer to dispatch through the JIT-cached
``fit_for_visualization`` helper on the parent ``af.Analysis``.
``AnalysisEllipse.__init__`` accepts ``**kwargs`` and forwards them to
``super().__init__``, so no AnalysisEllipse signature change is needed.
``use_jax=True`` turns on the JAX path. Visualization follows ``use_jax``
automatically via the JIT-cached ``fit_for_visualization`` helper on the
parent ``af.Analysis``. ``AnalysisEllipse.__init__`` accepts ``**kwargs``
and forwards them to ``super().__init__``, so no AnalysisEllipse signature
change is needed.
"""
analysis = ag.AnalysisEllipse(
dataset=dataset,
use_jax=True,
use_jax_for_visualization=True,
title_prefix="JAX_PILOT",
)

Expand All @@ -122,7 +117,7 @@
"""
instance = model.instance_from_prior_medians()

print("Running VisualizerEllipse.visualize with use_jax_for_visualization=True ...")
print("Running VisualizerEllipse.visualize with use_jax=True ...")
VisualizerEllipse.visualize(
analysis=analysis,
paths=paths,
Expand Down
15 changes: 3 additions & 12 deletions scripts/imaging/modeling_visualization_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
during the live search callback.

This script deliberately opts in with
``AnalysisImaging(use_jax=True, use_jax_for_visualization=True)``. Default
model-fit scripts elsewhere in the workspace leave both flags at ``False``
and are therefore untouched by this change.
``AnalysisImaging(use_jax=True)``. Default model-fit scripts elsewhere in the
workspace leave the flag at ``False`` and are therefore untouched by this
change.
"""

import shutil
Expand All @@ -38,9 +38,7 @@

import autofit as af
import autogalaxy as ag
from autofit.jax.pytrees import enable_pytrees, register_model

enable_pytrees()


"""
Expand Down Expand Up @@ -108,12 +106,10 @@

model_mge = af.Collection(galaxies=af.Collection(galaxy=galaxy_mge))

register_model(model_mge)

analysis_mge = ag.AnalysisImaging(
dataset=dataset,
use_jax=True,
use_jax_for_visualization=True,
)

instance_mge = model_mge.instance_from_prior_medians()
Expand Down Expand Up @@ -141,9 +137,6 @@
f"Cached call ({cached_time:.3f}s) not faster than compile "
f"({compile_time:.3f}s) — JIT cache is not being hit."
)
assert (
analysis_mge._jitted_fit_from is not None
), "expected _jitted_fit_from to be cached on the analysis instance after first call"
print("PASS: MGE jit-cached fit_for_visualization works and is reused.")


Expand Down Expand Up @@ -213,12 +206,10 @@

model_mge2 = af.Collection(galaxies=af.Collection(galaxy=galaxy_mge2))

register_model(model_mge2)

analysis_mge2 = ag.AnalysisImaging(
dataset=dataset,
use_jax=True,
use_jax_for_visualization=True,
)

output_root = Path("scripts") / "imaging" / "images" / "modeling_visualization_jit"
Expand Down
25 changes: 10 additions & 15 deletions scripts/imaging/visualization_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

Goal
----
Run ``VisualizerImaging.visualize`` with JAX enabled end-to-end, gated behind
``use_jax_for_visualization=True`` on ``Analysis``. After PyAutoGalaxy #390
(2026-05-08) the imaging visualizer dispatches through
``analysis.fit_for_visualization``, which lazily wraps ``fit_from`` in
``jax.jit``. To trace across that boundary the model and fit return type
must be JAX pytrees, so this script enables pytree registration before
constructing the model. Parametric MGE galaxy — simplest case (no
Run ``VisualizerImaging.visualize`` with JAX enabled end-to-end via
``use_jax=True`` on ``Analysis``. After PyAutoGalaxy #390 (2026-05-08) the
imaging visualizer dispatches through ``analysis.fit_for_visualization``,
which lazily wraps ``fit_from`` in ``jax.jit``. Visualization now follows
``use_jax`` automatically. To trace across that boundary the model and fit
return type must be JAX pytrees, so this script enables pytree registration
before constructing the model. Parametric MGE galaxy — simplest case (no
pixelization, no inversion).

Scope
Expand All @@ -39,10 +39,8 @@

import autofit as af
import autogalaxy as ag
from autofit.jax.pytrees import enable_pytrees, register_model
from autogalaxy.imaging.model.visualizer import VisualizerImaging

enable_pytrees()


"""
Expand Down Expand Up @@ -87,20 +85,17 @@
galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=galaxy_bulge)
model = af.Collection(galaxies=af.Collection(galaxy=galaxy))

register_model(model)


"""
__Analysis__

``use_jax=True`` turns on the JAX ``_xp`` path; ``use_jax_for_visualization=True``
tells the search-level visualization path to wrap ``fit_from`` in ``jax.jit``
via the ``Analysis.fit_for_visualization`` helper.
``use_jax=True`` turns on the JAX ``_xp`` path. Visualization now follows
``use_jax`` automatically via the ``Analysis.fit_for_visualization`` helper.
"""
analysis = ag.AnalysisImaging(
dataset=dataset,
use_jax=True,
use_jax_for_visualization=True,
title_prefix="JAX_PILOT",
)

Expand All @@ -122,7 +117,7 @@
"""
instance = model.instance_from_prior_medians()

print("Running VisualizerImaging.visualize with use_jax_for_visualization=True ...")
print("Running VisualizerImaging.visualize with use_jax=True ...")
VisualizerImaging.visualize(
analysis=analysis,
paths=paths,
Expand Down
15 changes: 3 additions & 12 deletions scripts/interferometer/modeling_visualization_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
during the live search callback.

This script deliberately opts in with
``AnalysisInterferometer(use_jax=True, use_jax_for_visualization=True)``.
Default model-fit scripts elsewhere in the workspace leave both flags at
``False`` and are therefore untouched by this change.
``AnalysisInterferometer(use_jax=True)``. Default model-fit scripts elsewhere
in the workspace leave the flag at ``False`` and are therefore untouched by
this change.
"""

import shutil
Expand All @@ -40,9 +40,7 @@

import autofit as af
import autogalaxy as ag
from autofit.jax.pytrees import enable_pytrees, register_model

enable_pytrees()


"""
Expand Down Expand Up @@ -114,12 +112,10 @@

model_mge = af.Collection(galaxies=af.Collection(galaxy=galaxy_mge))

register_model(model_mge)

analysis_mge = ag.AnalysisInterferometer(
dataset=dataset,
use_jax=True,
use_jax_for_visualization=True,
)

instance_mge = model_mge.instance_from_prior_medians()
Expand Down Expand Up @@ -147,9 +143,6 @@
f"Cached call ({cached_time:.3f}s) not faster than compile "
f"({compile_time:.3f}s) — JIT cache is not being hit."
)
assert (
analysis_mge._jitted_fit_from is not None
), "expected _jitted_fit_from to be cached on the analysis instance after first call"
print("PASS: MGE jit-cached fit_for_visualization works and is reused.")


Expand Down Expand Up @@ -224,12 +217,10 @@

model_mge2 = af.Collection(galaxies=af.Collection(galaxy=galaxy_mge2))

register_model(model_mge2)

analysis_mge2 = ag.AnalysisInterferometer(
dataset=dataset,
use_jax=True,
use_jax_for_visualization=True,
)

output_root = (
Expand Down
28 changes: 11 additions & 17 deletions scripts/interferometer/visualization_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

Goal
----
Run ``VisualizerInterferometer.visualize`` with JAX enabled end-to-end, gated
behind ``use_jax_for_visualization=True`` on ``AnalysisInterferometer``. The
interferometer visualizer dispatches through ``analysis.fit_for_visualization``,
which lazily wraps ``fit_from`` in ``jax.jit``. To trace across that boundary
the model and fit return type must be JAX pytrees, so this script enables
pytree registration before constructing the model.
Run ``VisualizerInterferometer.visualize`` with JAX enabled end-to-end via
``use_jax=True`` on ``AnalysisInterferometer``. Visualization now follows
``use_jax`` automatically — the interferometer visualizer dispatches through
``analysis.fit_for_visualization``, which lazily wraps ``fit_from`` in
``jax.jit``. To trace across that boundary the model and fit return type must
be JAX pytrees, so this script enables pytree registration before constructing
the model.

Scope
-----
Expand All @@ -30,10 +31,8 @@

import autofit as af
import autogalaxy as ag
from autofit.jax.pytrees import enable_pytrees, register_model
from autogalaxy.interferometer.model.visualizer import VisualizerInterferometer

enable_pytrees()


"""
Expand Down Expand Up @@ -75,30 +74,25 @@
"""
__Model__

Single-galaxy MGE parametric model. Pytree registration is required before
constructing the model so that the model and fit return type are valid JAX
pytrees at the ``jax.jit`` boundary.
Single-galaxy MGE parametric model.
"""
bulge = ag.model_util.mge_model_from(
mask_radius=mask_radius, total_gaussians=20, centre_prior_is_uniform=True
)
galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge)
model = af.Collection(galaxies=af.Collection(galaxy=galaxy))

register_model(model)


"""
__Analysis__

``use_jax=True`` turns on the JAX ``_xp`` path; ``use_jax_for_visualization=True``
tells the search-level visualization path to wrap ``fit_from`` in ``jax.jit``
via the ``Analysis.fit_for_visualization`` helper.
``use_jax=True`` turns on the JAX ``_xp`` path. Visualization now follows
``use_jax`` automatically via the ``Analysis.fit_for_visualization`` helper.
"""
analysis = ag.AnalysisInterferometer(
dataset=dataset,
use_jax=True,
use_jax_for_visualization=True,
title_prefix="JAX_PILOT",
)

Expand All @@ -121,7 +115,7 @@
instance = model.instance_from_prior_medians()

print(
"Running VisualizerInterferometer.visualize with use_jax_for_visualization=True ..."
"Running VisualizerInterferometer.visualize with use_jax=True ..."
)
VisualizerInterferometer.visualize(
analysis=analysis,
Expand Down
Loading
Loading