Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions config/build/env_vars.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,10 @@ overrides:
# full resolution; just unset the cap.
- pattern: "multi/visualization_imaging"
unset: [PYAUTO_SMALL_DATASETS]
# point_source/visualization_jax exercises the jit-cached
# fit_for_visualization path on the point-source side.
- pattern: "point_source/visualization_jax"
unset: [PYAUTO_DISABLE_JAX, PYAUTO_SMALL_DATASETS]
# point_source/modeling_visualization_jit — live Nautilus + JIT path.
- pattern: "point_source/modeling_visualization_jit"
unset: [PYAUTO_DISABLE_JAX, PYAUTO_SMALL_DATASETS, PYAUTO_TEST_MODE, PYAUTO_FAST_PLOTS]
227 changes: 227 additions & 0 deletions scripts/point_source/modeling_visualization_jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
"""
End-to-end test: jit-cached visualization during a real Nautilus model-fit.
============================================================================

Exercises the full JAX visualization pipeline for the point-source analysis
path: ``AnalysisPoint(use_jax=True, use_jax_for_visualization=True)`` with
an ``Isothermal`` lens mass and ``PointFlux`` source (image-plane chi-squared
via ``FitPositionsImagePairAll``).

This test runs in two parts:

Part 1 — **Caching probe.** Calls ``analysis.fit_for_visualization(instance)``
twice and asserts the second call is much faster than the first (confirming
the compiled function is cached on the analysis instance, not recompiled per
visualization call). Also asserts ``analysis._jitted_fit_from is not None``
after the first call.

Part 2 — **Live Nautilus quick-update.** Runs a real (short) Nautilus fit.
The live search fires quick-update visualization every
``iterations_per_quick_update`` likelihood evaluations; we verify that
``fit.png`` lands on disk under the Nautilus output tree, proving the
JIT-cached ``fit_for_visualization`` fires correctly during the live
search callback.

This script deliberately opts in with
``AnalysisPoint(use_jax=True, use_jax_for_visualization=True)``.
Default model-fit scripts elsewhere in the workspace leave both flags at
``False`` and are therefore untouched.
"""

import shutil
import time
from os import path
from pathlib import Path

import jax
import jax.numpy as jnp

import autofit as af
import autolens as al
from autofit.jax.pytrees import enable_pytrees, register_model

enable_pytrees()


"""
__Dataset__
"""
dataset_path = Path("dataset") / "point_source" / "simple"

if al.util.dataset.should_simulate(str(dataset_path)):
import subprocess
import sys

subprocess.run(
[sys.executable, "scripts/jax_likelihood_functions/point_source/simulator.py"],
check=True,
)

dataset = al.from_json(
file_path=dataset_path / "point_dataset_positions_only.json",
)


"""
__Point Solver__
"""
grid = al.Grid2D.uniform(shape_native=(100, 100), pixel_scales=0.2)

solver = al.PointSolver.for_grid(
grid=grid, pixel_scale_precision=0.001, magnification_threshold=0.1
)


"""
============================================================================
Part 1 — Caching probe
============================================================================

Model: Isothermal lens mass + PointFlux source. Same tight priors as the
other point-source scripts so the prior-median instance produces multiple
images. No free cosmology (breaks JIT via global-state distance caching).
"""
print("\n" + "=" * 72)
print("Part 1: Point-source caching probe")
print("=" * 72)

mass = af.Model(al.mp.Isothermal)
mass.centre.centre_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02)
mass.centre.centre_1 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02)
mass.ell_comps.ell_comps_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02)
mass.ell_comps.ell_comps_1 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02)
mass.einstein_radius = af.UniformPrior(lower_limit=1.5, upper_limit=1.8)

lens = af.Model(al.Galaxy, redshift=0.5, mass=mass)

point_0 = af.Model(al.ps.PointFlux)
point_0.centre.centre_0 = af.UniformPrior(lower_limit=0.06, upper_limit=0.08)
point_0.centre.centre_1 = af.UniformPrior(lower_limit=0.06, upper_limit=0.08)

source = af.Model(al.Galaxy, redshift=1.0, point_0=point_0)

