diff --git a/scripts/cluster/simulator.py b/scripts/cluster/simulator.py index e201876f..179ed828 100644 --- a/scripts/cluster/simulator.py +++ b/scripts/cluster/simulator.py @@ -27,7 +27,6 @@ import autolens as al import autolens.plot as aplt -from autofit.jax import register_model as _register_model_pytrees from autoarray.abstract_ndarray import register_instance_pytree from autolens.lens.tracer import Tracer @@ -165,7 +164,6 @@ galaxies=af.Collection(*(_lens_models + [_halo_model] + _source_models)) ) -_register_model_pytrees(_registration_model) register_instance_pytree(Tracer, no_flatten=("cosmology",)) diff --git a/scripts/imaging/modeling_visualization_jit.py b/scripts/imaging/modeling_visualization_jit.py index 570fb2b9..1dcb552b 100644 --- a/scripts/imaging/modeling_visualization_jit.py +++ b/scripts/imaging/modeling_visualization_jit.py @@ -24,9 +24,9 @@ fit_for_visualization fires correctly during the live search callback. This script deliberately opts in with -``AnalysisImaging(use_jax=True, use_jax_for_visualization=True)``. Default -model-fit scripts elsewhere in the workspace leave both flags at ``False`` -and are therefore untouched by this change. +``AnalysisImaging(use_jax=True)``. Default model-fit scripts elsewhere in the +workspace leave the flag at ``False`` and are therefore untouched by this +change. """ import shutil @@ -40,9 +40,7 @@ import autofit as af import autolens as al -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() """ @@ -129,12 +127,10 @@ model_mge = af.Collection(galaxies=af.Collection(lens=lens_mge, source=source_mge)) -register_model(model_mge) analysis_mge = al.AnalysisImaging( dataset=dataset, use_jax=True, - use_jax_for_visualization=True, ) instance_mge = model_mge.instance_from_prior_medians() @@ -162,9 +158,6 @@ f"Cached call ({cached_time:.3f}s) not faster than compile " f"({compile_time:.3f}s) — JIT cache is not being hit." ) -assert ( - analysis_mge._jitted_fit_from is not None -), "expected _jitted_fit_from to be cached on the analysis instance after first call" print("PASS: MGE jit-cached fit_for_visualization works and is reused.") @@ -288,12 +281,10 @@ model_mge2 = af.Collection(galaxies=af.Collection(lens=lens_mge2, source=source_mge2)) -register_model(model_mge2) analysis_mge2 = al.AnalysisImaging( dataset=dataset, use_jax=True, - use_jax_for_visualization=True, ) output_root = Path("scripts") / "imaging" / "images" / "modeling_visualization_jit" @@ -325,10 +316,6 @@ f"no fit.png produced under {output_search_root} — " "quick-update visualization did not fire" ) -assert ( - analysis_mge2._jitted_fit_from is not None -), "expected _jitted_fit_from to be cached on the analysis instance during search" - print( "\nPASS: jit-cached fit_for_visualization fires during Nautilus quick updates " "with MGE linear profiles, fit.png written, no KeyError from " diff --git a/scripts/imaging/modeling_visualization_jit_delaunay.py b/scripts/imaging/modeling_visualization_jit_delaunay.py index 06f846e0..f09524ea 100644 --- a/scripts/imaging/modeling_visualization_jit_delaunay.py +++ b/scripts/imaging/modeling_visualization_jit_delaunay.py @@ -37,9 +37,7 @@ import autofit as af import autolens as al -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() """ @@ -159,7 +157,6 @@ model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) -register_model(model) """ @@ -176,7 +173,6 @@ adapt_images=adapt_images, raise_inversion_positions_likelihood_exception=False, use_jax=True, - use_jax_for_visualization=True, ) instance_probe = model.instance_from_prior_medians() @@ -261,9 +257,6 @@ def _assert_likelihood_sanity(label, analysis, model): f"Cached call ({cached_time:.3f}s) not faster than compile " f"({compile_time:.3f}s) — JIT cache is not being hit." ) -assert ( - analysis_probe._jitted_fit_from is not None -), "expected _jitted_fit_from to be cached on the analysis instance after first call" print("PASS: Delaunay jit-cached fit_for_visualization works and is reused.") @@ -337,7 +330,6 @@ def _assert_likelihood_sanity(label, analysis, model): adapt_images=adapt_images, raise_inversion_positions_likelihood_exception=False, use_jax=True, - use_jax_for_visualization=True, ) output_root = ( @@ -369,10 +361,6 @@ def _assert_likelihood_sanity(label, analysis, model): f"no fit.png produced under {output_search_root} — " "quick-update visualization did not fire" ) -assert ( - analysis_live._jitted_fit_from is not None -), "expected _jitted_fit_from to be cached on the analysis instance during search" - print( "\nPASS: jit-cached fit_for_visualization fires during Nautilus quick updates " "with a Delaunay-pixelization source, fit.png written." diff --git a/scripts/imaging/modeling_visualization_jit_rectangular.py b/scripts/imaging/modeling_visualization_jit_rectangular.py index bcd0123e..cf583ac8 100644 --- a/scripts/imaging/modeling_visualization_jit_rectangular.py +++ b/scripts/imaging/modeling_visualization_jit_rectangular.py @@ -22,7 +22,7 @@ ``galaxy_image_plane_mesh_grid_dict`` / ``galaxy_image_dict`` lookups. This script deliberately opts in with -``AnalysisImaging(use_jax=True, use_jax_for_visualization=True)``. +``AnalysisImaging(use_jax=True)``. """ import shutil @@ -36,9 +36,7 @@ import autofit as af import autolens as al -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() """ @@ -131,7 +129,6 @@ model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) -register_model(model) galaxy_name_image_dict = { @@ -160,7 +157,6 @@ use_mixed_precision=True, ), use_jax=True, - use_jax_for_visualization=True, ) instance_probe = model.instance_from_prior_medians() @@ -249,9 +245,6 @@ def _assert_likelihood_sanity(label, analysis, model): f"Cached call ({cached_time:.3f}s) not faster than compile " f"({compile_time:.3f}s) — JIT cache is not being hit." ) -assert ( - analysis_probe._jitted_fit_from is not None -), "expected _jitted_fit_from to be cached on the analysis instance after first call" print("PASS: rectangular jit-cached fit_for_visualization works and is reused.") @@ -325,7 +318,6 @@ def _assert_likelihood_sanity(label, analysis, model): use_mixed_precision=True, ), use_jax=True, - use_jax_for_visualization=True, ) output_root = ( @@ -357,10 +349,6 @@ def _assert_likelihood_sanity(label, analysis, model): f"no fit.png produced under {output_search_root} — " "quick-update visualization did not fire" ) -assert ( - analysis_live._jitted_fit_from is not None -), "expected _jitted_fit_from to be cached on the analysis instance during search" - print( "\nPASS: jit-cached fit_for_visualization fires during Nautilus quick updates " "with a rectangular-pixelization source, fit.png written." diff --git a/scripts/imaging/visualization_jax.py b/scripts/imaging/visualization_jax.py index e226243a..76d08269 100644 --- a/scripts/imaging/visualization_jax.py +++ b/scripts/imaging/visualization_jax.py @@ -6,13 +6,13 @@ Goal ---- -Run ``VisualizerImaging.visualize`` with JAX enabled end-to-end, gated behind -``use_jax_for_visualization=True`` on ``Analysis``. After PyAutoLens #443 -(2026-04-19) the imaging visualizer dispatches through -``analysis.fit_for_visualization``, which lazily wraps ``fit_from`` in -``jax.jit``. To trace across that boundary the model and fit return type -must be JAX pytrees, so this script enables pytree registration before -constructing the model. Parametric MGE source — simplest case (no +Run ``VisualizerImaging.visualize`` with JAX enabled end-to-end via +``use_jax=True`` on ``Analysis``. After PyAutoLens #443 (2026-04-19) the +imaging visualizer dispatches through ``analysis.fit_for_visualization``, +which lazily wraps ``fit_from`` in ``jax.jit``. Visualization now follows +``use_jax`` automatically. To trace across that boundary the model and fit +return type must be JAX pytrees, so this script enables pytree registration +before constructing the model. Parametric MGE source — simplest case (no pixelization, no inversion). Scope @@ -38,10 +38,8 @@ import autofit as af import autolens as al -from autofit.jax.pytrees import enable_pytrees, register_model from autolens.imaging.model.visualizer import VisualizerImaging -enable_pytrees() """ @@ -102,20 +100,17 @@ model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) -register_model(model) """ __Analysis__ -``use_jax=True`` turns on the JAX ``_xp`` path; ``use_jax_for_visualization=True`` -tells the search-level visualization path to wrap ``fit_from`` in ``jax.jit`` -via the new ``Analysis.fit_for_visualization`` helper. +``use_jax=True`` turns on the JAX ``_xp`` path. Visualization now follows +``use_jax`` automatically via the ``Analysis.fit_for_visualization`` helper. """ analysis = al.AnalysisImaging( dataset=dataset, use_jax=True, - use_jax_for_visualization=True, title_prefix="JAX_PILOT", ) @@ -137,7 +132,7 @@ """ instance = model.instance_from_prior_medians() -print("Running VisualizerImaging.visualize with use_jax_for_visualization=True ...") +print("Running VisualizerImaging.visualize with use_jax=True ...") VisualizerImaging.visualize( analysis=analysis, paths=paths, diff --git a/scripts/interferometer/modeling_visualization_jit.py b/scripts/interferometer/modeling_visualization_jit.py index ce4c62d7..e694a9bc 100644 --- a/scripts/interferometer/modeling_visualization_jit.py +++ b/scripts/interferometer/modeling_visualization_jit.py @@ -24,9 +24,9 @@ fires correctly during the live search callback. This script deliberately opts in with -``AnalysisInterferometer(use_jax=True, use_jax_for_visualization=True)``. -Default model-fit scripts elsewhere in the workspace leave both flags at -``False`` and are therefore untouched by this change. +``AnalysisInterferometer(use_jax=True)``. Default model-fit scripts elsewhere +in the workspace leave the flag at ``False`` and are therefore untouched by +this change. """ import shutil @@ -40,9 +40,7 @@ import autofit as af import autolens as al -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() """ @@ -131,13 +129,11 @@ model_mge = af.Collection(galaxies=af.Collection(lens=lens_mge, source=source_mge)) -register_model(model_mge) analysis_mge = al.AnalysisInterferometer( dataset=dataset, positions_likelihood_list=[al.PositionsLH(threshold=0.4, positions=positions)], use_jax=True, - use_jax_for_visualization=True, ) instance_mge = model_mge.instance_from_prior_medians() @@ -165,9 +161,6 @@ f"Cached call ({cached_time:.3f}s) not faster than compile " f"({compile_time:.3f}s) — JIT cache is not being hit." ) -assert ( - analysis_mge._jitted_fit_from is not None -), "expected _jitted_fit_from to be cached on the analysis instance after first call" print("PASS: MGE jit-cached fit_for_visualization works and is reused.") @@ -284,13 +277,11 @@ model_mge2 = af.Collection(galaxies=af.Collection(lens=lens_mge2, source=source_mge2)) -register_model(model_mge2) analysis_mge2 = al.AnalysisInterferometer( dataset=dataset, positions_likelihood_list=[al.PositionsLH(threshold=0.4, positions=positions)], use_jax=True, - use_jax_for_visualization=True, ) output_root = ( @@ -330,10 +321,6 @@ f"no fit.png produced under {output_search_root} — " "quick-update visualization did not fire" ) -assert ( - analysis_mge2._jitted_fit_from is not None -), "expected _jitted_fit_from to be cached on the analysis instance during search" - print( "\nPASS: jit-cached fit_for_visualization fires during Nautilus quick updates " "with MGE linear profiles, fit.png written, no KeyError from " diff --git a/scripts/interferometer/visualization_jax.py b/scripts/interferometer/visualization_jax.py index 7bdff78a..97fbf70d 100644 --- a/scripts/interferometer/visualization_jax.py +++ b/scripts/interferometer/visualization_jax.py @@ -6,15 +6,16 @@ Goal ---- -Run ``VisualizerInterferometer.visualize`` with JAX enabled end-to-end, gated -behind ``use_jax_for_visualization=True`` on ``Analysis``. After PyAutoLens #443 -the interferometer visualizer dispatches through -``analysis.fit_for_visualization``, which lazily wraps ``fit_from`` in -``jax.jit`` (autolens/interferometer/model/visualizer.py:96). To trace across -that boundary the model and fit return type must be JAX pytrees, so this script -enables pytree registration before constructing the model. Parametric MGE -source — simplest case (no PSF convolution; interferometer operates in Fourier -space via DFT, no pixelization, no inversion). +Run ``VisualizerInterferometer.visualize`` with JAX enabled end-to-end via +``use_jax=True`` on ``Analysis``. After PyAutoLens #443 the interferometer +visualizer dispatches through ``analysis.fit_for_visualization``, which lazily +wraps ``fit_from`` in ``jax.jit`` +(autolens/interferometer/model/visualizer.py:96). Visualization now follows +``use_jax`` automatically. To trace across that boundary the model and fit +return type must be JAX pytrees, so this script enables pytree registration +before constructing the model. Parametric MGE source — simplest case (no PSF +convolution; interferometer operates in Fourier space via DFT, no +pixelization, no inversion). Scope ----- @@ -31,10 +32,8 @@ import autofit as af import autolens as al -from autofit.jax.pytrees import enable_pytrees, register_model from autolens.interferometer.model.visualizer import VisualizerInterferometer -enable_pytrees() """ @@ -95,21 +94,18 @@ model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) -register_model(model) """ __Analysis__ -``use_jax=True`` turns on the JAX ``_xp`` path; ``use_jax_for_visualization=True`` -tells the search-level visualization path to wrap ``fit_from`` in ``jax.jit`` -via the new ``Analysis.fit_for_visualization`` helper. +``use_jax=True`` turns on the JAX ``_xp`` path. Visualization now follows +``use_jax`` automatically via the ``Analysis.fit_for_visualization`` helper. """ analysis = al.AnalysisInterferometer( dataset=dataset, positions_likelihood_list=[al.PositionsLH(threshold=0.4, positions=positions)], use_jax=True, - use_jax_for_visualization=True, title_prefix="JAX_PILOT", ) @@ -132,7 +128,7 @@ instance = model.instance_from_prior_medians() print( - "Running VisualizerInterferometer.visualize with use_jax_for_visualization=True ..." + "Running VisualizerInterferometer.visualize with use_jax=True ..." ) VisualizerInterferometer.visualize( analysis=analysis, diff --git a/scripts/jax_likelihood_functions/datacube/delaunay.py b/scripts/jax_likelihood_functions/datacube/delaunay.py index 53751be7..605f4cb7 100644 --- a/scripts/jax_likelihood_functions/datacube/delaunay.py +++ b/scripts/jax_likelihood_functions/datacube/delaunay.py @@ -254,10 +254,7 @@ Matches ``multi/delaunay.py``: jit-wrap ``factor_graph.log_likelihood_function`` through ``instance_from_vector`` and assert the result matches the vmap value. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(factor_graph.global_prior_model) @jax.jit diff --git a/scripts/jax_likelihood_functions/datacube/rectangular.py b/scripts/jax_likelihood_functions/datacube/rectangular.py index e3de74c0..c0953951 100644 --- a/scripts/jax_likelihood_functions/datacube/rectangular.py +++ b/scripts/jax_likelihood_functions/datacube/rectangular.py @@ -237,10 +237,7 @@ Matches ``multi/delaunay.py``: jit-wrap ``factor_graph.log_likelihood_function`` through ``instance_from_vector`` and assert the result matches the vmap value. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(factor_graph.global_prior_model) @jax.jit diff --git a/scripts/jax_likelihood_functions/imaging/delaunay.py b/scripts/jax_likelihood_functions/imaging/delaunay.py index a9797f02..ce59e787 100644 --- a/scripts/jax_likelihood_functions/imaging/delaunay.py +++ b/scripts/jax_likelihood_functions/imaging/delaunay.py @@ -298,10 +298,7 @@ """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/imaging/delaunay_mge.py b/scripts/jax_likelihood_functions/imaging/delaunay_mge.py index e089a88c..5b183a91 100644 --- a/scripts/jax_likelihood_functions/imaging/delaunay_mge.py +++ b/scripts/jax_likelihood_functions/imaging/delaunay_mge.py @@ -318,10 +318,7 @@ """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/imaging/lp.py b/scripts/jax_likelihood_functions/imaging/lp.py index c31f5914..e817291b 100644 --- a/scripts/jax_likelihood_functions/imaging/lp.py +++ b/scripts/jax_likelihood_functions/imaging/lp.py @@ -1,243 +1,240 @@ -""" -Func Grad: Light Parametric Operated -==================================== - -This script test if JAX can successfully compute the gradient of the log likelihood of an `Imaging` dataset with a -model which uses operated light profiles. - - __Operated Fitting__ - -It is common for galaxies to have point-source emission, for example bright emission right at their centre due to -an active galactic nuclei or very compact knot of star formation. - -This point-source emission is subject to blurring during data accquisiton due to the telescope optics, and therefore -is not seen as a single pixel of light but spread over multiple pixels as a convolution with the telescope -Point Spread Function (PSF). - -It is difficult to model this compact point source emission using a point-source light profile (or an extremely -compact Gaussian / Sersic profile). This is because when the model-image of a compact point source of light is -convolved with the PSF, the solution to this convolution is extremely sensitive to which pixel (and sub-pixel) the -compact model emission lands in. - -Operated light profiles offer an alternative approach, whereby the light profile is assumed to have already been -convolved with the PSF. This operated light profile is then fitted directly to the point-source emission, which as -discussed above shows the PSF features. -""" - -# %matplotlib inline -# from pyprojroot import here -# workspace_path = str(here()) -# %cd $workspace_path -# print(f"Working Directory has been set to `{workspace_path}`") - -import numpy as np -import jax.numpy as jnp -import jax -from jax import grad -from os import path - -import autofit as af -import autolens as al -from autoconf import conf - -""" -__Dataset__ - -Load and plot the galaxy dataset via .fits files. -""" -dataset_path = path.join("dataset", "imaging", "jax_test") - -""" -__Dataset Auto-Simulation__ - -If the dataset does not already exist on your system, it will be created by running the corresponding -simulator script. This ensures that all example scripts can be run without manually simulating data first. -""" -if al.util.dataset.should_simulate(dataset_path): - import subprocess - import sys - - subprocess.run( - [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], - check=True, - ) - -dataset = al.Imaging.from_fits( - data_path=path.join(dataset_path, "data.fits"), - psf_path=path.join(dataset_path, "psf.fits"), - noise_map_path=path.join(dataset_path, "noise_map.fits"), - pixel_scales=0.2, -) - - -""" -__Mask__ - -The model-fit requires a 2D mask defining the regions of the image we fit the model to the data, which we define -and use to set up the `Imaging` object that the model fits. -""" -mask_radius = 3.0 - -mask = al.Mask2D.circular( - shape_native=dataset.shape_native, - pixel_scales=dataset.pixel_scales, - radius=mask_radius, -) - -dataset = dataset.apply_mask(mask=mask) - -dataset = dataset.apply_over_sampling(over_sample_size_lp=1) - -positions = al.Grid2DIrregular( - al.from_json(file_path=path.join(dataset_path, "positions.json")) -) - -# over_sample_size = al.util.over_sample.over_sample_size_via_radial_bins_from( -# grid=dataset.grid, -# sub_size_list=[4, 2, 1], -# radial_list=[0.3, 0.6], -# centre_list=[(0.0, 0.0)], -# ) -# -# dataset = dataset.apply_over_sampling(over_sample_size_lp=over_sample_size) -# - -""" -__Model__ - -We compose our model using `Model` objects, which represent the galaxies we fit to our data. In this -example we fit a model where: - - - The galaxy's bulge is a parametric `Sersic` bulge [7 parameters]. - - The galaxy's point source emission is a parametric operated `Gaussian` centred on the bulge [4 parameters]. - -The number of free parameters and therefore the dimensionality of non-linear parameter space is N=11. -""" -# Lens: - -bulge = af.Model(al.lp_linear.Sersic) - -mass = af.Model(al.mp.PowerLaw) - -shear = af.Model(al.mp.ExternalShear) - -lens = af.Model( - al.Galaxy, - redshift=0.5, - bulge=bulge, - mass=mass, - shear=shear, -) - -# Source: - -bulge = af.Model(al.lp_linear.Sersic) - -source = af.Model(al.Galaxy, redshift=1.0, bulge=bulge) - -# Overall Lens Model: - -model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) - -""" -The `info` attribute shows the model in a readable format. -""" -print(model.info) - -""" -__Analysis__ - -The `AnalysisImaging` object defines the `log_likelihood_function` which will be used to determine if JAX -can compute its gradient. -""" -analysis = al.AnalysisImaging( - dataset=dataset, - positions_likelihood_list=[al.PositionsLH(threshold=0.4, positions=positions)], - # settings=al.Settings(use_positive_only_solver=False) -) - - -""" -The analysis and `log_likelihood_function` are internally wrapped into a `Fitness` class in **PyAutoFit**, which pairs -the model with likelihood. - -This is the function on which JAX gradients are computed, so we create this class here. -""" -from autofit.non_linear.fitness import Fitness -import time - -batch_size = 50 - -fitness = Fitness( - model=model, - analysis=analysis, - fom_is_log_likelihood=True, - resample_figure_of_merit=-1.0e99, -) - -param_vector = jnp.array(model.physical_values_from_prior_medians) - -parameters = np.zeros((batch_size, model.total_free_parameters)) - -for i in range(batch_size): - parameters[i, :] = model.physical_values_from_prior_medians - -parameters = jnp.array(parameters) - -start = time.time() -print() -print(fitness._vmap(parameters)) -print("JAX Time To VMAP + JIT Function", time.time() - start) - -start = time.time() -print() -result = fitness._vmap(parameters) -print(result) -print("JAX Time Taken using VMAP:", time.time() - start) -print("JAX Time Taken per Likelihood:", (time.time() - start) / batch_size) - -np.testing.assert_allclose( - np.array(result), - -1.34797827e09, - rtol=1e-4, - err_msg="lp: JAX vmap likelihood mismatch", -) - - -""" -__Path A: jit-wrap ``analysis.fit_from``__ - -Wrap ``analysis.fit_from`` in ``jax.jit`` and assert the returned ``FitImaging`` -has a ``jax.Array`` ``log_likelihood`` that matches the NumPy-path scalar. -""" -from autofit.jax.pytrees import enable_pytrees, register_model - -enable_pytrees() -register_model(model) - -instance = model.instance_from_prior_medians() - -analysis_np = al.AnalysisImaging( - dataset=dataset, - positions_likelihood_list=[al.PositionsLH(threshold=0.4, positions=positions)], - use_jax=False, -) -fit_np = analysis_np.fit_from(instance=instance) -print("NumPy fit.log_likelihood:", float(fit_np.log_likelihood)) - -analysis_jit = al.AnalysisImaging( - dataset=dataset, - positions_likelihood_list=[al.PositionsLH(threshold=0.4, positions=positions)], - use_jax=True, -) -fit_jit_fn = jax.jit(analysis_jit.fit_from) -fit = fit_jit_fn(instance) - -print("JIT fit.log_likelihood:", fit.log_likelihood) -assert isinstance( - fit.log_likelihood, jnp.ndarray -), f"expected jax.Array, got {type(fit.log_likelihood)}" -np.testing.assert_allclose( - float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4 -) -print("PASS: jit(fit_from) round-trip matches NumPy scalar.") +""" +Func Grad: Light Parametric Operated +==================================== + +This script test if JAX can successfully compute the gradient of the log likelihood of an `Imaging` dataset with a +model which uses operated light profiles. + + __Operated Fitting__ + +It is common for galaxies to have point-source emission, for example bright emission right at their centre due to +an active galactic nuclei or very compact knot of star formation. + +This point-source emission is subject to blurring during data accquisiton due to the telescope optics, and therefore +is not seen as a single pixel of light but spread over multiple pixels as a convolution with the telescope +Point Spread Function (PSF). + +It is difficult to model this compact point source emission using a point-source light profile (or an extremely +compact Gaussian / Sersic profile). This is because when the model-image of a compact point source of light is +convolved with the PSF, the solution to this convolution is extremely sensitive to which pixel (and sub-pixel) the +compact model emission lands in. + +Operated light profiles offer an alternative approach, whereby the light profile is assumed to have already been +convolved with the PSF. This operated light profile is then fitted directly to the point-source emission, which as +discussed above shows the PSF features. +""" + +# %matplotlib inline +# from pyprojroot import here +# workspace_path = str(here()) +# %cd $workspace_path +# print(f"Working Directory has been set to `{workspace_path}`") + +import numpy as np +import jax.numpy as jnp +import jax +from jax import grad +from os import path + +import autofit as af +import autolens as al +from autoconf import conf + +""" +__Dataset__ + +Load and plot the galaxy dataset via .fits files. +""" +dataset_path = path.join("dataset", "imaging", "jax_test") + +""" +__Dataset Auto-Simulation__ + +If the dataset does not already exist on your system, it will be created by running the corresponding +simulator script. This ensures that all example scripts can be run without manually simulating data first. +""" +if al.util.dataset.should_simulate(dataset_path): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset = al.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=0.2, +) + + +""" +__Mask__ + +The model-fit requires a 2D mask defining the regions of the image we fit the model to the data, which we define +and use to set up the `Imaging` object that the model fits. +""" +mask_radius = 3.0 + +mask = al.Mask2D.circular( + shape_native=dataset.shape_native, + pixel_scales=dataset.pixel_scales, + radius=mask_radius, +) + +dataset = dataset.apply_mask(mask=mask) + +dataset = dataset.apply_over_sampling(over_sample_size_lp=1) + +positions = al.Grid2DIrregular( + al.from_json(file_path=path.join(dataset_path, "positions.json")) +) + +# over_sample_size = al.util.over_sample.over_sample_size_via_radial_bins_from( +# grid=dataset.grid, +# sub_size_list=[4, 2, 1], +# radial_list=[0.3, 0.6], +# centre_list=[(0.0, 0.0)], +# ) +# +# dataset = dataset.apply_over_sampling(over_sample_size_lp=over_sample_size) +# + +""" +__Model__ + +We compose our model using `Model` objects, which represent the galaxies we fit to our data. In this +example we fit a model where: + + - The galaxy's bulge is a parametric `Sersic` bulge [7 parameters]. + - The galaxy's point source emission is a parametric operated `Gaussian` centred on the bulge [4 parameters]. + +The number of free parameters and therefore the dimensionality of non-linear parameter space is N=11. +""" +# Lens: + +bulge = af.Model(al.lp_linear.Sersic) + +mass = af.Model(al.mp.PowerLaw) + +shear = af.Model(al.mp.ExternalShear) + +lens = af.Model( + al.Galaxy, + redshift=0.5, + bulge=bulge, + mass=mass, + shear=shear, +) + +# Source: + +bulge = af.Model(al.lp_linear.Sersic) + +source = af.Model(al.Galaxy, redshift=1.0, bulge=bulge) + +# Overall Lens Model: + +model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) + +""" +The `info` attribute shows the model in a readable format. +""" +print(model.info) + +""" +__Analysis__ + +The `AnalysisImaging` object defines the `log_likelihood_function` which will be used to determine if JAX +can compute its gradient. +""" +analysis = al.AnalysisImaging( + dataset=dataset, + positions_likelihood_list=[al.PositionsLH(threshold=0.4, positions=positions)], + # settings=al.Settings(use_positive_only_solver=False) +) + + +""" +The analysis and `log_likelihood_function` are internally wrapped into a `Fitness` class in **PyAutoFit**, which pairs +the model with likelihood. + +This is the function on which JAX gradients are computed, so we create this class here. +""" +from autofit.non_linear.fitness import Fitness +import time + +batch_size = 50 + +fitness = Fitness( + model=model, + analysis=analysis, + fom_is_log_likelihood=True, + resample_figure_of_merit=-1.0e99, +) + +param_vector = jnp.array(model.physical_values_from_prior_medians) + +parameters = np.zeros((batch_size, model.total_free_parameters)) + +for i in range(batch_size): + parameters[i, :] = model.physical_values_from_prior_medians + +parameters = jnp.array(parameters) + +start = time.time() +print() +print(fitness._vmap(parameters)) +print("JAX Time To VMAP + JIT Function", time.time() - start) + +start = time.time() +print() +result = fitness._vmap(parameters) +print(result) +print("JAX Time Taken using VMAP:", time.time() - start) +print("JAX Time Taken per Likelihood:", (time.time() - start) / batch_size) + +np.testing.assert_allclose( + np.array(result), + -1.34797827e09, + rtol=1e-4, + err_msg="lp: JAX vmap likelihood mismatch", +) + + +""" +__Path A: jit-wrap ``analysis.fit_from``__ + +Wrap ``analysis.fit_from`` in ``jax.jit`` and assert the returned ``FitImaging`` +has a ``jax.Array`` ``log_likelihood`` that matches the NumPy-path scalar. +""" + + +instance = model.instance_from_prior_medians() + +analysis_np = al.AnalysisImaging( + dataset=dataset, + positions_likelihood_list=[al.PositionsLH(threshold=0.4, positions=positions)], + use_jax=False, +) +fit_np = analysis_np.fit_from(instance=instance) +print("NumPy fit.log_likelihood:", float(fit_np.log_likelihood)) + +analysis_jit = al.AnalysisImaging( + dataset=dataset, + positions_likelihood_list=[al.PositionsLH(threshold=0.4, positions=positions)], + use_jax=True, +) +fit_jit_fn = jax.jit(analysis_jit.fit_from) +fit = fit_jit_fn(instance) + +print("JIT fit.log_likelihood:", fit.log_likelihood) +assert isinstance( + fit.log_likelihood, jnp.ndarray +), f"expected jax.Array, got {type(fit.log_likelihood)}" +np.testing.assert_allclose( + float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4 +) +print("PASS: jit(fit_from) round-trip matches NumPy scalar.") diff --git a/scripts/jax_likelihood_functions/imaging/mge.py b/scripts/jax_likelihood_functions/imaging/mge.py index b3ce64e2..922e2e20 100644 --- a/scripts/jax_likelihood_functions/imaging/mge.py +++ b/scripts/jax_likelihood_functions/imaging/mge.py @@ -1,276 +1,273 @@ -""" -Func Grad: Light Parametric Operated -==================================== - -This script test if JAX can successfully compute the gradient of the log likelihood of an `Imaging` dataset with a -model which uses operated light profiles. - - __Operated Fitting__ - -It is common for galaxies to have point-source emission, for example bright emission right at their centre due to -an active galactic nuclei or very compact knot of star formation. - -This point-source emission is subject to blurring during data accquisiton due to the telescope optics, and therefore -is not seen as a single pixel of light but spread over multiple pixels as a convolution with the telescope -Point Spread Function (PSF). - -It is difficult to model this compact point source emission using a point-source light profile (or an extremely -compact Gaussian / Sersic profile). This is because when the model-image of a compact point source of light is -convolved with the PSF, the solution to this convolution is extremely sensitive to which pixel (and sub-pixel) the -compact model emission lands in. - -Operated light profiles offer an alternative approach, whereby the light profile is assumed to have already been -convolved with the PSF. This operated light profile is then fitted directly to the point-source emission, which as -discussed above shows the PSF features. -""" - -# %matplotlib inline -# from pyprojroot import here -# workspace_path = str(here()) -# %cd $workspace_path -# print(f"Working Directory has been set to `{workspace_path}`") - -import numpy as np -import jax -import jax.numpy as jnp -from jax import grad -from os import path - -import autofit as af -import autolens as al -from autoconf import conf - - -""" -__Dataset__ - -Load and plot the galaxy dataset via .fits files. -""" -dataset_path = path.join("dataset", "imaging", "jax_test") - -""" -__Dataset Auto-Simulation__ - -If the dataset does not already exist on your system, it will be created by running the corresponding -simulator script. This ensures that all example scripts can be run without manually simulating data first. -""" -if al.util.dataset.should_simulate(dataset_path): - import subprocess - import sys - - subprocess.run( - [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], - check=True, - ) - -dataset = al.Imaging.from_fits( - data_path=path.join(dataset_path, "data.fits"), - psf_path=path.join(dataset_path, "psf.fits"), - noise_map_path=path.join(dataset_path, "noise_map.fits"), - pixel_scales=0.2, -) - -""" -__Mask__ - -The model-fit requires a 2D mask defining the regions of the image we fit the model to the data, which we define -and use to set up the `Imaging` object that the model fits. -""" -mask_radius = 3.5 - -mask = al.Mask2D.circular( - shape_native=dataset.shape_native, - pixel_scales=dataset.pixel_scales, - radius=mask_radius, -) - -dataset = dataset.apply_mask(mask=mask) - -dataset = dataset.apply_over_sampling(over_sample_size_lp=4) - -# positions = al.Grid2DIrregular( -# al.from_json(file_path=path.join(dataset_path, "positions.json")) -# ) - -over_sample_size = al.util.over_sample.over_sample_size_via_radial_bins_from( - grid=dataset.grid, - sub_size_list=[4, 2, 1], - radial_list=[0.3, 0.6], - centre_list=[(0.0, 0.0)], -) - -dataset = dataset.apply_over_sampling(over_sample_size_lp=over_sample_size) - - -""" -__Model__ - -We compose our model using `Model` objects, which represent the galaxies we fit to our data. In this -example we fit a model where: - - - The galaxy's bulge is a parametric `Sersic` bulge [7 parameters]. - - The galaxy's point source emission is a parametric operated `Gaussian` centred on the bulge [4 parameters]. - -The number of free parameters and therefore the dimensionality of non-linear parameter space is N=11. -""" -# Lens: - -bulge = al.model_util.mge_model_from( - mask_radius=mask_radius, total_gaussians=20, centre_prior_is_uniform=True -) - -# mass = af.Model(al.mp.Gaussian) - -mass = af.Model(al.mp.NFWSph) - -total_gaussians = 3 - -# The sigma values of the Gaussians will be fixed to values spanning 0.01 to the mask radius, 3.0". -mask_radius = 3.0 -log10_sigma_list = np.linspace(-2, np.log10(mask_radius), total_gaussians) - -# By defining the centre here, it creates two free parameters that are assigned below to all Gaussians. - -centre_0 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) -centre_1 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) - -bulge_gaussian_list = [] - -gaussian_list = af.Collection( - af.Model(al.lmp_linear.GaussianGradient) for _ in range(total_gaussians) -) - -for i, gaussian in enumerate(gaussian_list): - gaussian.centre.centre_0 = centre_0 # All Gaussians have same y centre. - gaussian.centre.centre_1 = centre_1 # All Gaussians have same x centre. - gaussian.ell_comps = gaussian_list[ - 0 - ].ell_comps # All Gaussians have same elliptical components. - gaussian.sigma = ( - 10 ** log10_sigma_list[i] - ) # All Gaussian sigmas are fixed to values above. - gaussian.mass_to_light_ratio = 10.0 - gaussian.mass_to_light_gradient = 1.0 - -bulge_gaussian_list += gaussian_list - -# The Basis object groups many light profiles together into a single model component. - -bulge = af.Model( - al.lp_basis.Basis, - profile_list=bulge_gaussian_list, -) - -shear = af.Model(al.mp.ExternalShear) - -lens = af.Model(al.Galaxy, redshift=0.5, bulge=bulge, mass=mass, shear=shear) - -# Source: - -total_gaussians = 30 -gaussian_per_basis = 1 - -# By defining the centre here, it creates two free parameters that are assigned to the source Gaussians. - -bulge = al.model_util.mge_model_from( - mask_radius=mask_radius, total_gaussians=20, centre_prior_is_uniform=False -) - -source = af.Model(al.Galaxy, redshift=1.0, bulge=bulge) - -# Overall Lens Model: - -model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) - -""" -The `info` attribute shows the model in a readable format. -""" -print(model.info) - -""" -__Analysis__ - -The `AnalysisImaging` object defines the `log_likelihood_function` which will be used to determine if JAX -can compute its gradient. -""" -analysis = al.AnalysisImaging( - dataset=dataset, -) - - -""" -The analysis and `log_likelihood_function` are internally wrapped into a `Fitness` class in **PyAutoFit**, which pairs -the model with likelihood. - -This is the function on which JAX gradients are computed, so we create this class here. -""" -from autofit.non_linear.fitness import Fitness -import time - -batch_size = 3 - -fitness = Fitness( - model=model, - analysis=analysis, - fom_is_log_likelihood=True, - resample_figure_of_merit=-1.0e99, -) - -param_vector = jnp.array(model.physical_values_from_prior_medians) - -parameters = np.zeros((batch_size, model.total_free_parameters)) - -for i in range(batch_size): - parameters[i, :] = model.physical_values_from_prior_medians - -parameters = jnp.array(parameters) - -start = time.time() -print() -print(fitness._vmap(parameters)) -print("JAX Time To VMAP + JIT Function", time.time() - start) - -start = time.time() -print() -result = fitness._vmap(parameters) -print(result) -print("JAX Time Taken using VMAP:", time.time() - start) -print("JAX Time Taken per Likelihood:", (time.time() - start) / batch_size) - -np.testing.assert_allclose( - np.array(result), - -86629.349379, - rtol=1e-4, - err_msg="mge: JAX vmap likelihood mismatch", -) - - -""" -__Path A: jit-wrap ``analysis.fit_from``__ - -Wrap ``analysis.fit_from`` in ``jax.jit`` and assert the returned ``FitImaging`` -has a ``jax.Array`` ``log_likelihood`` that matches the NumPy-path scalar. -""" -from autofit.jax.pytrees import enable_pytrees, register_model - -enable_pytrees() -register_model(model) - -instance = model.instance_from_prior_medians() - -analysis_np = al.AnalysisImaging(dataset=dataset, use_jax=False) -fit_np = analysis_np.fit_from(instance=instance) -print("NumPy fit.log_likelihood:", float(fit_np.log_likelihood)) - -analysis_jit = al.AnalysisImaging(dataset=dataset, use_jax=True) -fit_jit_fn = jax.jit(analysis_jit.fit_from) -fit = fit_jit_fn(instance) - -print("JIT fit.log_likelihood:", fit.log_likelihood) -assert isinstance( - fit.log_likelihood, jnp.ndarray -), f"expected jax.Array, got {type(fit.log_likelihood)}" -np.testing.assert_allclose( - float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4 -) -print("PASS: jit(fit_from) round-trip matches NumPy scalar.") +""" +Func Grad: Light Parametric Operated +==================================== + +This script test if JAX can successfully compute the gradient of the log likelihood of an `Imaging` dataset with a +model which uses operated light profiles. + + __Operated Fitting__ + +It is common for galaxies to have point-source emission, for example bright emission right at their centre due to +an active galactic nuclei or very compact knot of star formation. + +This point-source emission is subject to blurring during data accquisiton due to the telescope optics, and therefore +is not seen as a single pixel of light but spread over multiple pixels as a convolution with the telescope +Point Spread Function (PSF). + +It is difficult to model this compact point source emission using a point-source light profile (or an extremely +compact Gaussian / Sersic profile). This is because when the model-image of a compact point source of light is +convolved with the PSF, the solution to this convolution is extremely sensitive to which pixel (and sub-pixel) the +compact model emission lands in. + +Operated light profiles offer an alternative approach, whereby the light profile is assumed to have already been +convolved with the PSF. This operated light profile is then fitted directly to the point-source emission, which as +discussed above shows the PSF features. +""" + +# %matplotlib inline +# from pyprojroot import here +# workspace_path = str(here()) +# %cd $workspace_path +# print(f"Working Directory has been set to `{workspace_path}`") + +import numpy as np +import jax +import jax.numpy as jnp +from jax import grad +from os import path + +import autofit as af +import autolens as al +from autoconf import conf + + +""" +__Dataset__ + +Load and plot the galaxy dataset via .fits files. +""" +dataset_path = path.join("dataset", "imaging", "jax_test") + +""" +__Dataset Auto-Simulation__ + +If the dataset does not already exist on your system, it will be created by running the corresponding +simulator script. This ensures that all example scripts can be run without manually simulating data first. +""" +if al.util.dataset.should_simulate(dataset_path): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset = al.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=0.2, +) + +""" +__Mask__ + +The model-fit requires a 2D mask defining the regions of the image we fit the model to the data, which we define +and use to set up the `Imaging` object that the model fits. +""" +mask_radius = 3.5 + +mask = al.Mask2D.circular( + shape_native=dataset.shape_native, + pixel_scales=dataset.pixel_scales, + radius=mask_radius, +) + +dataset = dataset.apply_mask(mask=mask) + +dataset = dataset.apply_over_sampling(over_sample_size_lp=4) + +# positions = al.Grid2DIrregular( +# al.from_json(file_path=path.join(dataset_path, "positions.json")) +# ) + +over_sample_size = al.util.over_sample.over_sample_size_via_radial_bins_from( + grid=dataset.grid, + sub_size_list=[4, 2, 1], + radial_list=[0.3, 0.6], + centre_list=[(0.0, 0.0)], +) + +dataset = dataset.apply_over_sampling(over_sample_size_lp=over_sample_size) + + +""" +__Model__ + +We compose our model using `Model` objects, which represent the galaxies we fit to our data. In this +example we fit a model where: + + - The galaxy's bulge is a parametric `Sersic` bulge [7 parameters]. + - The galaxy's point source emission is a parametric operated `Gaussian` centred on the bulge [4 parameters]. + +The number of free parameters and therefore the dimensionality of non-linear parameter space is N=11. +""" +# Lens: + +bulge = al.model_util.mge_model_from( + mask_radius=mask_radius, total_gaussians=20, centre_prior_is_uniform=True +) + +# mass = af.Model(al.mp.Gaussian) + +mass = af.Model(al.mp.NFWSph) + +total_gaussians = 3 + +# The sigma values of the Gaussians will be fixed to values spanning 0.01 to the mask radius, 3.0". +mask_radius = 3.0 +log10_sigma_list = np.linspace(-2, np.log10(mask_radius), total_gaussians) + +# By defining the centre here, it creates two free parameters that are assigned below to all Gaussians. + +centre_0 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) +centre_1 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) + +bulge_gaussian_list = [] + +gaussian_list = af.Collection( + af.Model(al.lmp_linear.GaussianGradient) for _ in range(total_gaussians) +) + +for i, gaussian in enumerate(gaussian_list): + gaussian.centre.centre_0 = centre_0 # All Gaussians have same y centre. + gaussian.centre.centre_1 = centre_1 # All Gaussians have same x centre. + gaussian.ell_comps = gaussian_list[ + 0 + ].ell_comps # All Gaussians have same elliptical components. + gaussian.sigma = ( + 10 ** log10_sigma_list[i] + ) # All Gaussian sigmas are fixed to values above. + gaussian.mass_to_light_ratio = 10.0 + gaussian.mass_to_light_gradient = 1.0 + +bulge_gaussian_list += gaussian_list + +# The Basis object groups many light profiles together into a single model component. + +bulge = af.Model( + al.lp_basis.Basis, + profile_list=bulge_gaussian_list, +) + +shear = af.Model(al.mp.ExternalShear) + +lens = af.Model(al.Galaxy, redshift=0.5, bulge=bulge, mass=mass, shear=shear) + +# Source: + +total_gaussians = 30 +gaussian_per_basis = 1 + +# By defining the centre here, it creates two free parameters that are assigned to the source Gaussians. + +bulge = al.model_util.mge_model_from( + mask_radius=mask_radius, total_gaussians=20, centre_prior_is_uniform=False +) + +source = af.Model(al.Galaxy, redshift=1.0, bulge=bulge) + +# Overall Lens Model: + +model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) + +""" +The `info` attribute shows the model in a readable format. +""" +print(model.info) + +""" +__Analysis__ + +The `AnalysisImaging` object defines the `log_likelihood_function` which will be used to determine if JAX +can compute its gradient. +""" +analysis = al.AnalysisImaging( + dataset=dataset, +) + + +""" +The analysis and `log_likelihood_function` are internally wrapped into a `Fitness` class in **PyAutoFit**, which pairs +the model with likelihood. + +This is the function on which JAX gradients are computed, so we create this class here. +""" +from autofit.non_linear.fitness import Fitness +import time + +batch_size = 3 + +fitness = Fitness( + model=model, + analysis=analysis, + fom_is_log_likelihood=True, + resample_figure_of_merit=-1.0e99, +) + +param_vector = jnp.array(model.physical_values_from_prior_medians) + +parameters = np.zeros((batch_size, model.total_free_parameters)) + +for i in range(batch_size): + parameters[i, :] = model.physical_values_from_prior_medians + +parameters = jnp.array(parameters) + +start = time.time() +print() +print(fitness._vmap(parameters)) +print("JAX Time To VMAP + JIT Function", time.time() - start) + +start = time.time() +print() +result = fitness._vmap(parameters) +print(result) +print("JAX Time Taken using VMAP:", time.time() - start) +print("JAX Time Taken per Likelihood:", (time.time() - start) / batch_size) + +np.testing.assert_allclose( + np.array(result), + -86629.349379, + rtol=1e-4, + err_msg="mge: JAX vmap likelihood mismatch", +) + + +""" +__Path A: jit-wrap ``analysis.fit_from``__ + +Wrap ``analysis.fit_from`` in ``jax.jit`` and assert the returned ``FitImaging`` +has a ``jax.Array`` ``log_likelihood`` that matches the NumPy-path scalar. +""" + + +instance = model.instance_from_prior_medians() + +analysis_np = al.AnalysisImaging(dataset=dataset, use_jax=False) +fit_np = analysis_np.fit_from(instance=instance) +print("NumPy fit.log_likelihood:", float(fit_np.log_likelihood)) + +analysis_jit = al.AnalysisImaging(dataset=dataset, use_jax=True) +fit_jit_fn = jax.jit(analysis_jit.fit_from) +fit = fit_jit_fn(instance) + +print("JIT fit.log_likelihood:", fit.log_likelihood) +assert isinstance( + fit.log_likelihood, jnp.ndarray +), f"expected jax.Array, got {type(fit.log_likelihood)}" +np.testing.assert_allclose( + float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4 +) +print("PASS: jit(fit_from) round-trip matches NumPy scalar.") diff --git a/scripts/jax_likelihood_functions/imaging/mge_group.py b/scripts/jax_likelihood_functions/imaging/mge_group.py index d4bac633..48713206 100644 --- a/scripts/jax_likelihood_functions/imaging/mge_group.py +++ b/scripts/jax_likelihood_functions/imaging/mge_group.py @@ -314,10 +314,7 @@ """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/imaging/rectangular.py b/scripts/jax_likelihood_functions/imaging/rectangular.py index 20e37662..3ae56e41 100644 --- a/scripts/jax_likelihood_functions/imaging/rectangular.py +++ b/scripts/jax_likelihood_functions/imaging/rectangular.py @@ -268,10 +268,7 @@ """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/imaging/rectangular_dspl.py b/scripts/jax_likelihood_functions/imaging/rectangular_dspl.py index 0a5d373a..eb99c9a7 100644 --- a/scripts/jax_likelihood_functions/imaging/rectangular_dspl.py +++ b/scripts/jax_likelihood_functions/imaging/rectangular_dspl.py @@ -283,10 +283,7 @@ """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/imaging/rectangular_mge.py b/scripts/jax_likelihood_functions/imaging/rectangular_mge.py index 26c41ab4..966c5aee 100644 --- a/scripts/jax_likelihood_functions/imaging/rectangular_mge.py +++ b/scripts/jax_likelihood_functions/imaging/rectangular_mge.py @@ -306,10 +306,7 @@ """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/imaging/subhalo.py b/scripts/jax_likelihood_functions/imaging/subhalo.py index bbfa2e8b..2c4fc0c9 100644 --- a/scripts/jax_likelihood_functions/imaging/subhalo.py +++ b/scripts/jax_likelihood_functions/imaging/subhalo.py @@ -191,11 +191,7 @@ def build_model(redshift_subhalo, subhalo_mass_factory): """ from autofit.non_linear.fitness import Fitness -from autofit.jax.pytrees import enable_pytrees, register_model -# enable_pytrees once globally so all scenarios benefit from it for the -# single-instance jit wrap. ``register_model`` is called per-scenario below. -enable_pytrees() def run_scenario( @@ -214,7 +210,6 @@ def run_scenario( print("=" * 72) model = build_model(redshift_subhalo, subhalo_mass_factory) - register_model(model) analysis = al.AnalysisImaging( dataset=dataset, diff --git a/scripts/jax_likelihood_functions/interferometer/delaunay.py b/scripts/jax_likelihood_functions/interferometer/delaunay.py index 10eccd8b..0b69adcf 100644 --- a/scripts/jax_likelihood_functions/interferometer/delaunay.py +++ b/scripts/jax_likelihood_functions/interferometer/delaunay.py @@ -239,10 +239,7 @@ class in **PyAutoFit**, which pairs the model with likelihood. """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/interferometer/delaunay_mge.py b/scripts/jax_likelihood_functions/interferometer/delaunay_mge.py index fa35834c..9fcf5ef7 100644 --- a/scripts/jax_likelihood_functions/interferometer/delaunay_mge.py +++ b/scripts/jax_likelihood_functions/interferometer/delaunay_mge.py @@ -247,10 +247,7 @@ class in **PyAutoFit**, which pairs the model with likelihood. """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/interferometer/lp.py b/scripts/jax_likelihood_functions/interferometer/lp.py index 140bd356..3793925a 100644 --- a/scripts/jax_likelihood_functions/interferometer/lp.py +++ b/scripts/jax_likelihood_functions/interferometer/lp.py @@ -186,10 +186,7 @@ class in **PyAutoFit**, which pairs the model with likelihood. """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/interferometer/mge.py b/scripts/jax_likelihood_functions/interferometer/mge.py index 0d4261cd..58eee918 100644 --- a/scripts/jax_likelihood_functions/interferometer/mge.py +++ b/scripts/jax_likelihood_functions/interferometer/mge.py @@ -202,10 +202,7 @@ """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/interferometer/mge_group.py b/scripts/jax_likelihood_functions/interferometer/mge_group.py index cfb60e9d..3c63c32b 100644 --- a/scripts/jax_likelihood_functions/interferometer/mge_group.py +++ b/scripts/jax_likelihood_functions/interferometer/mge_group.py @@ -131,10 +131,7 @@ """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/interferometer/rectangular.py b/scripts/jax_likelihood_functions/interferometer/rectangular.py index 57782666..de5da478 100644 --- a/scripts/jax_likelihood_functions/interferometer/rectangular.py +++ b/scripts/jax_likelihood_functions/interferometer/rectangular.py @@ -272,10 +272,7 @@ """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/interferometer/rectangular_dspl.py b/scripts/jax_likelihood_functions/interferometer/rectangular_dspl.py index 45af0774..fdc1797a 100644 --- a/scripts/jax_likelihood_functions/interferometer/rectangular_dspl.py +++ b/scripts/jax_likelihood_functions/interferometer/rectangular_dspl.py @@ -231,10 +231,7 @@ class in **PyAutoFit**, which pairs the model with likelihood. """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/interferometer/rectangular_mge.py b/scripts/jax_likelihood_functions/interferometer/rectangular_mge.py index 9ccbae48..1a4c1c03 100644 --- a/scripts/jax_likelihood_functions/interferometer/rectangular_mge.py +++ b/scripts/jax_likelihood_functions/interferometer/rectangular_mge.py @@ -231,10 +231,7 @@ class in **PyAutoFit**, which pairs the model with likelihood. """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/interferometer/rectangular_sparse.py b/scripts/jax_likelihood_functions/interferometer/rectangular_sparse.py index 57f6525f..1e00f75c 100644 --- a/scripts/jax_likelihood_functions/interferometer/rectangular_sparse.py +++ b/scripts/jax_likelihood_functions/interferometer/rectangular_sparse.py @@ -222,10 +222,7 @@ class in **PyAutoFit**, which pairs the model with likelihood. """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/light_multipole/multipole.py b/scripts/jax_likelihood_functions/light_multipole/multipole.py index 574e1da7..e3c9af7d 100644 --- a/scripts/jax_likelihood_functions/light_multipole/multipole.py +++ b/scripts/jax_likelihood_functions/light_multipole/multipole.py @@ -149,10 +149,7 @@ """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/multi/dataset_model.py b/scripts/jax_likelihood_functions/multi/dataset_model.py index fc62ddc3..427b0ff1 100644 --- a/scripts/jax_likelihood_functions/multi/dataset_model.py +++ b/scripts/jax_likelihood_functions/multi/dataset_model.py @@ -185,10 +185,7 @@ ``register_model`` is what registers ``DatasetModel`` (along with every other class in the model tree) as a JAX pytree. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(factor_graph.global_prior_model) analysis_np_list = [ al.AnalysisImaging(dataset=dataset, use_jax=False) for dataset in dataset_list diff --git a/scripts/jax_likelihood_functions/multi/delaunay.py b/scripts/jax_likelihood_functions/multi/delaunay.py index c6fa37ee..18d29457 100644 --- a/scripts/jax_likelihood_functions/multi/delaunay.py +++ b/scripts/jax_likelihood_functions/multi/delaunay.py @@ -215,10 +215,7 @@ """ __Path A: jit-wrap parameter-vector entry point__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(factor_graph.global_prior_model) @jax.jit diff --git a/scripts/jax_likelihood_functions/multi/delaunay_mge.py b/scripts/jax_likelihood_functions/multi/delaunay_mge.py index 8d9fbdba..5153c664 100644 --- a/scripts/jax_likelihood_functions/multi/delaunay_mge.py +++ b/scripts/jax_likelihood_functions/multi/delaunay_mge.py @@ -213,10 +213,7 @@ """ __Path A: jit-wrap parameter-vector entry point__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(factor_graph.global_prior_model) @jax.jit diff --git a/scripts/jax_likelihood_functions/multi/lp.py b/scripts/jax_likelihood_functions/multi/lp.py index dafd35a9..1113bc22 100644 --- a/scripts/jax_likelihood_functions/multi/lp.py +++ b/scripts/jax_likelihood_functions/multi/lp.py @@ -167,10 +167,7 @@ on the instance, and JAX pytree-flattens the whole instance and chokes on that non-registered leaf. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(factor_graph.global_prior_model) analysis_np_list = [ al.AnalysisImaging(dataset=dataset, use_jax=False) for dataset in dataset_list diff --git a/scripts/jax_likelihood_functions/multi/mge_group.py b/scripts/jax_likelihood_functions/multi/mge_group.py index fe7eef0e..bae81274 100644 --- a/scripts/jax_likelihood_functions/multi/mge_group.py +++ b/scripts/jax_likelihood_functions/multi/mge_group.py @@ -239,10 +239,7 @@ """ __Path A: jit-wrap parameter-vector entry point__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(factor_graph.global_prior_model) analysis_np_list = [ al.AnalysisImaging(dataset=dataset, use_jax=False) for dataset in dataset_list diff --git a/scripts/jax_likelihood_functions/multi/rectangular.py b/scripts/jax_likelihood_functions/multi/rectangular.py index be6ff13f..9c9a6253 100644 --- a/scripts/jax_likelihood_functions/multi/rectangular.py +++ b/scripts/jax_likelihood_functions/multi/rectangular.py @@ -211,10 +211,7 @@ """ __Path A: jit-wrap parameter-vector entry point__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(factor_graph.global_prior_model) @jax.jit diff --git a/scripts/jax_likelihood_functions/multi/rectangular_mge.py b/scripts/jax_likelihood_functions/multi/rectangular_mge.py index 528b49ba..4864fdff 100644 --- a/scripts/jax_likelihood_functions/multi/rectangular_mge.py +++ b/scripts/jax_likelihood_functions/multi/rectangular_mge.py @@ -196,10 +196,7 @@ """ __Path A: jit-wrap parameter-vector entry point__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(factor_graph.global_prior_model) @jax.jit diff --git a/scripts/jax_likelihood_functions/point_source/image_plane.py b/scripts/jax_likelihood_functions/point_source/image_plane.py index bf4a19da..ebe9cfdd 100644 --- a/scripts/jax_likelihood_functions/point_source/image_plane.py +++ b/scripts/jax_likelihood_functions/point_source/image_plane.py @@ -145,12 +145,9 @@ as ``point.py``): the cosmology distance calc caches intermediate values in global state, triggering ``UnexpectedTracerError`` under ``jit``. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() model_jit = af.Collection(galaxies=af.Collection(lens=lens, source=source)) -register_model(model_jit) instance = model_jit.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/point_source/point.py b/scripts/jax_likelihood_functions/point_source/point.py index 8bc47b76..28312c63 100644 --- a/scripts/jax_likelihood_functions/point_source/point.py +++ b/scripts/jax_likelihood_functions/point_source/point.py @@ -248,12 +248,9 @@ the vmap path above handles it fine. Once that library-level leak is fixed, this block can reuse ``model`` directly. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() model_jit = af.Collection(galaxies=af.Collection(lens=lens, source=source)) -register_model(model_jit) instance = model_jit.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/point_source/source_plane.py b/scripts/jax_likelihood_functions/point_source/source_plane.py index 48a7c24f..abb5e442 100644 --- a/scripts/jax_likelihood_functions/point_source/source_plane.py +++ b/scripts/jax_likelihood_functions/point_source/source_plane.py @@ -161,12 +161,9 @@ xp-propagation bug. The eager NumPy log-likelihood is still asserted for regression coverage. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() model_jit = af.Collection(galaxies=af.Collection(lens=lens, source=source)) -register_model(model_jit) instance = model_jit.instance_from_prior_medians() diff --git a/scripts/point_source/modeling_visualization_jit.py b/scripts/point_source/modeling_visualization_jit.py index 9fb1de8f..0716050c 100644 --- a/scripts/point_source/modeling_visualization_jit.py +++ b/scripts/point_source/modeling_visualization_jit.py @@ -3,17 +3,16 @@ ============================================================================ Exercises the full JAX visualization pipeline for the point-source analysis -path: ``AnalysisPoint(use_jax=True, use_jax_for_visualization=True)`` with -an ``Isothermal`` lens mass and ``PointFlux`` source (image-plane chi-squared -via ``FitPositionsImagePairAll``). +path: ``AnalysisPoint(use_jax=True)`` with an ``Isothermal`` lens mass and +``PointFlux`` source (image-plane chi-squared via +``FitPositionsImagePairAll``). This test runs in two parts: Part 1 — **Caching probe.** Calls ``analysis.fit_for_visualization(instance)`` twice and asserts the second call is much faster than the first (confirming the compiled function is cached on the analysis instance, not recompiled per -visualization call). Also asserts ``analysis._jitted_fit_from is not None`` -after the first call. +visualization call). Part 2 — **Live Nautilus quick-update.** Runs a real (short) Nautilus fit. The live search fires quick-update visualization every @@ -23,8 +22,8 @@ search callback. This script deliberately opts in with -``AnalysisPoint(use_jax=True, use_jax_for_visualization=True)``. -Default model-fit scripts elsewhere in the workspace leave both flags at +``AnalysisPoint(use_jax=True)``. +Default model-fit scripts elsewhere in the workspace leave the flag at ``False`` and are therefore untouched. """ @@ -38,9 +37,7 @@ import autofit as af import autolens as al -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() """ @@ -102,14 +99,12 @@ model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) -register_model(model) analysis = al.AnalysisPoint( dataset=dataset, solver=solver, fit_positions_cls=al.FitPositionsImagePairAll, use_jax=True, - use_jax_for_visualization=True, ) instance = model.instance_from_prior_medians() @@ -137,9 +132,6 @@ f"Cached call ({cached_time:.3f}s) not faster than compile " f"({compile_time:.3f}s) — JIT cache is not being hit." ) -assert ( - analysis._jitted_fit_from is not None -), "expected _jitted_fit_from to be cached on the analysis instance after first call" print("PASS: Point-source jit-cached fit_for_visualization works and is reused.") @@ -202,8 +194,8 @@ Part 2 — Live Nautilus quick-update ============================================================================ -Rebuild the model fresh (register_model on the new instance), create a -separate analysis object, and run a short Nautilus fit. The search fires +Rebuild the model fresh, create a separate analysis object, and run a short +Nautilus fit. The search fires quick-update visualization every ``iterations_per_quick_update`` calls; we assert that ``fit.png`` lands on disk under the Nautilus output tree. """ @@ -228,14 +220,12 @@ model2 = af.Collection(galaxies=af.Collection(lens=lens2, source=source2)) -register_model(model2) analysis_run = al.AnalysisPoint( dataset=dataset, solver=solver, fit_positions_cls=al.FitPositionsImagePairAll, use_jax=True, - use_jax_for_visualization=True, ) output_root = Path("scripts") / "point_source" / "images" / "modeling_visualization_jit" @@ -271,10 +261,6 @@ f"no fit.png produced under {output_search_root} — " "quick-update visualization did not fire" ) -assert ( - analysis_run._jitted_fit_from is not None -), "expected _jitted_fit_from to be cached on the analysis instance during search" - print( "\nPASS: jit-cached fit_for_visualization fires during Nautilus quick updates " f"for point source, fit.png written." diff --git a/scripts/point_source/visualization_jax.py b/scripts/point_source/visualization_jax.py index c35e6c91..846ef7f9 100644 --- a/scripts/point_source/visualization_jax.py +++ b/scripts/point_source/visualization_jax.py @@ -6,12 +6,12 @@ Goal ---- -Run ``VisualizerPoint.visualize`` with ``use_jax=True`` and -``use_jax_for_visualization=True`` on ``AnalysisPoint``. The point -visualizer dispatches through ``analysis.fit_for_visualization``, which -lazily wraps ``fit_from`` in ``jax.jit``. To trace across that boundary the -model and fit return type must be JAX pytrees, so this script enables pytree -registration before constructing the model. +Run ``VisualizerPoint.visualize`` with ``use_jax=True`` on ``AnalysisPoint``. +Visualization now follows ``use_jax`` automatically — the point visualizer +dispatches through ``analysis.fit_for_visualization``, which lazily wraps +``fit_from`` in ``jax.jit``. To trace across that boundary the model and fit +return type must be JAX pytrees, so this script enables pytree registration +before constructing the model. Scope ----- @@ -27,10 +27,8 @@ import autofit as af import autolens as al -from autofit.jax.pytrees import enable_pytrees, register_model from autolens.point.model.visualizer import VisualizerPoint -enable_pytrees() """ @@ -86,15 +84,13 @@ model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) -register_model(model) """ __Analysis__ -``use_jax=True`` turns on the JAX ``_xp`` path; -``use_jax_for_visualization=True`` tells the visualization path to wrap -``fit_from`` in ``jax.jit`` via ``Analysis.fit_for_visualization``. +``use_jax=True`` turns on the JAX ``_xp`` path. Visualization now follows +``use_jax`` automatically via ``Analysis.fit_for_visualization``. ``title_prefix`` is passed through via PR #506's **kwargs fix. """ analysis = al.AnalysisPoint( @@ -102,7 +98,6 @@ solver=solver, fit_positions_cls=al.FitPositionsImagePairAll, use_jax=True, - use_jax_for_visualization=True, title_prefix="JAX_PILOT", ) @@ -124,7 +119,7 @@ """ instance = model.instance_from_prior_medians() -print("Running VisualizerPoint.visualize with use_jax_for_visualization=True ...") +print("Running VisualizerPoint.visualize with use_jax=True ...") VisualizerPoint.visualize( analysis=analysis, paths=paths,