Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions scripts/cluster/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import autolens as al
import autolens.plot as aplt

from autofit.jax import register_model as _register_model_pytrees
from autoarray.abstract_ndarray import register_instance_pytree
from autolens.lens.tracer import Tracer

Expand Down Expand Up @@ -165,7 +164,6 @@
galaxies=af.Collection(*(_lens_models + [_halo_model] + _source_models))
)

_register_model_pytrees(_registration_model)
register_instance_pytree(Tracer, no_flatten=("cosmology",))


Expand Down
19 changes: 3 additions & 16 deletions scripts/imaging/modeling_visualization_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
fit_for_visualization fires correctly during the live search callback.

This script deliberately opts in with
``AnalysisImaging(use_jax=True, use_jax_for_visualization=True)``. Default
model-fit scripts elsewhere in the workspace leave both flags at ``False``
and are therefore untouched by this change.
``AnalysisImaging(use_jax=True)``. Default model-fit scripts elsewhere in the
workspace leave the flag at ``False`` and are therefore untouched by this
change.
"""

import shutil
Expand All @@ -40,9 +40,7 @@

import autofit as af
import autolens as al
from autofit.jax.pytrees import enable_pytrees, register_model

enable_pytrees()


"""
Expand Down Expand Up @@ -129,12 +127,10 @@

model_mge = af.Collection(galaxies=af.Collection(lens=lens_mge, source=source_mge))

register_model(model_mge)

analysis_mge = al.AnalysisImaging(
dataset=dataset,
use_jax=True,
use_jax_for_visualization=True,
)

instance_mge = model_mge.instance_from_prior_medians()
Expand Down Expand Up @@ -162,9 +158,6 @@
f"Cached call ({cached_time:.3f}s) not faster than compile "
f"({compile_time:.3f}s) — JIT cache is not being hit."
)
assert (
analysis_mge._jitted_fit_from is not None
), "expected _jitted_fit_from to be cached on the analysis instance after first call"
print("PASS: MGE jit-cached fit_for_visualization works and is reused.")


Expand Down Expand Up @@ -288,12 +281,10 @@

model_mge2 = af.Collection(galaxies=af.Collection(lens=lens_mge2, source=source_mge2))

register_model(model_mge2)

analysis_mge2 = al.AnalysisImaging(
dataset=dataset,
use_jax=True,
use_jax_for_visualization=True,
)

output_root = Path("scripts") / "imaging" / "images" / "modeling_visualization_jit"
Expand Down Expand Up @@ -325,10 +316,6 @@
f"no fit.png produced under {output_search_root} — "
"quick-update visualization did not fire"
)
assert (
analysis_mge2._jitted_fit_from is not None
), "expected _jitted_fit_from to be cached on the analysis instance during search"

print(
"\nPASS: jit-cached fit_for_visualization fires during Nautilus quick updates "
"with MGE linear profiles, fit.png written, no KeyError from "
Expand Down
12 changes: 0 additions & 12 deletions scripts/imaging/modeling_visualization_jit_delaunay.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@

import autofit as af
import autolens as al
from autofit.jax.pytrees import enable_pytrees, register_model

enable_pytrees()


"""
Expand Down Expand Up @@ -159,7 +157,6 @@

model = af.Collection(galaxies=af.Collection(lens=lens, source=source))

register_model(model)


"""
Expand All @@ -176,7 +173,6 @@
adapt_images=adapt_images,
raise_inversion_positions_likelihood_exception=False,
use_jax=True,
use_jax_for_visualization=True,
)

instance_probe = model.instance_from_prior_medians()
Expand Down Expand Up @@ -261,9 +257,6 @@ def _assert_likelihood_sanity(label, analysis, model):
f"Cached call ({cached_time:.3f}s) not faster than compile "
f"({compile_time:.3f}s) — JIT cache is not being hit."
)
assert (
analysis_probe._jitted_fit_from is not None
), "expected _jitted_fit_from to be cached on the analysis instance after first call"
print("PASS: Delaunay jit-cached fit_for_visualization works and is reused.")


