From 608534a579a7da5de928dd40bb37bafd6328b934 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 27 May 2026 12:58:57 +0100 Subject: [PATCH] Remove use_jax_for_visualization from all scripts The flag has been removed from Analysis.__init__ in PyAutoFit PR #1297. Visualization now follows use_jax automatically. Also removes enable_pytrees() / register_model() calls that were only needed for the jax.jit(fit_from) visualization path which no longer exists. Co-Authored-By: Claude Opus 4.7 --- scripts/ellipse/modeling_visualization_jit.py | 15 ++------- scripts/ellipse/visualization_jax.py | 31 ++++++++----------- scripts/imaging/modeling_visualization_jit.py | 15 ++------- scripts/imaging/visualization_jax.py | 25 ++++++--------- .../modeling_visualization_jit.py | 15 ++------- scripts/interferometer/visualization_jax.py | 28 +++++++---------- .../jax_likelihood_functions/ellipse/fit.py | 3 -- .../ellipse/multipoles.py | 3 -- .../ellipse/multipoles_scaled.py | 3 -- .../imaging/delaunay.py | 3 -- .../imaging/delaunay_mge.py | 3 -- .../jax_likelihood_functions/imaging/lp.py | 3 -- .../jax_likelihood_functions/imaging/mge.py | 3 -- .../imaging/mge_group.py | 3 -- .../imaging/rectangular.py | 3 -- .../imaging/rectangular_mge.py | 3 -- .../interferometer/delaunay.py | 3 -- .../interferometer/delaunay_mge.py | 3 -- .../interferometer/lp.py | 3 -- .../interferometer/mge.py | 3 -- .../interferometer/mge_group.py | 3 -- .../interferometer/rectangular.py | 3 -- .../interferometer/rectangular_mge.py | 3 -- .../light_multipole/multipole.py | 3 -- .../multi/dataset_model.py | 3 -- .../multi/delaunay.py | 3 -- .../multi/delaunay_mge.py | 3 -- scripts/jax_likelihood_functions/multi/lp.py | 3 -- scripts/jax_likelihood_functions/multi/mge.py | 3 -- .../multi/mge_group.py | 3 -- .../multi/rectangular.py | 3 -- .../multi/rectangular_mge.py | 3 -- 32 files changed, 43 insertions(+), 164 deletions(-) diff --git a/scripts/ellipse/modeling_visualization_jit.py b/scripts/ellipse/modeling_visualization_jit.py index 58131b0..761af78 100644 --- a/scripts/ellipse/modeling_visualization_jit.py +++ b/scripts/ellipse/modeling_visualization_jit.py @@ -19,9 +19,9 @@ correctly during the live search callback. This script deliberately opts in with -``AnalysisEllipse(use_jax=True, use_jax_for_visualization=True)``. -Default ellipse model-fit scripts elsewhere in the workspace leave both -flags at ``False`` and are therefore untouched by this change. +``AnalysisEllipse(use_jax=True)``. +Default ellipse model-fit scripts elsewhere in the workspace leave the flag +at ``False`` and are therefore untouched by this change. """ import shutil @@ -37,9 +37,7 @@ import autofit as af import autogalaxy as ag -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() """ @@ -92,12 +90,10 @@ model_mge = af.Collection(ellipses=af.Collection(ellipse_0=ellipse_mge)) -register_model(model_mge) analysis_mge = ag.AnalysisEllipse( dataset=dataset, use_jax=True, - use_jax_for_visualization=True, ) instance_mge = model_mge.instance_from_prior_medians() @@ -125,9 +121,6 @@ f"Cached call ({cached_time:.3f}s) not faster than compile " f"({compile_time:.3f}s) — JIT cache is not being hit." ) -assert ( - analysis_mge._jitted_fit_from is not None -), "expected _jitted_fit_from to be cached on the analysis instance after first call" print("PASS: Ellipse jit-cached fit_for_visualization works and is reused.") @@ -177,12 +170,10 @@ model_mge2 = af.Collection(ellipses=af.Collection(ellipse_0=ellipse_2)) -register_model(model_mge2) analysis_mge2 = ag.AnalysisEllipse( dataset=dataset, use_jax=True, - use_jax_for_visualization=True, ) output_root = Path("scripts") / "ellipse" / "images" / "modeling_visualization_jit" diff --git a/scripts/ellipse/visualization_jax.py b/scripts/ellipse/visualization_jax.py index e470b39..7b13c7c 100644 --- a/scripts/ellipse/visualization_jax.py +++ b/scripts/ellipse/visualization_jax.py @@ -2,12 +2,11 @@ Visualization JAX Pilot: Ellipse Analysis (autogalaxy) ====================================================== -Tests that ``VisualizerEllipse.visualize`` with -``use_jax_for_visualization=True`` dispatches through the JIT-cached -``fit_for_visualization`` path that the parent ``af.Analysis`` already -provides. ``AnalysisEllipse.__init__`` passes ``**kwargs`` to its -parent, so ``use_jax_for_visualization=True`` flows through to the -PyAutoFit-level dispatch without a library-side change. +Tests that ``VisualizerEllipse.visualize`` with ``use_jax=True`` dispatches +through the JIT-cached ``fit_for_visualization`` path that the parent +``af.Analysis`` already provides. Visualization follows ``use_jax`` +automatically — ``AnalysisEllipse.__init__`` passes ``**kwargs`` to its +parent without a library-side change needed. Scope ----- @@ -16,8 +15,8 @@ - Calls ``VisualizerEllipse.visualize`` only (not ``visualize_before_fit``). - Reuses the ``dataset/imaging/jax_test`` dataset that the ``jax_likelihood_functions`` scripts produce. -- ``use_jax=True`` turns on the JAX path; ``use_jax_for_visualization=True`` - routes ``Visualizer*.visualize`` through ``analysis.fit_for_visualization``. +- ``use_jax=True`` turns on the JAX path and routes ``Visualizer*.visualize`` + through ``analysis.fit_for_visualization``. """ import shutil @@ -36,10 +35,8 @@ import autofit as af import autogalaxy as ag -from autofit.jax.pytrees import enable_pytrees, register_model from autogalaxy.ellipse.model.visualizer import VisualizerEllipse -enable_pytrees() """ @@ -85,22 +82,20 @@ model = af.Collection(ellipses=af.Collection(ellipse_0=ellipse)) -register_model(model) """ __Analysis__ -``use_jax=True`` turns on the JAX path; ``use_jax_for_visualization=True`` -tells the visualizer to dispatch through the JIT-cached -``fit_for_visualization`` helper on the parent ``af.Analysis``. -``AnalysisEllipse.__init__`` accepts ``**kwargs`` and forwards them to -``super().__init__``, so no AnalysisEllipse signature change is needed. +``use_jax=True`` turns on the JAX path. Visualization follows ``use_jax`` +automatically via the JIT-cached ``fit_for_visualization`` helper on the +parent ``af.Analysis``. ``AnalysisEllipse.__init__`` accepts ``**kwargs`` +and forwards them to ``super().__init__``, so no AnalysisEllipse signature +change is needed. """ analysis = ag.AnalysisEllipse( dataset=dataset, use_jax=True, - use_jax_for_visualization=True, title_prefix="JAX_PILOT", ) @@ -122,7 +117,7 @@ """ instance = model.instance_from_prior_medians() -print("Running VisualizerEllipse.visualize with use_jax_for_visualization=True ...") +print("Running VisualizerEllipse.visualize with use_jax=True ...") VisualizerEllipse.visualize( analysis=analysis, paths=paths, diff --git a/scripts/imaging/modeling_visualization_jit.py b/scripts/imaging/modeling_visualization_jit.py index e5fed30..2612f03 100644 --- a/scripts/imaging/modeling_visualization_jit.py +++ b/scripts/imaging/modeling_visualization_jit.py @@ -22,9 +22,9 @@ during the live search callback. This script deliberately opts in with -``AnalysisImaging(use_jax=True, use_jax_for_visualization=True)``. Default -model-fit scripts elsewhere in the workspace leave both flags at ``False`` -and are therefore untouched by this change. +``AnalysisImaging(use_jax=True)``. Default model-fit scripts elsewhere in the +workspace leave the flag at ``False`` and are therefore untouched by this +change. """ import shutil @@ -38,9 +38,7 @@ import autofit as af import autogalaxy as ag -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() """ @@ -108,12 +106,10 @@ model_mge = af.Collection(galaxies=af.Collection(galaxy=galaxy_mge)) -register_model(model_mge) analysis_mge = ag.AnalysisImaging( dataset=dataset, use_jax=True, - use_jax_for_visualization=True, ) instance_mge = model_mge.instance_from_prior_medians() @@ -141,9 +137,6 @@ f"Cached call ({cached_time:.3f}s) not faster than compile " f"({compile_time:.3f}s) — JIT cache is not being hit." ) -assert ( - analysis_mge._jitted_fit_from is not None -), "expected _jitted_fit_from to be cached on the analysis instance after first call" print("PASS: MGE jit-cached fit_for_visualization works and is reused.") @@ -213,12 +206,10 @@ model_mge2 = af.Collection(galaxies=af.Collection(galaxy=galaxy_mge2)) -register_model(model_mge2) analysis_mge2 = ag.AnalysisImaging( dataset=dataset, use_jax=True, - use_jax_for_visualization=True, ) output_root = Path("scripts") / "imaging" / "images" / "modeling_visualization_jit" diff --git a/scripts/imaging/visualization_jax.py b/scripts/imaging/visualization_jax.py index 30157e3..a1f28fc 100644 --- a/scripts/imaging/visualization_jax.py +++ b/scripts/imaging/visualization_jax.py @@ -7,13 +7,13 @@ Goal ---- -Run ``VisualizerImaging.visualize`` with JAX enabled end-to-end, gated behind -``use_jax_for_visualization=True`` on ``Analysis``. After PyAutoGalaxy #390 -(2026-05-08) the imaging visualizer dispatches through -``analysis.fit_for_visualization``, which lazily wraps ``fit_from`` in -``jax.jit``. To trace across that boundary the model and fit return type -must be JAX pytrees, so this script enables pytree registration before -constructing the model. Parametric MGE galaxy — simplest case (no +Run ``VisualizerImaging.visualize`` with JAX enabled end-to-end via +``use_jax=True`` on ``Analysis``. After PyAutoGalaxy #390 (2026-05-08) the +imaging visualizer dispatches through ``analysis.fit_for_visualization``, +which lazily wraps ``fit_from`` in ``jax.jit``. Visualization now follows +``use_jax`` automatically. To trace across that boundary the model and fit +return type must be JAX pytrees, so this script enables pytree registration +before constructing the model. Parametric MGE galaxy — simplest case (no pixelization, no inversion). Scope @@ -39,10 +39,8 @@ import autofit as af import autogalaxy as ag -from autofit.jax.pytrees import enable_pytrees, register_model from autogalaxy.imaging.model.visualizer import VisualizerImaging -enable_pytrees() """ @@ -87,20 +85,17 @@ galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=galaxy_bulge) model = af.Collection(galaxies=af.Collection(galaxy=galaxy)) -register_model(model) """ __Analysis__ -``use_jax=True`` turns on the JAX ``_xp`` path; ``use_jax_for_visualization=True`` -tells the search-level visualization path to wrap ``fit_from`` in ``jax.jit`` -via the ``Analysis.fit_for_visualization`` helper. +``use_jax=True`` turns on the JAX ``_xp`` path. Visualization now follows +``use_jax`` automatically via the ``Analysis.fit_for_visualization`` helper. """ analysis = ag.AnalysisImaging( dataset=dataset, use_jax=True, - use_jax_for_visualization=True, title_prefix="JAX_PILOT", ) @@ -122,7 +117,7 @@ """ instance = model.instance_from_prior_medians() -print("Running VisualizerImaging.visualize with use_jax_for_visualization=True ...") +print("Running VisualizerImaging.visualize with use_jax=True ...") VisualizerImaging.visualize( analysis=analysis, paths=paths, diff --git a/scripts/interferometer/modeling_visualization_jit.py b/scripts/interferometer/modeling_visualization_jit.py index b5b0a62..b9d21bc 100644 --- a/scripts/interferometer/modeling_visualization_jit.py +++ b/scripts/interferometer/modeling_visualization_jit.py @@ -24,9 +24,9 @@ during the live search callback. This script deliberately opts in with -``AnalysisInterferometer(use_jax=True, use_jax_for_visualization=True)``. -Default model-fit scripts elsewhere in the workspace leave both flags at -``False`` and are therefore untouched by this change. +``AnalysisInterferometer(use_jax=True)``. Default model-fit scripts elsewhere +in the workspace leave the flag at ``False`` and are therefore untouched by +this change. """ import shutil @@ -40,9 +40,7 @@ import autofit as af import autogalaxy as ag -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() """ @@ -114,12 +112,10 @@ model_mge = af.Collection(galaxies=af.Collection(galaxy=galaxy_mge)) -register_model(model_mge) analysis_mge = ag.AnalysisInterferometer( dataset=dataset, use_jax=True, - use_jax_for_visualization=True, ) instance_mge = model_mge.instance_from_prior_medians() @@ -147,9 +143,6 @@ f"Cached call ({cached_time:.3f}s) not faster than compile " f"({compile_time:.3f}s) — JIT cache is not being hit." ) -assert ( - analysis_mge._jitted_fit_from is not None -), "expected _jitted_fit_from to be cached on the analysis instance after first call" print("PASS: MGE jit-cached fit_for_visualization works and is reused.") @@ -224,12 +217,10 @@ model_mge2 = af.Collection(galaxies=af.Collection(galaxy=galaxy_mge2)) -register_model(model_mge2) analysis_mge2 = ag.AnalysisInterferometer( dataset=dataset, use_jax=True, - use_jax_for_visualization=True, ) output_root = ( diff --git a/scripts/interferometer/visualization_jax.py b/scripts/interferometer/visualization_jax.py index 4b56b55..a7300dd 100644 --- a/scripts/interferometer/visualization_jax.py +++ b/scripts/interferometer/visualization_jax.py @@ -7,12 +7,13 @@ Goal ---- -Run ``VisualizerInterferometer.visualize`` with JAX enabled end-to-end, gated -behind ``use_jax_for_visualization=True`` on ``AnalysisInterferometer``. The -interferometer visualizer dispatches through ``analysis.fit_for_visualization``, -which lazily wraps ``fit_from`` in ``jax.jit``. To trace across that boundary -the model and fit return type must be JAX pytrees, so this script enables -pytree registration before constructing the model. +Run ``VisualizerInterferometer.visualize`` with JAX enabled end-to-end via +``use_jax=True`` on ``AnalysisInterferometer``. Visualization now follows +``use_jax`` automatically — the interferometer visualizer dispatches through +``analysis.fit_for_visualization``, which lazily wraps ``fit_from`` in +``jax.jit``. To trace across that boundary the model and fit return type must +be JAX pytrees, so this script enables pytree registration before constructing +the model. Scope ----- @@ -30,10 +31,8 @@ import autofit as af import autogalaxy as ag -from autofit.jax.pytrees import enable_pytrees, register_model from autogalaxy.interferometer.model.visualizer import VisualizerInterferometer -enable_pytrees() """ @@ -75,9 +74,7 @@ """ __Model__ -Single-galaxy MGE parametric model. Pytree registration is required before -constructing the model so that the model and fit return type are valid JAX -pytrees at the ``jax.jit`` boundary. +Single-galaxy MGE parametric model. """ bulge = ag.model_util.mge_model_from( mask_radius=mask_radius, total_gaussians=20, centre_prior_is_uniform=True @@ -85,20 +82,17 @@ galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge) model = af.Collection(galaxies=af.Collection(galaxy=galaxy)) -register_model(model) """ __Analysis__ -``use_jax=True`` turns on the JAX ``_xp`` path; ``use_jax_for_visualization=True`` -tells the search-level visualization path to wrap ``fit_from`` in ``jax.jit`` -via the ``Analysis.fit_for_visualization`` helper. +``use_jax=True`` turns on the JAX ``_xp`` path. Visualization now follows +``use_jax`` automatically via the ``Analysis.fit_for_visualization`` helper. """ analysis = ag.AnalysisInterferometer( dataset=dataset, use_jax=True, - use_jax_for_visualization=True, title_prefix="JAX_PILOT", ) @@ -121,7 +115,7 @@ instance = model.instance_from_prior_medians() print( - "Running VisualizerInterferometer.visualize with use_jax_for_visualization=True ..." + "Running VisualizerInterferometer.visualize with use_jax=True ..." ) VisualizerInterferometer.visualize( analysis=analysis, diff --git a/scripts/jax_likelihood_functions/ellipse/fit.py b/scripts/jax_likelihood_functions/ellipse/fit.py index a3e3e40..7fc8449 100644 --- a/scripts/jax_likelihood_functions/ellipse/fit.py +++ b/scripts/jax_likelihood_functions/ellipse/fit.py @@ -131,10 +131,7 @@ Assert that ``jax.jit(analysis.fit_from)(instance)`` returns a ``FitEllipseSummed`` with a ``jax.Array`` ``log_likelihood`` matching the NumPy-path scalar. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) analysis_jit = ag.AnalysisEllipse(dataset=dataset, use_jax=True) fit_jit_fn = jax.jit(analysis_jit.fit_from) diff --git a/scripts/jax_likelihood_functions/ellipse/multipoles.py b/scripts/jax_likelihood_functions/ellipse/multipoles.py index 3026108..0d2cbcb 100644 --- a/scripts/jax_likelihood_functions/ellipse/multipoles.py +++ b/scripts/jax_likelihood_functions/ellipse/multipoles.py @@ -143,10 +143,7 @@ Assert that ``jax.jit(analysis.fit_from)(instance)`` returns a ``FitEllipseSummed`` with a ``jax.Array`` ``log_likelihood`` matching the NumPy-path scalar. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) analysis_jit = ag.AnalysisEllipse(dataset=dataset, use_jax=True) fit_jit_fn = jax.jit(analysis_jit.fit_from) diff --git a/scripts/jax_likelihood_functions/ellipse/multipoles_scaled.py b/scripts/jax_likelihood_functions/ellipse/multipoles_scaled.py index 164aaf6..10372c8 100644 --- a/scripts/jax_likelihood_functions/ellipse/multipoles_scaled.py +++ b/scripts/jax_likelihood_functions/ellipse/multipoles_scaled.py @@ -145,10 +145,7 @@ Assert that ``jax.jit(analysis.fit_from)(instance)`` returns a ``FitEllipseSummed`` with a ``jax.Array`` ``log_likelihood`` matching the NumPy-path scalar. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) analysis_jit = ag.AnalysisEllipse(dataset=dataset, use_jax=True) fit_jit_fn = jax.jit(analysis_jit.fit_from) diff --git a/scripts/jax_likelihood_functions/imaging/delaunay.py b/scripts/jax_likelihood_functions/imaging/delaunay.py index 5cc8960..0ff2fbf 100644 --- a/scripts/jax_likelihood_functions/imaging/delaunay.py +++ b/scripts/jax_likelihood_functions/imaging/delaunay.py @@ -158,10 +158,7 @@ """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/imaging/delaunay_mge.py b/scripts/jax_likelihood_functions/imaging/delaunay_mge.py index cebfd3e..04feecd 100644 --- a/scripts/jax_likelihood_functions/imaging/delaunay_mge.py +++ b/scripts/jax_likelihood_functions/imaging/delaunay_mge.py @@ -175,10 +175,7 @@ """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/imaging/lp.py b/scripts/jax_likelihood_functions/imaging/lp.py index 4a11956..66a3232 100644 --- a/scripts/jax_likelihood_functions/imaging/lp.py +++ b/scripts/jax_likelihood_functions/imaging/lp.py @@ -104,10 +104,7 @@ with a ``jax.Array`` ``log_likelihood`` matching the NumPy-path scalar. This is the part unblocked by ``_register_fit_imaging_pytrees``. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/imaging/mge.py b/scripts/jax_likelihood_functions/imaging/mge.py index 2a7dead..091810f 100644 --- a/scripts/jax_likelihood_functions/imaging/mge.py +++ b/scripts/jax_likelihood_functions/imaging/mge.py @@ -117,10 +117,7 @@ Assert that ``jax.jit(analysis.fit_from)(instance)`` returns a ``FitImaging`` with a ``jax.Array`` ``log_likelihood`` matching the NumPy-path scalar. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/imaging/mge_group.py b/scripts/jax_likelihood_functions/imaging/mge_group.py index 97a1452..454ea4c 100644 --- a/scripts/jax_likelihood_functions/imaging/mge_group.py +++ b/scripts/jax_likelihood_functions/imaging/mge_group.py @@ -179,10 +179,7 @@ Assert that ``jax.jit(analysis.fit_from)(instance)`` returns a ``FitImaging`` with a ``jax.Array`` ``log_likelihood`` matching the NumPy-path scalar. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/imaging/rectangular.py b/scripts/jax_likelihood_functions/imaging/rectangular.py index 86e3397..3d4b191 100644 --- a/scripts/jax_likelihood_functions/imaging/rectangular.py +++ b/scripts/jax_likelihood_functions/imaging/rectangular.py @@ -127,10 +127,7 @@ """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/imaging/rectangular_mge.py b/scripts/jax_likelihood_functions/imaging/rectangular_mge.py index 2936b26..c16e77b 100644 --- a/scripts/jax_likelihood_functions/imaging/rectangular_mge.py +++ b/scripts/jax_likelihood_functions/imaging/rectangular_mge.py @@ -152,10 +152,7 @@ """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/interferometer/delaunay.py b/scripts/jax_likelihood_functions/interferometer/delaunay.py index dc15bd7..33280a1 100644 --- a/scripts/jax_likelihood_functions/interferometer/delaunay.py +++ b/scripts/jax_likelihood_functions/interferometer/delaunay.py @@ -165,10 +165,7 @@ """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/interferometer/delaunay_mge.py b/scripts/jax_likelihood_functions/interferometer/delaunay_mge.py index d3cef0e..694ecf9 100644 --- a/scripts/jax_likelihood_functions/interferometer/delaunay_mge.py +++ b/scripts/jax_likelihood_functions/interferometer/delaunay_mge.py @@ -175,10 +175,7 @@ """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/interferometer/lp.py b/scripts/jax_likelihood_functions/interferometer/lp.py index af9ea27..20fccc4 100644 --- a/scripts/jax_likelihood_functions/interferometer/lp.py +++ b/scripts/jax_likelihood_functions/interferometer/lp.py @@ -109,10 +109,7 @@ NumPy-path scalar. This is the part unblocked by ``_register_fit_interferometer_pytrees``. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/interferometer/mge.py b/scripts/jax_likelihood_functions/interferometer/mge.py index 0b30e12..5905939 100644 --- a/scripts/jax_likelihood_functions/interferometer/mge.py +++ b/scripts/jax_likelihood_functions/interferometer/mge.py @@ -115,10 +115,7 @@ NumPy-path scalar. This is the part unblocked by ``_register_fit_interferometer_pytrees``. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/interferometer/mge_group.py b/scripts/jax_likelihood_functions/interferometer/mge_group.py index ad28dbe..3579bcc 100644 --- a/scripts/jax_likelihood_functions/interferometer/mge_group.py +++ b/scripts/jax_likelihood_functions/interferometer/mge_group.py @@ -180,10 +180,7 @@ ``FitInterferometer`` with a ``jax.Array`` ``log_likelihood`` matching the NumPy-path scalar. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/interferometer/rectangular.py b/scripts/jax_likelihood_functions/interferometer/rectangular.py index 7db3a89..0acc437 100644 --- a/scripts/jax_likelihood_functions/interferometer/rectangular.py +++ b/scripts/jax_likelihood_functions/interferometer/rectangular.py @@ -132,10 +132,7 @@ """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/interferometer/rectangular_mge.py b/scripts/jax_likelihood_functions/interferometer/rectangular_mge.py index bee6513..84fccdc 100644 --- a/scripts/jax_likelihood_functions/interferometer/rectangular_mge.py +++ b/scripts/jax_likelihood_functions/interferometer/rectangular_mge.py @@ -147,10 +147,7 @@ """ __Path A: jit-wrap ``analysis.fit_from``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/light_multipole/multipole.py b/scripts/jax_likelihood_functions/light_multipole/multipole.py index e43b7f8..fb86064 100644 --- a/scripts/jax_likelihood_functions/light_multipole/multipole.py +++ b/scripts/jax_likelihood_functions/light_multipole/multipole.py @@ -119,10 +119,7 @@ Assert that ``jax.jit(analysis.fit_from)(instance)`` returns a ``FitImaging`` with a ``jax.Array`` ``log_likelihood`` matching the NumPy-path scalar. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(model) instance = model.instance_from_prior_medians() diff --git a/scripts/jax_likelihood_functions/multi/dataset_model.py b/scripts/jax_likelihood_functions/multi/dataset_model.py index 95ba35e..4673d59 100644 --- a/scripts/jax_likelihood_functions/multi/dataset_model.py +++ b/scripts/jax_likelihood_functions/multi/dataset_model.py @@ -167,10 +167,7 @@ """ __Path A: jit-wrap ``factor_graph.log_likelihood_function``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(factor_graph.global_prior_model) analysis_np_list = [ ag.AnalysisImaging(dataset=dataset, use_jax=False) for dataset in dataset_list diff --git a/scripts/jax_likelihood_functions/multi/delaunay.py b/scripts/jax_likelihood_functions/multi/delaunay.py index 7f10a88..19756eb 100644 --- a/scripts/jax_likelihood_functions/multi/delaunay.py +++ b/scripts/jax_likelihood_functions/multi/delaunay.py @@ -206,10 +206,7 @@ drift — same as the single-dataset autogalaxy ``imaging/delaunay.py`` and ``interferometer/delaunay.py``, so the rtol=1e-2 convention applies. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(factor_graph.global_prior_model) analysis_np_list = [ ag.AnalysisImaging( diff --git a/scripts/jax_likelihood_functions/multi/delaunay_mge.py b/scripts/jax_likelihood_functions/multi/delaunay_mge.py index 32b1407..d7fe158 100644 --- a/scripts/jax_likelihood_functions/multi/delaunay_mge.py +++ b/scripts/jax_likelihood_functions/multi/delaunay_mge.py @@ -222,10 +222,7 @@ drift — same as the single-dataset autogalaxy ``imaging/delaunay_mge.py`` and ``interferometer/delaunay_mge.py``, so the rtol=1e-2 convention applies. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(factor_graph.global_prior_model) analysis_np_list = [ ag.AnalysisImaging( diff --git a/scripts/jax_likelihood_functions/multi/lp.py b/scripts/jax_likelihood_functions/multi/lp.py index 92c339b..fe3f296 100644 --- a/scripts/jax_likelihood_functions/multi/lp.py +++ b/scripts/jax_likelihood_functions/multi/lp.py @@ -157,10 +157,7 @@ on the instance, and JAX pytree-flattens the whole instance and chokes on that non-registered leaf. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(factor_graph.global_prior_model) analysis_np_list = [ ag.AnalysisImaging(dataset=dataset, use_jax=False) for dataset in dataset_list diff --git a/scripts/jax_likelihood_functions/multi/mge.py b/scripts/jax_likelihood_functions/multi/mge.py index b0c558b..ad34bee 100644 --- a/scripts/jax_likelihood_functions/multi/mge.py +++ b/scripts/jax_likelihood_functions/multi/mge.py @@ -162,10 +162,7 @@ on the instance, and JAX pytree-flattens the whole instance and chokes on that non-registered leaf. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(factor_graph.global_prior_model) analysis_np_list = [ ag.AnalysisImaging(dataset=dataset, use_jax=False) for dataset in dataset_list diff --git a/scripts/jax_likelihood_functions/multi/mge_group.py b/scripts/jax_likelihood_functions/multi/mge_group.py index 9eeec76..a16bed8 100644 --- a/scripts/jax_likelihood_functions/multi/mge_group.py +++ b/scripts/jax_likelihood_functions/multi/mge_group.py @@ -199,10 +199,7 @@ """ __Path A: jit-wrap ``factor_graph.log_likelihood_function``__ """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(factor_graph.global_prior_model) analysis_np_list = [ ag.AnalysisImaging(dataset=dataset, use_jax=False) for dataset in dataset_list diff --git a/scripts/jax_likelihood_functions/multi/rectangular.py b/scripts/jax_likelihood_functions/multi/rectangular.py index d21b7ac..291e4c2 100644 --- a/scripts/jax_likelihood_functions/multi/rectangular.py +++ b/scripts/jax_likelihood_functions/multi/rectangular.py @@ -192,10 +192,7 @@ single-dataset autogalaxy ``imaging/rectangular.py`` and ``interferometer/rectangular.py``, so the rtol=1e-2 convention applies. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(factor_graph.global_prior_model) analysis_np_list = [ ag.AnalysisImaging( diff --git a/scripts/jax_likelihood_functions/multi/rectangular_mge.py b/scripts/jax_likelihood_functions/multi/rectangular_mge.py index 4d95635..7b2f97a 100644 --- a/scripts/jax_likelihood_functions/multi/rectangular_mge.py +++ b/scripts/jax_likelihood_functions/multi/rectangular_mge.py @@ -199,10 +199,7 @@ and ``interferometer/rectangular_mge.py``, so the rtol=1e-2 convention applies. """ -from autofit.jax.pytrees import enable_pytrees, register_model -enable_pytrees() -register_model(factor_graph.global_prior_model) analysis_np_list = [ ag.AnalysisImaging(