Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions scripts/jax_likelihood_functions/point_source/point_pytree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""
Path A PoC: jit-wrap ``analysis.fit_from`` for FitPointDataset
===============================================================

Sibling of ``imaging/mge_pytree.py`` and ``interferometer/mge_pytree.py`` for
the point-source path (see admin_jammy/prompt/issued/fit_point_pytree.md).

Unlike the imaging / interferometer variants, the point-source data type is
``PointDataset`` (positions + optional fluxes + time-delays), and the "model
prediction" comes from an iterative ``PointSolver`` that solves the lens
equation for image-plane positions.

Success criterion:
- ``jax.jit(analysis.fit_from)(instance)`` returns a ``FitPointDataset`` whose
``log_likelihood`` is a ``jax.Array`` matching the NumPy-path scalar.

If the solver surfaces a non-JIT-safe operation, stop and document the blocker
inline rather than forcing a workaround.
"""

from pathlib import Path

import jax
import jax.numpy as jnp
import numpy as np

import autofit as af
import autolens as al

from autofit.jax.pytrees import enable_pytrees, register_model

enable_pytrees()


"""
__Dataset__ — same on-disk dataset used by ``point_source/point.py``.
"""
dataset_name = "simple"
dataset_path = Path("dataset") / "point_source" / dataset_name

if al.util.dataset.should_simulate(str(dataset_path)):
import subprocess
import sys

subprocess.run(
[sys.executable, "scripts/jax_likelihood_functions/point_source/simulator.py"],
check=True,
)

dataset = al.from_json(
file_path=dataset_path / "point_dataset_positions_only.json",
)


"""
__Point Solver__ — constructed once, rides as aux through the pytree.
"""
grid = al.Grid2D.uniform(
shape_native=(100, 100),
pixel_scales=0.2,
)

solver = al.PointSolver.for_grid(
grid=grid, pixel_scale_precision=0.001, magnification_threshold=0.1, xp=jnp
)


"""
__Model__ — Isothermal lens + Point source, fixed cosmology.
"""
mass = af.Model(al.mp.Isothermal)

mass.centre.centre_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02)
mass.centre.centre_1 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02)
mass.ell_comps.ell_comps_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02)
mass.ell_comps.ell_comps_1 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02)
mass.einstein_radius = af.UniformPrior(lower_limit=1.5, upper_limit=1.8)

lens = af.Model(al.Galaxy, redshift=0.5, mass=mass)

point_0 = af.Model(al.ps.PointFlux)
point_0.centre.centre_0 = af.UniformPrior(lower_limit=0.06, upper_limit=0.08)
point_0.centre.centre_1 = af.UniformPrior(lower_limit=0.06, upper_limit=0.08)

source = af.Model(al.Galaxy, redshift=1.0, point_0=point_0)

model = af.Collection(galaxies=af.Collection(lens=lens, source=source))


"""
__Analysis__ on the JAX path (``use_jax=True``).
"""
analysis = al.AnalysisPoint(
dataset=dataset,
solver=solver,
fit_positions_cls=al.FitPositionsImagePairAll,
use_jax=True,
)

register_model(model)

instance = model.instance_from_prior_medians()


"""
__NumPy reference scalar__.
"""
analysis_np = al.AnalysisPoint(
dataset=dataset,
solver=solver,
fit_positions_cls=al.FitPositionsImagePairAll,
use_jax=False,
)
fit_np = analysis_np.fit_from(instance=instance)
print("NumPy fit.log_likelihood:", float(fit_np.log_likelihood))


"""
__Path A: jit-wrap ``analysis.fit_from``__.
"""
fit_jit_fn = jax.jit(analysis.fit_from)
fit = fit_jit_fn(instance)

print("JIT fit type:", type(fit).__name__)
print("JIT fit.log_likelihood:", fit.log_likelihood)
assert isinstance(fit.log_likelihood, jnp.ndarray), (
f"expected jax.Array, got {type(fit.log_likelihood)}"
)
np.testing.assert_allclose(
float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4
)

print("PASS: jit(fit_from) round-trip matches NumPy scalar.")
Loading