diff --git a/scripts/jax_likelihood_functions/point_source/point_pytree.py b/scripts/jax_likelihood_functions/point_source/point_pytree.py new file mode 100644 index 00000000..66eed625 --- /dev/null +++ b/scripts/jax_likelihood_functions/point_source/point_pytree.py @@ -0,0 +1,133 @@ +""" +Path A PoC: jit-wrap ``analysis.fit_from`` for FitPointDataset +=============================================================== + +Sibling of ``imaging/mge_pytree.py`` and ``interferometer/mge_pytree.py`` for +the point-source path (see admin_jammy/prompt/issued/fit_point_pytree.md). + +Unlike the imaging / interferometer variants, the point-source data type is +``PointDataset`` (positions + optional fluxes + time-delays), and the "model +prediction" comes from an iterative ``PointSolver`` that solves the lens +equation for image-plane positions. + +Success criterion: + - ``jax.jit(analysis.fit_from)(instance)`` returns a ``FitPointDataset`` whose + ``log_likelihood`` is a ``jax.Array`` matching the NumPy-path scalar. + +If the solver surfaces a non-JIT-safe operation, stop and document the blocker +inline rather than forcing a workaround. +""" + +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np + +import autofit as af +import autolens as al + +from autofit.jax.pytrees import enable_pytrees, register_model + +enable_pytrees() + + +""" +__Dataset__ — same on-disk dataset used by ``point_source/point.py``. +""" +dataset_name = "simple" +dataset_path = Path("dataset") / "point_source" / dataset_name + +if al.util.dataset.should_simulate(str(dataset_path)): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/point_source/simulator.py"], + check=True, + ) + +dataset = al.from_json( + file_path=dataset_path / "point_dataset_positions_only.json", +) + + +""" +__Point Solver__ — constructed once, rides as aux through the pytree. +""" +grid = al.Grid2D.uniform( + shape_native=(100, 100), + pixel_scales=0.2, +) + +solver = al.PointSolver.for_grid( + grid=grid, pixel_scale_precision=0.001, magnification_threshold=0.1, xp=jnp +) + + +""" +__Model__ — Isothermal lens + Point source, fixed cosmology. +""" +mass = af.Model(al.mp.Isothermal) + +mass.centre.centre_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass.centre.centre_1 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass.ell_comps.ell_comps_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass.ell_comps.ell_comps_1 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass.einstein_radius = af.UniformPrior(lower_limit=1.5, upper_limit=1.8) + +lens = af.Model(al.Galaxy, redshift=0.5, mass=mass) + +point_0 = af.Model(al.ps.PointFlux) +point_0.centre.centre_0 = af.UniformPrior(lower_limit=0.06, upper_limit=0.08) +point_0.centre.centre_1 = af.UniformPrior(lower_limit=0.06, upper_limit=0.08) + +source = af.Model(al.Galaxy, redshift=1.0, point_0=point_0) + +model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) + + +""" +__Analysis__ on the JAX path (``use_jax=True``). +""" +analysis = al.AnalysisPoint( + dataset=dataset, + solver=solver, + fit_positions_cls=al.FitPositionsImagePairAll, + use_jax=True, +) + +register_model(model) + +instance = model.instance_from_prior_medians() + + +""" +__NumPy reference scalar__. +""" +analysis_np = al.AnalysisPoint( + dataset=dataset, + solver=solver, + fit_positions_cls=al.FitPositionsImagePairAll, + use_jax=False, +) +fit_np = analysis_np.fit_from(instance=instance) +print("NumPy fit.log_likelihood:", float(fit_np.log_likelihood)) + + +""" +__Path A: jit-wrap ``analysis.fit_from``__. +""" +fit_jit_fn = jax.jit(analysis.fit_from) +fit = fit_jit_fn(instance) + +print("JIT fit type:", type(fit).__name__) +print("JIT fit.log_likelihood:", fit.log_likelihood) +assert isinstance(fit.log_likelihood, jnp.ndarray), ( + f"expected jax.Array, got {type(fit.log_likelihood)}" +) +np.testing.assert_allclose( + float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4 +) + +print("PASS: jit(fit_from) round-trip matches NumPy scalar.")