diff --git a/autolens/imaging/model/plotter.py b/autolens/imaging/model/plotter.py index 061cbad73..a3b7722ae 100644 --- a/autolens/imaging/model/plotter.py +++ b/autolens/imaging/model/plotter.py @@ -13,6 +13,7 @@ from autolens.imaging.fit_imaging import FitImaging from autolens.imaging.plot.fit_imaging_plots import ( subplot_fit, + subplot_fit_quick, subplot_fit_log10, subplot_of_planes, subplot_tracer_from_fit, @@ -75,7 +76,16 @@ def should_plot(name): source_plane_lines, source_plane_line_colors, ) - if should_plot("subplot_fit") or quick_update: + if quick_update: + subplot_fit_quick( + fit, output_path=output_path, output_format=fmt, + image_plane_lines=ip_lines, image_plane_line_colors=ip_colors, + source_plane_lines=sp_lines, source_plane_line_colors=sp_colors, + title_prefix=self.title_prefix, + ) + return + + if should_plot("subplot_fit"): if len(fit.tracer.planes) > 2: for plane_index in plane_indexes_to_plot: @@ -94,9 +104,6 @@ def should_plot(name): title_prefix=self.title_prefix, ) - if quick_update: - return - if plot_setting(section="tracer", name="subplot_tracer"): subplot_tracer_from_fit( fit, output_path=output_path, output_format=fmt, diff --git a/autolens/imaging/plot/fit_imaging_plots.py b/autolens/imaging/plot/fit_imaging_plots.py index b9692cf93..9fd147be6 100644 --- a/autolens/imaging/plot/fit_imaging_plots.py +++ b/autolens/imaging/plot/fit_imaging_plots.py @@ -358,6 +358,108 @@ def subplot_fit( save_figure(fig, path=output_path, filename=f"fit{plane_index_tag}", format=output_format) +def subplot_fit_quick( + fit, + output_path: Optional[str] = None, + output_format: str = None, + colormap: Optional[str] = None, + image_plane_lines=None, + image_plane_line_colors=None, + source_plane_lines=None, + source_plane_line_colors=None, + title_prefix: str = None, +): + """ + Produce a 6-panel quick-update subplot summarising an imaging fit. + + Arranges the following panels in a 2 × 3 grid: + + * Data + * Model image + * Normalised residual map (symmetric scale) + * Lens-light-subtracted image + * Source model image + * Source plane image (mid zoom) + + This is a lighter alternative to :func:`subplot_fit` (12 panels) + intended for the quick-update visualization path during sampling, + where render speed matters more than completeness. + + For single-plane tracers the function delegates to + :func:`subplot_fit_x1_plane`. + """ + if len(fit.tracer.planes) == 1: + return subplot_fit_x1_plane( + fit, output_path=output_path, + output_format=output_format, colormap=colormap, + title_prefix=title_prefix, + ) + + final_plane_index = len(fit.tracer.planes) - 1 + source_vmax = _get_source_vmax(fit) + + _pf = (lambda t: f"{title_prefix.rstrip()} {t}") if title_prefix else (lambda t: t) + fig, axes = subplots(2, 3, figsize=conf_subplot_figsize(2, 3)) + axes_flat = list(axes.flatten()) + + # Top row: Data, Model Image, Normalized Residual Map + plot_array( + array=fit.data, ax=axes_flat[0], title=_pf("Data"), colormap=colormap, + ) + + plot_array( + array=fit.model_data, ax=axes_flat[1], title=_pf("Model Image"), + colormap=colormap, lines=image_plane_lines, + line_colors=image_plane_line_colors, + ) + + norm_resid = fit.normalized_residual_map + _abs_max = _symmetric_vmax(norm_resid) + plot_array( + array=norm_resid, ax=axes_flat[2], title=_pf("Normalized Residual Map"), + colormap=colormap, vmin=-_abs_max, vmax=_abs_max, + ) + + # Bottom row: Lens Light Subtracted, Source Model Image, Source Plane (Mid Zoom) + try: + subtracted_img = fit.subtracted_images_of_planes_list[final_plane_index] + except (IndexError, AttributeError): + subtracted_img = None + if subtracted_img is not None: + plot_array( + array=subtracted_img, ax=axes_flat[3], + title=_pf("Lens Light Subtracted"), colormap=colormap, + vmin=0.0 if source_vmax is not None else None, vmax=source_vmax, + ) + else: + axes_flat[3].axis("off") + + try: + source_model_img = fit.model_images_of_planes_list[final_plane_index] + except (IndexError, AttributeError): + source_model_img = None + if source_model_img is not None: + plot_array( + array=source_model_img, ax=axes_flat[4], + title=_pf("Source Model Image"), colormap=colormap, + vmax=source_vmax, lines=image_plane_lines, + line_colors=image_plane_line_colors, + ) + else: + axes_flat[4].axis("off") + + _plot_source_plane( + fit, axes_flat[5], final_plane_index, zoom_to_brightest=True, + colormap=colormap, title=_pf("Source Plane (Mid Zoom)"), + lines=source_plane_lines, line_colors=source_plane_line_colors, + vmax=source_vmax, zoom_extent_scale=2.0, + ) + + hide_unused_axes(axes_flat) + tight_layout() + save_figure(fig, path=output_path, filename="fit_quick", format=output_format, dpi=200) + + def subplot_fit_x1_plane( fit, output_path: Optional[str] = None, diff --git a/autolens/lens/substructure_util.py b/autolens/lens/substructure_util.py index 564231639..9555b4d69 100644 --- a/autolens/lens/substructure_util.py +++ b/autolens/lens/substructure_util.py @@ -168,3 +168,75 @@ def simulate_substructure( image_2d = image_2d - background_sky_level return image_2d + + +def los_realizations_to_arrays( + realization_galaxies, + plane_redshifts, + max_n, + profile_cls, +): + import jax.numpy as jnp + + all_params = [] + all_masks = [] + all_kappas = [] + + for galaxies in realization_galaxies: + params, mask, kappas = galaxies_to_halo_arrays( + galaxies=galaxies, + plane_redshifts=plane_redshifts, + max_n=max_n, + profile_cls=profile_cls, + ) + all_params.append(params) + all_masks.append(mask) + all_kappas.append(kappas) + + return jnp.stack(all_params), jnp.stack(all_masks), jnp.stack(all_kappas) + + +def batched_simulate_substructure( + grid, + image_shape, + halo_params_batch, + halo_mask_batch, + scaling_matrix, + macro_deflections_fn, + macro_plane_mask, + sheet_kappas_batch, + source_image_fn, + psf_kernel, + exposure_time, + background_sky_level, + prng_keys, + halo_profile_cls, +): + import jax + import functools + + single_fn = functools.partial( + simulate_substructure, + grid=grid, + image_shape=image_shape, + scaling_matrix=scaling_matrix, + macro_deflections_fn=macro_deflections_fn, + macro_plane_mask=macro_plane_mask, + source_image_fn=source_image_fn, + psf_kernel=psf_kernel, + exposure_time=exposure_time, + background_sky_level=background_sky_level, + halo_profile_cls=halo_profile_cls, + ) + + def call(halo_params, halo_mask, sheet_kappas, prng_key): + return single_fn( + halo_params=halo_params, + halo_mask=halo_mask, + sheet_kappas=sheet_kappas, + prng_key=prng_key, + ) + + return jax.vmap(call)( + halo_params_batch, halo_mask_batch, sheet_kappas_batch, prng_keys, + )