Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions autolens/interferometer/model/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from autolens.interferometer.plot.fit_interferometer_plots import (
subplot_fit,
subplot_fit_dirty_images,
subplot_fit_interferometer_combined,
subplot_fit_real_space,
subplot_tracer_from_fit,
_compute_critical_curve_lines,
Expand Down Expand Up @@ -106,3 +107,40 @@ def should_plot(name):

if should_plot("fits_dirty_images"):
fits_dirty_images(fit=fit, output_path=self.image_path)

def fit_interferometer_combined(
self,
fit_list,
quick_update: bool = False,
):
"""
Output visualization of all `FitInterferometer` objects in a summed combined
analysis (e.g. an ALMA datacube modelled as a list of channels via
`af.FactorGraphModel`).

Outputs ``fit_combined.png`` in the visualisation directory: a row-per-channel
subplot showing dirty image, dirty model image, source-plane reconstruction
and dirty normalised residual map for every channel side by side.

Parameters
----------
fit_list
The list of interferometer fits which are visualized.
quick_update
If ``True``, only the combined dirty-image subplot is written (no extra
log-stretched variants), so this is safe to call from the search's
quick-update hook.
"""
def should_plot(name):
return plot_setting(section=["fit", "fit_interferometer"], name=name)

output_path = str(self.image_path)
fmt = self.fmt

if should_plot("subplot_fit") or quick_update:
subplot_fit_interferometer_combined(
fit_list,
output_path=output_path,
output_format=fmt,
title_prefix=self.title_prefix,
)
53 changes: 53 additions & 0 deletions autolens/interferometer/model/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,56 @@ def visualize(
)
except IndexError:
pass

@staticmethod
def visualize_combined(
analyses,
paths: af.AbstractPaths,
instance: af.ModelInstance,
during_analysis: bool,
quick_update: bool = False,
):
"""
Performs visualization during the non-linear search of information that is
shared across all per-channel interferometer analyses, on a single multi-row
figure. Used for ALMA-style datacube fits where each channel is its own
``Interferometer`` dataset wrapped in an ``af.AnalysisFactor``.

Outputs ``fit_combined.png``: a row-per-channel subplot showing dirty image,
dirty model image, source-plane reconstruction and dirty normalised residual
map. The plot makes it easy to see how an emission line's source-plane
morphology shifts across the cube while the lens model stays fixed.

Parameters
----------
analyses
The list of all per-channel ``AnalysisInterferometer`` objects.
paths
The paths object which manages where visualisation is written to.
instance
A ``Collection`` of per-factor model instances. Iterating it yields one
``ModelInstance`` per channel, in the same order as ``analyses``.
during_analysis
``True`` when called during the non-linear search, ``False`` when
called after the search completes.
quick_update
``True`` when called from the search's quick-update hook between
iterations; only the headline combined plot is written in that case.
"""

if analyses is None:
return

plotter = PlotterInterferometer(
image_path=paths.image_path, title_prefix=analyses[0].title_prefix
)

fit_list = [
analysis.fit_for_visualization(instance=single_instance)
for analysis, single_instance in zip(analyses, instance)
]

plotter.fit_interferometer_combined(
fit_list=fit_list,
quick_update=quick_update,
)
96 changes: 96 additions & 0 deletions autolens/interferometer/plot/fit_interferometer_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,102 @@ def subplot_fit_dirty_images(
save_figure(fig, path=output_path, filename="fit_dirty_images", format=output_format)


def subplot_fit_interferometer_combined(
fit_list,
output_path: Optional[str] = None,
output_format: str = None,
colormap: Optional[str] = None,
title_prefix: str = None,
):
"""
Produce a combined multi-row subplot for a list of `FitInterferometer` objects.

Each row corresponds to one channel of a datacube (or one dataset of a multi-band
interferometer fit) and contains four panels:

* Dirty Image (data)
* Dirty Model Image (with critical curves)
* Source Plane (reconstruction)
* Dirty Normalised Residual Map

The layout mirrors :func:`subplot_fit_combined` for imaging — same purpose,
different panel choice because interferometer fits are most informatively
visualised in dirty-image space.

Parameters
----------
fit_list : list of FitInterferometer
The interferometer fits to display. Each fit occupies one row of the figure.
output_path : str, optional
Directory in which to save the figure. If ``None`` the figure is not saved.
output_format : str, optional
Image format passed to :func:`~autoarray.plot.utils.save_figure`.
colormap : str, optional
Matplotlib colormap name applied to all image panels.
title_prefix : str, optional
Optional prefix prepended to every panel title.
"""
n_fits = len(fit_list)
n_cols = 4
fig, axes = subplots(n_fits, n_cols, figsize=conf_subplot_figsize(n_fits, n_cols))
if n_fits == 1:
all_axes = [list(axes)]
else:
all_axes = [list(axes[i]) for i in range(n_fits)]

final_plane_index = len(fit_list[0].tracer.planes) - 1

_pf = (lambda t: f"{title_prefix.rstrip()} {t}") if title_prefix else (lambda t: t)
for row, fit in enumerate(fit_list):
row_axes = all_axes[row]

tracer = fit.tracer_linear_light_profiles_to_light_profiles
cc_grid = fit.dataset.real_space_mask.derive_grid.all_false
ip_lines, ip_colors, sp_lines, sp_colors = _compute_critical_curve_lines(
tracer, cc_grid
)

plot_array(
array=fit.dirty_image,
ax=row_axes[0],
title=_pf(f"Dirty Image (ch {row})"),
colormap=colormap,
)

plot_array(
array=fit.dirty_model_image,
ax=row_axes[1],
title=_pf("Dirty Model Image"),
colormap=colormap,
lines=ip_lines,
line_colors=ip_colors,
)

try:
_plot_source_plane(
fit,
row_axes[2],
final_plane_index,
colormap=colormap,
title=_pf(f"Source Plane {final_plane_index}"),
lines=sp_lines,
line_colors=sp_colors,
)
except Exception:
row_axes[2].axis("off")

plot_array(
array=fit.dirty_normalized_residual_map,
ax=row_axes[3],
title=_pf("Dirty Norm Residual"),
colormap=colormap,
cb_unit=r"$\sigma$",
)

tight_layout()
save_figure(fig, path=output_path, filename="fit_combined", format=output_format)


def subplot_fit_real_space(
fit,
output_path: Optional[str] = None,
Expand Down
Loading