model = af.Collection(galaxies=af.Collection(lens=lens, source=source))

register_model(model)

analysis = al.AnalysisPoint(
dataset=dataset,
solver=solver,
fit_positions_cls=al.FitPositionsImagePairAll,
use_jax=True,
use_jax_for_visualization=True,
)

instance = model.instance_from_prior_medians()

t0 = time.perf_counter()
fit_1 = analysis.fit_for_visualization(instance)
jax.block_until_ready(fit_1.log_likelihood)
t1 = time.perf_counter()
compile_time = t1 - t0
print(f"First call (compile + run): {compile_time:.3f}s")
print(f" log_likelihood leaf type: {type(fit_1.log_likelihood).__name__}")
assert isinstance(
fit_1.log_likelihood, jnp.ndarray
), f"expected jax.Array, got {type(fit_1.log_likelihood)}"

t0 = time.perf_counter()
fit_2 = analysis.fit_for_visualization(instance)
jax.block_until_ready(fit_2.log_likelihood)
t1 = time.perf_counter()
cached_time = t1 - t0
print(f"Second call (cached): {cached_time:.3f}s")
print(f"Speedup: {compile_time / max(cached_time, 1e-9):.1f}x")

assert cached_time < compile_time * 0.5, (
f"Cached call ({cached_time:.3f}s) not faster than compile "
f"({compile_time:.3f}s) — JIT cache is not being hit."
)
assert (
analysis._jitted_fit_from is not None
), "expected _jitted_fit_from to be cached on the analysis instance after first call"
print("PASS: Point-source jit-cached fit_for_visualization works and is reused.")


"""
============================================================================
Part 2 — Live Nautilus quick-update
============================================================================

Rebuild the model fresh (register_model on the new instance), create a
separate analysis object, and run a short Nautilus fit. The search fires
quick-update visualization every ``iterations_per_quick_update`` calls;
we assert that ``fit.png`` lands on disk under the Nautilus output tree.
"""
print("\n" + "=" * 72)
print("Part 2: Live Nautilus + jit-visualization for point source")
print("=" * 72)

mass2 = af.Model(al.mp.Isothermal)
mass2.centre.centre_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02)
mass2.centre.centre_1 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02)
mass2.ell_comps.ell_comps_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02)
mass2.ell_comps.ell_comps_1 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02)
mass2.einstein_radius = af.UniformPrior(lower_limit=1.5, upper_limit=1.8)

lens2 = af.Model(al.Galaxy, redshift=0.5, mass=mass2)

point_02 = af.Model(al.ps.PointFlux)
point_02.centre.centre_0 = af.UniformPrior(lower_limit=0.06, upper_limit=0.08)
point_02.centre.centre_1 = af.UniformPrior(lower_limit=0.06, upper_limit=0.08)

source2 = af.Model(al.Galaxy, redshift=1.0, point_0=point_02)

model2 = af.Collection(galaxies=af.Collection(lens=lens2, source=source2))

register_model(model2)

analysis_run = al.AnalysisPoint(
dataset=dataset,
solver=solver,
fit_positions_cls=al.FitPositionsImagePairAll,
use_jax=True,
use_jax_for_visualization=True,
)

output_root = Path("scripts") / "point_source" / "images" / "modeling_visualization_jit"
if output_root.exists():
shutil.rmtree(output_root)
output_root.mkdir(parents=True)

# Also clean the autofit search output so Nautilus performs live sampling
# instead of resuming from a cached samples.csv — without this the
# quick-update visualizer never fires on reruns.
output_search_root = Path("output") / output_root / "point_image_plane"
if output_search_root.exists():
shutil.rmtree(output_search_root)

search = af.Nautilus(
path_prefix=str(output_root),
name="point_image_plane",
n_live=50,
n_like_max=1500,
iterations_per_quick_update=500,
number_of_cores=1,
)

print("Running Nautilus ...")
result = search.fit(model=model2, analysis=analysis_run)

