Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions autolens/imaging/model/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from autolens.imaging.fit_imaging import FitImaging
from autolens.imaging.plot.fit_imaging_plots import (
subplot_fit,
subplot_fit_quick,
subplot_fit_log10,
subplot_of_planes,
subplot_tracer_from_fit,
Expand Down Expand Up @@ -75,7 +76,16 @@ def should_plot(name):
source_plane_lines, source_plane_line_colors,
)

if should_plot("subplot_fit") or quick_update:
if quick_update:
subplot_fit_quick(
fit, output_path=output_path, output_format=fmt,
image_plane_lines=ip_lines, image_plane_line_colors=ip_colors,
source_plane_lines=sp_lines, source_plane_line_colors=sp_colors,
title_prefix=self.title_prefix,
)
return

if should_plot("subplot_fit"):

if len(fit.tracer.planes) > 2:
for plane_index in plane_indexes_to_plot:
Expand All @@ -94,9 +104,6 @@ def should_plot(name):
title_prefix=self.title_prefix,
)

if quick_update:
return

if plot_setting(section="tracer", name="subplot_tracer"):
subplot_tracer_from_fit(
fit, output_path=output_path, output_format=fmt,
Expand Down
102 changes: 102 additions & 0 deletions autolens/imaging/plot/fit_imaging_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,108 @@ def subplot_fit(
save_figure(fig, path=output_path, filename=f"fit{plane_index_tag}", format=output_format)


def subplot_fit_quick(
fit,
output_path: Optional[str] = None,
output_format: str = None,
colormap: Optional[str] = None,
image_plane_lines=None,
image_plane_line_colors=None,
source_plane_lines=None,
source_plane_line_colors=None,
title_prefix: str = None,
):
"""
Produce a 6-panel quick-update subplot summarising an imaging fit.

Arranges the following panels in a 2 × 3 grid:

* Data
* Model image
* Normalised residual map (symmetric scale)
* Lens-light-subtracted image
* Source model image
* Source plane image (mid zoom)

This is a lighter alternative to :func:`subplot_fit` (12 panels)
intended for the quick-update visualization path during sampling,
where render speed matters more than completeness.

For single-plane tracers the function delegates to
:func:`subplot_fit_x1_plane`.
"""
if len(fit.tracer.planes) == 1:
return subplot_fit_x1_plane(
fit, output_path=output_path,
output_format=output_format, colormap=colormap,
title_prefix=title_prefix,
)

final_plane_index = len(fit.tracer.planes) - 1
source_vmax = _get_source_vmax(fit)

_pf = (lambda t: f"{title_prefix.rstrip()} {t}") if title_prefix else (lambda t: t)
fig, axes = subplots(2, 3, figsize=conf_subplot_figsize(2, 3))
axes_flat = list(axes.flatten())

# Top row: Data, Model Image, Normalized Residual Map
plot_array(
array=fit.data, ax=axes_flat[0], title=_pf("Data"), colormap=colormap,
)

plot_array(
array=fit.model_data, ax=axes_flat[1], title=_pf("Model Image"),
colormap=colormap, lines=image_plane_lines,
line_colors=image_plane_line_colors,
)

norm_resid = fit.normalized_residual_map
_abs_max = _symmetric_vmax(norm_resid)
plot_array(
array=norm_resid, ax=axes_flat[2], title=_pf("Normalized Residual Map"),
colormap=colormap, vmin=-_abs_max, vmax=_abs_max,
)

# Bottom row: Lens Light Subtracted, Source Model Image, Source Plane (Mid Zoom)
try:
subtracted_img = fit.subtracted_images_of_planes_list[final_plane_index]
except (IndexError, AttributeError):
subtracted_img = None
if subtracted_img is not None:
plot_array(
array=subtracted_img, ax=axes_flat[3],
title=_pf("Lens Light Subtracted"), colormap=colormap,
vmin=0.0 if source_vmax is not None else None, vmax=source_vmax,
)
else:
axes_flat[3].axis("off")

try:
source_model_img = fit.model_images_of_planes_list[final_plane_index]
except (IndexError, AttributeError):
source_model_img = None
if source_model_img is not None:
plot_array(
array=source_model_img, ax=axes_flat[4],
title=_pf("Source Model Image"), colormap=colormap,
vmax=source_vmax, lines=image_plane_lines,
line_colors=image_plane_line_colors,
)
else:
axes_flat[4].axis("off")

_plot_source_plane(
fit, axes_flat[5], final_plane_index, zoom_to_brightest=True,
colormap=colormap, title=_pf("Source Plane (Mid Zoom)"),
lines=source_plane_lines, line_colors=source_plane_line_colors,
vmax=source_vmax, zoom_extent_scale=2.0,
)

hide_unused_axes(axes_flat)
tight_layout()
save_figure(fig, path=output_path, filename="fit_quick", format=output_format, dpi=200)


def subplot_fit_x1_plane(
fit,
output_path: Optional[str] = None,
Expand Down
72 changes: 72 additions & 0 deletions autolens/lens/substructure_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,75 @@ def simulate_substructure(
image_2d = image_2d - background_sky_level

return image_2d


def los_realizations_to_arrays(
realization_galaxies,
plane_redshifts,
max_n,
profile_cls,
):
import jax.numpy as jnp

all_params = []
all_masks = []
all_kappas = []

for galaxies in realization_galaxies:
params, mask, kappas = galaxies_to_halo_arrays(
galaxies=galaxies,
plane_redshifts=plane_redshifts,
max_n=max_n,
profile_cls=profile_cls,
)
all_params.append(params)
all_masks.append(mask)
all_kappas.append(kappas)

return jnp.stack(all_params), jnp.stack(all_masks), jnp.stack(all_kappas)


def batched_simulate_substructure(
grid,
image_shape,
halo_params_batch,
halo_mask_batch,
scaling_matrix,
macro_deflections_fn,
macro_plane_mask,
sheet_kappas_batch,
source_image_fn,
psf_kernel,
exposure_time,
background_sky_level,
prng_keys,
halo_profile_cls,
):
import jax
import functools

single_fn = functools.partial(
simulate_substructure,
grid=grid,
image_shape=image_shape,
scaling_matrix=scaling_matrix,
macro_deflections_fn=macro_deflections_fn,
macro_plane_mask=macro_plane_mask,
source_image_fn=source_image_fn,
psf_kernel=psf_kernel,
exposure_time=exposure_time,
background_sky_level=background_sky_level,
halo_profile_cls=halo_profile_cls,
)

def call(halo_params, halo_mask, sheet_kappas, prng_key):
return single_fn(
halo_params=halo_params,
halo_mask=halo_mask,
sheet_kappas=sheet_kappas,
prng_key=prng_key,
)

return jax.vmap(call)(
halo_params_batch, halo_mask_batch, sheet_kappas_batch, prng_keys,
)
Loading