Expand Down Expand Up @@ -337,7 +330,6 @@ def _assert_likelihood_sanity(label, analysis, model):
adapt_images=adapt_images,
raise_inversion_positions_likelihood_exception=False,
use_jax=True,
use_jax_for_visualization=True,
)

output_root = (
Expand Down Expand Up @@ -369,10 +361,6 @@ def _assert_likelihood_sanity(label, analysis, model):
f"no fit.png produced under {output_search_root} — "
"quick-update visualization did not fire"
)
assert (
analysis_live._jitted_fit_from is not None
), "expected _jitted_fit_from to be cached on the analysis instance during search"

print(
"\nPASS: jit-cached fit_for_visualization fires during Nautilus quick updates "
"with a Delaunay-pixelization source, fit.png written."
Expand Down
14 changes: 1 addition & 13 deletions scripts/imaging/modeling_visualization_jit_rectangular.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
``galaxy_image_plane_mesh_grid_dict`` / ``galaxy_image_dict`` lookups.

This script deliberately opts in with
``AnalysisImaging(use_jax=True, use_jax_for_visualization=True)``.
``AnalysisImaging(use_jax=True)``.
"""

import shutil
Expand All @@ -36,9 +36,7 @@

import autofit as af
import autolens as al
from autofit.jax.pytrees import enable_pytrees, register_model

enable_pytrees()


"""
Expand Down Expand Up @@ -131,7 +129,6 @@

model = af.Collection(galaxies=af.Collection(lens=lens, source=source))

register_model(model)


galaxy_name_image_dict = {
Expand Down Expand Up @@ -160,7 +157,6 @@
use_mixed_precision=True,
),
use_jax=True,
use_jax_for_visualization=True,
)

instance_probe = model.instance_from_prior_medians()
Expand Down Expand Up @@ -249,9 +245,6 @@ def _assert_likelihood_sanity(label, analysis, model):
f"Cached call ({cached_time:.3f}s) not faster than compile "
f"({compile_time:.3f}s) — JIT cache is not being hit."
)
assert (
analysis_probe._jitted_fit_from is not None
), "expected _jitted_fit_from to be cached on the analysis instance after first call"
print("PASS: rectangular jit-cached fit_for_visualization works and is reused.")


Expand Down Expand Up @@ -325,7 +318,6 @@ def _assert_likelihood_sanity(label, analysis, model):
use_mixed_precision=True,
),
use_jax=True,
use_jax_for_visualization=True,
)

output_root = (
Expand Down Expand Up @@ -357,10 +349,6 @@ def _assert_likelihood_sanity(label, analysis, model):
f"no fit.png produced under {output_search_root} — "
"quick-update visualization did not fire"
)
assert (
analysis_live._jitted_fit_from is not None
), "expected _jitted_fit_from to be cached on the analysis instance during search"