# Nautilus writes quick-update images to output/<path_prefix>/<name>/<hash>/image/
produced_pngs = list(output_search_root.rglob("fit.png"))
print(f"fit.png files produced: {len(produced_pngs)}")
for p in produced_pngs:
print(f" {p}")
assert len(produced_pngs) > 0, (
f"no fit.png produced under {output_search_root} — "
"quick-update visualization did not fire"
)
assert (
analysis_run._jitted_fit_from is not None
), "expected _jitted_fit_from to be cached on the analysis instance during search"

print(
"\nPASS: jit-cached fit_for_visualization fires during Nautilus quick updates "
f"for point source, fit.png written."
)
122 changes: 122 additions & 0 deletions scripts/point_source/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
Visualization NumPy: Point Source Analysis
==========================================

Tests that ``VisualizerPoint.visualize`` runs end-to-end on a
``PointDataset`` using the NumPy (non-JAX) code path and that ``fit.png``
lands on disk.

Uses the ``simple/point_dataset_positions_only.json`` dataset (auto-simulated
if missing) with an ``Isothermal`` lens mass and ``PointFlux`` source — the
same model that is proven to JIT end-to-end in
``scripts/jax_likelihood_functions/point_source/image_plane.py``.

No ``try/except`` — any failure in the visualizer surfaces immediately.
"""

import shutil
from pathlib import Path
from types import SimpleNamespace

import autofit as af
import autolens as al
from autolens.point.model.visualizer import VisualizerPoint


"""
__Dataset__
"""
dataset_path = Path("dataset") / "point_source" / "simple"

if al.util.dataset.should_simulate(str(dataset_path)):
import subprocess
import sys

subprocess.run(
[sys.executable, "scripts/jax_likelihood_functions/point_source/simulator.py"],
check=True,
)

dataset = al.from_json(
file_path=dataset_path / "point_dataset_positions_only.json",
)


"""
__Point Solver__
"""
grid = al.Grid2D.uniform(shape_native=(100, 100), pixel_scales=0.2)

solver = al.PointSolver.for_grid(
grid=grid, pixel_scale_precision=0.001, magnification_threshold=0.1
)


"""
__Model__

Tight priors centred on the true values so the prior-median instance
produces a sensible lens configuration (multiple images exist).
No free cosmology — cosmology distance caching breaks JIT.
"""
mass = af.Model(al.mp.Isothermal)
mass.centre.centre_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02)
mass.centre.centre_1 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02)
mass.ell_comps.ell_comps_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02)
mass.ell_comps.ell_comps_1 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02)
mass.einstein_radius = af.UniformPrior(lower_limit=1.5, upper_limit=1.8)

lens = af.Model(al.Galaxy, redshift=0.5, mass=mass)

point_0 = af.Model(al.ps.PointFlux)
point_0.centre.centre_0 = af.UniformPrior(lower_limit=0.06, upper_limit=0.08)
point_0.centre.centre_1 = af.UniformPrior(lower_limit=0.06, upper_limit=0.08)

source = af.Model(al.Galaxy, redshift=1.0, point_0=point_0)

model = af.Collection(galaxies=af.Collection(lens=lens, source=source))


"""
__Analysis__

Explicit NumPy path (use_jax=False).
"""
analysis = al.AnalysisPoint(
dataset=dataset,
solver=solver,
fit_positions_cls=al.FitPositionsImagePairAll,
use_jax=False,
)


"""
__Paths__
"""
image_path = Path("scripts") / "point_source" / "images" / "visualization"
if image_path.exists():
shutil.rmtree(image_path)
image_path.mkdir(parents=True)
output_path = image_path / "output"
output_path.mkdir(parents=True)
paths = SimpleNamespace(image_path=image_path, output_path=output_path)


"""
__Visualize__
"""
instance = model.instance_from_prior_medians()

print("Running VisualizerPoint.visualize (NumPy) ...")
VisualizerPoint.visualize(
analysis=analysis,
paths=paths,
instance=instance,
during_analysis=False,
)

print("Files in image_path:", list(image_path.iterdir()))
assert (image_path / "fit.png").exists(), (
f"fit.png was not produced. Files present: {list(image_path.iterdir())}"
)
print("NumPy point-source visualization produced fit.png.")
Loading
Loading