diff --git a/autolens/interferometer/model/plotter.py b/autolens/interferometer/model/plotter.py index 53caf0572..04c83586d 100644 --- a/autolens/interferometer/model/plotter.py +++ b/autolens/interferometer/model/plotter.py @@ -13,6 +13,7 @@ from autolens.interferometer.plot.fit_interferometer_plots import ( subplot_fit, subplot_fit_dirty_images, + subplot_fit_interferometer_combined, subplot_fit_real_space, subplot_tracer_from_fit, _compute_critical_curve_lines, @@ -106,3 +107,40 @@ def should_plot(name): if should_plot("fits_dirty_images"): fits_dirty_images(fit=fit, output_path=self.image_path) + + def fit_interferometer_combined( + self, + fit_list, + quick_update: bool = False, + ): + """ + Output visualization of all `FitInterferometer` objects in a summed combined + analysis (e.g. an ALMA datacube modelled as a list of channels via + `af.FactorGraphModel`). + + Outputs ``fit_combined.png`` in the visualisation directory: a row-per-channel + subplot showing dirty image, dirty model image, source-plane reconstruction + and dirty normalised residual map for every channel side by side. + + Parameters + ---------- + fit_list + The list of interferometer fits which are visualized. + quick_update + If ``True``, only the combined dirty-image subplot is written (no extra + log-stretched variants), so this is safe to call from the search's + quick-update hook. + """ + def should_plot(name): + return plot_setting(section=["fit", "fit_interferometer"], name=name) + + output_path = str(self.image_path) + fmt = self.fmt + + if should_plot("subplot_fit") or quick_update: + subplot_fit_interferometer_combined( + fit_list, + output_path=output_path, + output_format=fmt, + title_prefix=self.title_prefix, + ) diff --git a/autolens/interferometer/model/visualizer.py b/autolens/interferometer/model/visualizer.py index 7363b411b..3202c60e7 100644 --- a/autolens/interferometer/model/visualizer.py +++ b/autolens/interferometer/model/visualizer.py @@ -161,3 +161,56 @@ def visualize( ) except IndexError: pass + + @staticmethod + def visualize_combined( + analyses, + paths: af.AbstractPaths, + instance: af.ModelInstance, + during_analysis: bool, + quick_update: bool = False, + ): + """ + Performs visualization during the non-linear search of information that is + shared across all per-channel interferometer analyses, on a single multi-row + figure. Used for ALMA-style datacube fits where each channel is its own + ``Interferometer`` dataset wrapped in an ``af.AnalysisFactor``. + + Outputs ``fit_combined.png``: a row-per-channel subplot showing dirty image, + dirty model image, source-plane reconstruction and dirty normalised residual + map. The plot makes it easy to see how an emission line's source-plane + morphology shifts across the cube while the lens model stays fixed. + + Parameters + ---------- + analyses + The list of all per-channel ``AnalysisInterferometer`` objects. + paths + The paths object which manages where visualisation is written to. + instance + A ``Collection`` of per-factor model instances. Iterating it yields one + ``ModelInstance`` per channel, in the same order as ``analyses``. + during_analysis + ``True`` when called during the non-linear search, ``False`` when + called after the search completes. + quick_update + ``True`` when called from the search's quick-update hook between + iterations; only the headline combined plot is written in that case. + """ + + if analyses is None: + return + + plotter = PlotterInterferometer( + image_path=paths.image_path, title_prefix=analyses[0].title_prefix + ) + + fit_list = [ + analysis.fit_for_visualization(instance=single_instance) + for analysis, single_instance in zip(analyses, instance) + ] + + plotter.fit_interferometer_combined( + fit_list=fit_list, + quick_update=quick_update, + ) diff --git a/autolens/interferometer/plot/fit_interferometer_plots.py b/autolens/interferometer/plot/fit_interferometer_plots.py index 8b33ffe9d..7f906dde7 100644 --- a/autolens/interferometer/plot/fit_interferometer_plots.py +++ b/autolens/interferometer/plot/fit_interferometer_plots.py @@ -328,6 +328,102 @@ def subplot_fit_dirty_images( save_figure(fig, path=output_path, filename="fit_dirty_images", format=output_format) +def subplot_fit_interferometer_combined( + fit_list, + output_path: Optional[str] = None, + output_format: str = None, + colormap: Optional[str] = None, + title_prefix: str = None, +): + """ + Produce a combined multi-row subplot for a list of `FitInterferometer` objects. + + Each row corresponds to one channel of a datacube (or one dataset of a multi-band + interferometer fit) and contains four panels: + + * Dirty Image (data) + * Dirty Model Image (with critical curves) + * Source Plane (reconstruction) + * Dirty Normalised Residual Map + + The layout mirrors :func:`subplot_fit_combined` for imaging — same purpose, + different panel choice because interferometer fits are most informatively + visualised in dirty-image space. + + Parameters + ---------- + fit_list : list of FitInterferometer + The interferometer fits to display. Each fit occupies one row of the figure. + output_path : str, optional + Directory in which to save the figure. If ``None`` the figure is not saved. + output_format : str, optional + Image format passed to :func:`~autoarray.plot.utils.save_figure`. + colormap : str, optional + Matplotlib colormap name applied to all image panels. + title_prefix : str, optional + Optional prefix prepended to every panel title. + """ + n_fits = len(fit_list) + n_cols = 4 + fig, axes = subplots(n_fits, n_cols, figsize=conf_subplot_figsize(n_fits, n_cols)) + if n_fits == 1: + all_axes = [list(axes)] + else: + all_axes = [list(axes[i]) for i in range(n_fits)] + + final_plane_index = len(fit_list[0].tracer.planes) - 1 + + _pf = (lambda t: f"{title_prefix.rstrip()} {t}") if title_prefix else (lambda t: t) + for row, fit in enumerate(fit_list): + row_axes = all_axes[row] + + tracer = fit.tracer_linear_light_profiles_to_light_profiles + cc_grid = fit.dataset.real_space_mask.derive_grid.all_false + ip_lines, ip_colors, sp_lines, sp_colors = _compute_critical_curve_lines( + tracer, cc_grid + ) + + plot_array( + array=fit.dirty_image, + ax=row_axes[0], + title=_pf(f"Dirty Image (ch {row})"), + colormap=colormap, + ) + + plot_array( + array=fit.dirty_model_image, + ax=row_axes[1], + title=_pf("Dirty Model Image"), + colormap=colormap, + lines=ip_lines, + line_colors=ip_colors, + ) + + try: + _plot_source_plane( + fit, + row_axes[2], + final_plane_index, + colormap=colormap, + title=_pf(f"Source Plane {final_plane_index}"), + lines=sp_lines, + line_colors=sp_colors, + ) + except Exception: + row_axes[2].axis("off") + + plot_array( + array=fit.dirty_normalized_residual_map, + ax=row_axes[3], + title=_pf("Dirty Norm Residual"), + colormap=colormap, + cb_unit=r"$\sigma$", + ) + + tight_layout() + save_figure(fig, path=output_path, filename="fit_combined", format=output_format) + + def subplot_fit_real_space( fit, output_path: Optional[str] = None,