print(
"\nPASS: jit-cached fit_for_visualization fires during Nautilus quick updates "
"with a rectangular-pixelization source, fit.png written."
Expand Down
25 changes: 10 additions & 15 deletions scripts/imaging/visualization_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

Goal
----
Run ``VisualizerImaging.visualize`` with JAX enabled end-to-end, gated behind
``use_jax_for_visualization=True`` on ``Analysis``. After PyAutoLens #443
(2026-04-19) the imaging visualizer dispatches through
``analysis.fit_for_visualization``, which lazily wraps ``fit_from`` in
``jax.jit``. To trace across that boundary the model and fit return type
must be JAX pytrees, so this script enables pytree registration before
constructing the model. Parametric MGE source — simplest case (no
Run ``VisualizerImaging.visualize`` with JAX enabled end-to-end via
``use_jax=True`` on ``Analysis``. After PyAutoLens #443 (2026-04-19) the
imaging visualizer dispatches through ``analysis.fit_for_visualization``,
which lazily wraps ``fit_from`` in ``jax.jit``. Visualization now follows
``use_jax`` automatically. To trace across that boundary the model and fit
return type must be JAX pytrees, so this script enables pytree registration
before constructing the model. Parametric MGE source — simplest case (no
pixelization, no inversion).

Scope
Expand All @@ -38,10 +38,8 @@

import autofit as af
import autolens as al
from autofit.jax.pytrees import enable_pytrees, register_model
from autolens.imaging.model.visualizer import VisualizerImaging

enable_pytrees()


"""
Expand Down Expand Up @@ -102,20 +100,17 @@

model = af.Collection(galaxies=af.Collection(lens=lens, source=source))

register_model(model)


"""
__Analysis__

``use_jax=True`` turns on the JAX ``_xp`` path; ``use_jax_for_visualization=True``
tells the search-level visualization path to wrap ``fit_from`` in ``jax.jit``
via the new ``Analysis.fit_for_visualization`` helper.
``use_jax=True`` turns on the JAX ``_xp`` path. Visualization now follows
``use_jax`` automatically via the ``Analysis.fit_for_visualization`` helper.
"""
analysis = al.AnalysisImaging(
dataset=dataset,
use_jax=True,
use_jax_for_visualization=True,
title_prefix="JAX_PILOT",
)

Expand All @@ -137,7 +132,7 @@
"""
instance = model.instance_from_prior_medians()

print("Running VisualizerImaging.visualize with use_jax_for_visualization=True ...")
print("Running VisualizerImaging.visualize with use_jax=True ...")
VisualizerImaging.visualize(
analysis=analysis,
paths=paths,
Expand Down
19 changes: 3 additions & 16 deletions scripts/interferometer/modeling_visualization_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
fires correctly during the live search callback.

This script deliberately opts in with
``AnalysisInterferometer(use_jax=True, use_jax_for_visualization=True)``.
Default model-fit scripts elsewhere in the workspace leave both flags at
``False`` and are therefore untouched by this change.
``AnalysisInterferometer(use_jax=True)``. Default model-fit scripts elsewhere
in the workspace leave the flag at ``False`` and are therefore untouched by
this change.
"""

import shutil
Expand All @@ -40,9 +40,7 @@

import autofit as af
import autolens as al
from autofit.jax.pytrees import enable_pytrees, register_model

enable_pytrees()


"""
Expand Down Expand Up @@ -131,13 +129,11 @@

model_mge = af.Collection(galaxies=af.Collection(lens=lens_mge, source=source_mge))

register_model(model_mge)

analysis_mge = al.AnalysisInterferometer(
dataset=dataset,
positions_likelihood_list=[al.PositionsLH(threshold=0.4, positions=positions)],
use_jax=True,
use_jax_for_visualization=True,
)

instance_mge = model_mge.instance_from_prior_medians()
Expand Down Expand Up @@ -165,9 +161,6 @@
f"Cached call ({cached_time:.3f}s) not faster than compile "
f"({compile_time:.3f}s) — JIT cache is not being hit."
)
assert (
analysis_mge._jitted_fit_from is not None
), "expected _jitted_fit_from to be cached on the analysis instance after first call"
print("PASS: MGE jit-cached fit_for_visualization works and is reused.")


Expand Down Expand Up @@ -284,13 +277,11 @@

model_mge2 = af.Collection(galaxies=af.Collection(lens=lens_mge2, source=source_mge2))

register_model(model_mge2)

analysis_mge2 = al.AnalysisInterferometer(
dataset=dataset,
positions_likelihood_list=[al.PositionsLH(threshold=0.4, positions=positions)],
use_jax=True,
use_jax_for_visualization=True,
)

output_root = (
Expand Down Expand Up @@ -330,10 +321,6 @@
f"no fit.png produced under {output_search_root} — "
"quick-update visualization did not fire"
)
assert (
analysis_mge2._jitted_fit_from is not None
), "expected _jitted_fit_from to be cached on the analysis instance during search"

print(
"\nPASS: jit-cached fit_for_visualization fires during Nautilus quick updates "
"with MGE linear profiles, fit.png written, no KeyError from "
Expand Down
Loading
Loading