Skip to content

Commit cbdbb69

Browse files
authored
Merge pull request #547 from PyAutoLabs/feature/fast-quick-render
Fast subplot_fit_quick: sub-second rendering for quick updates
2 parents 0b45703 + ea3af0a commit cbdbb69

1 file changed

Lines changed: 152 additions & 57 deletions

File tree

autolens/imaging/plot/fit_imaging_plots.py

Lines changed: 152 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,42 @@ def subplot_fit(
358358
save_figure(fig, path=output_path, filename=f"fit{plane_index_tag}", format=output_format)
359359

360360

361+
def _to_native_np(array):
362+
"""Convert an autoarray Array2D to a plain numpy 2D array.
363+
364+
Uses direct boolean-mask indexing instead of the numba-accelerated
365+
``array_2d_via_indexes_from`` which has unexpected overhead (~0.4s
366+
per call for 15k pixels).
367+
"""
368+
try:
369+
mask = array.mask
370+
slim = np.asarray(array.array)
371+
native = np.zeros(mask.shape_native)
372+
native[~np.asarray(mask)] = slim
373+
return native
374+
except AttributeError:
375+
arr = np.asarray(array)
376+
if arr.ndim == 2:
377+
return arr
378+
return arr
379+
380+
381+
def _quick_imshow(ax, array_2d, title, extent, cmap, vmin=None, vmax=None):
382+
"""Minimal imshow for quick-update panels — no overlays, no colorbars."""
383+
import matplotlib.pyplot as plt
384+
385+
if array_2d is None:
386+
ax.axis("off")
387+
return
388+
im = ax.imshow(
389+
array_2d, cmap=cmap, vmin=vmin, vmax=vmax,
390+
extent=extent, aspect="auto", origin="lower",
391+
)
392+
ax.set_title(title, fontsize=8)
393+
ax.set_xticks([])
394+
ax.set_yticks([])
395+
396+
361397
def subplot_fit_quick(
362398
fit,
363399
output_path: Optional[str] = None,
@@ -379,15 +415,18 @@ def subplot_fit_quick(
379415
* Normalised residual map (symmetric scale)
380416
* Lens-light-subtracted image
381417
* Source model image
382-
* Source plane image (mid zoom)
418+
* Source plane image
383419
384-
This is a lighter alternative to :func:`subplot_fit` (12 panels)
385-
intended for the quick-update visualization path during sampling,
386-
where render speed matters more than completeness.
420+
Uses raw ``imshow`` calls on pre-converted numpy arrays to bypass
421+
the autoarray/autogalaxy plotting pipeline — the repeated
422+
``array.native`` conversions in that pipeline cost ~4s for pixelized
423+
source fits. This path renders in <1s.
387424
388425
For single-plane tracers the function delegates to
389426
:func:`subplot_fit_x1_plane`.
390427
"""
428+
import matplotlib.pyplot as plt
429+
391430
if len(fit.tracer.planes) == 1:
392431
return subplot_fit_x1_plane(
393432
fit, output_path=output_path,
@@ -396,68 +435,124 @@ def subplot_fit_quick(
396435
)
397436

398437
final_plane_index = len(fit.tracer.planes) - 1
399-
source_vmax = _get_source_vmax(fit)
400438

401-
_pf = (lambda t: f"{title_prefix.rstrip()} {t}") if title_prefix else (lambda t: t)
402-
fig, axes = subplots(2, 3, figsize=conf_subplot_figsize(2, 3))
403-
axes_flat = list(axes.flatten())
439+
# Pre-convert all arrays to numpy 2D. model_data and
440+
# subtracted_images_of_planes_list are @property (not cached) — each
441+
# access recomputes the entire inversion. Access model_data ONCE and
442+
# derive everything else from the cached per-plane images.
443+
mask_bool = ~np.asarray(fit.mask)
444+
shape_native = fit.mask.shape_native
445+
446+
def _fill(slim):
447+
out = np.zeros(shape_native)
448+
out[mask_bool] = np.asarray(slim)
449+
return out
450+
451+
data_slim = np.asarray(fit.data)
452+
noise_slim = np.asarray(fit.noise_map)
453+
data = _fill(data_slim)
454+
455+
# model_images_of_planes_list triggers model_data once internally
456+
# and caches the per-plane images — much cheaper than accessing
457+
# model_data + subtracted_images separately (each recomputes).
458+
try:
459+
plane_images = fit.model_images_of_planes_list
460+
lens_model_slim = np.asarray(plane_images[0])
461+
source_model_slim = np.asarray(plane_images[final_plane_index])
462+
model_slim = lens_model_slim + source_model_slim
463+
except (IndexError, AttributeError):
464+
model_slim = np.asarray(fit.model_data)
465+
lens_model_slim = None
466+
source_model_slim = None
404467

405-
# Top row: Data, Model Image, Normalized Residual Map
406-
plot_array(
407-
array=fit.data, ax=axes_flat[0], title=_pf("Data"), colormap=colormap,
408-
)
468+
model = _fill(model_slim)
409469

410-
plot_array(
411-
array=fit.model_data, ax=axes_flat[1], title=_pf("Model Image"),
412-
colormap=colormap, lines=image_plane_lines,
413-
line_colors=image_plane_line_colors,
414-
)
470+
# Compute subtracted = data - lens_model (no property access)
471+
if lens_model_slim is not None:
472+
subtracted = _fill(data_slim - lens_model_slim)
473+
else:
474+
subtracted = None
415475

416-
norm_resid = fit.normalized_residual_map
417-
_abs_max = _symmetric_vmax(norm_resid)
418-
plot_array(
419-
array=norm_resid, ax=axes_flat[2], title=_pf("Normalized Residual Map"),
420-
colormap=colormap, vmin=-_abs_max, vmax=_abs_max,
421-
)
476+
source_model = _fill(source_model_slim) if source_model_slim is not None else None
477+
source_vmax = float(np.max(source_model)) if source_model is not None else None
422478

423-
# Bottom row: Lens Light Subtracted, Source Model Image, Source Plane (Mid Zoom)
424-
try:
425-
subtracted_img = fit.subtracted_images_of_planes_list[final_plane_index]
426-
except (IndexError, AttributeError):
427-
subtracted_img = None
428-
if subtracted_img is not None:
429-
plot_array(
430-
array=subtracted_img, ax=axes_flat[3],
431-
title=_pf("Lens Light Subtracted"), colormap=colormap,
432-
vmin=0.0 if source_vmax is not None else None, vmax=source_vmax,
433-
)
434-
else:
435-
axes_flat[3].axis("off")
479+
# Normalized residual from slim arrays (no .native conversion)
480+
resid_slim = data_slim - model_slim
481+
with np.errstate(divide="ignore", invalid="ignore"):
482+
norm_resid_slim = resid_slim / noise_slim
483+
norm_resid_slim = np.where(np.isfinite(norm_resid_slim), norm_resid_slim, 0.0)
484+
norm_resid = _fill(norm_resid_slim)
485+
extent = fit.mask.geometry.extent
436486

437-
try:
438-
source_model_img = fit.model_images_of_planes_list[final_plane_index]
439-
except (IndexError, AttributeError):
440-
source_model_img = None
441-
if source_model_img is not None:
442-
plot_array(
443-
array=source_model_img, ax=axes_flat[4],
444-
title=_pf("Source Model Image"), colormap=colormap,
445-
vmax=source_vmax, lines=image_plane_lines,
446-
line_colors=image_plane_line_colors,
447-
)
448-
else:
449-
axes_flat[4].axis("off")
487+
if colormap is None:
488+
try:
489+
from autoarray.plot.utils import _default_colormap
490+
colormap = _default_colormap()
491+
except Exception:
492+
colormap = "default"
450493

451-
_plot_source_plane(
452-
fit, axes_flat[5], final_plane_index, zoom_to_brightest=True,
453-
colormap=colormap, title=_pf("Source Plane (Mid Zoom)"),
454-
lines=source_plane_lines, line_colors=source_plane_line_colors,
455-
vmax=source_vmax, zoom_extent_scale=2.0,
494+
_pf = (lambda t: f"{title_prefix.rstrip()} {t}") if title_prefix else (lambda t: t)
495+
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
496+
axes_flat = list(axes.flatten())
497+
498+
# Top row: Data, Model Image, Normalized Residual Map
499+
_quick_imshow(axes_flat[0], data, _pf("Data"), extent, colormap)
500+
_quick_imshow(axes_flat[1], model, _pf("Model Image"), extent, colormap)
501+
502+
finite = norm_resid[np.isfinite(norm_resid)]
503+
abs_max = float(np.max(np.abs(finite))) if len(finite) > 0 else 1.0
504+
_quick_imshow(axes_flat[2], norm_resid, _pf("Normalized Residual"),
505+
extent, colormap, vmin=-abs_max, vmax=abs_max)
506+
507+
# Bottom row: Lens Light Subtracted, Source Model Image, Source Plane
508+
_quick_imshow(axes_flat[3], subtracted, _pf("Lens Light Subtracted"),
509+
extent, colormap,
510+
vmin=0.0 if source_vmax else None, vmax=source_vmax)
511+
512+
_quick_imshow(axes_flat[4], source_model, _pf("Source Model Image"),
513+
extent, colormap, vmax=source_vmax)
514+
515+
# Source plane: fast path for parametric, inversion fallback for pixelized
516+
tracer_viz = fit.tracer_linear_light_profiles_to_light_profiles
517+
source_galaxies = tracer_viz.planes[final_plane_index]
518+
has_pixelization = any(
519+
hasattr(g, "pixelization") and g.pixelization is not None
520+
for g in source_galaxies
456521
)
457522

458-
hide_unused_axes(axes_flat)
459-
tight_layout()
460-
save_figure(fig, path=output_path, filename="fit_quick", format=output_format, dpi=200)
523+
if not has_pixelization:
524+
try:
525+
quick_grid = aa.Grid2D.uniform(
526+
shape_native=(50, 50),
527+
pixel_scales=fit.mask.pixel_scales,
528+
origin=fit.mask.origin,
529+
)
530+
source_img = plane_image_from(
531+
galaxies=source_galaxies,
532+
grid=quick_grid,
533+
zoom_to_brightest=False,
534+
)
535+
src_np = _to_native_np(source_img)
536+
_quick_imshow(axes_flat[5], src_np, _pf("Source Plane"),
537+
quick_grid.geometry.extent, colormap, vmax=source_vmax)
538+
except Exception:
539+
axes_flat[5].axis("off")
540+
else:
541+
try:
542+
inversion = fit.inversion
543+
mapper_list = inversion.cls_list_from(cls=Mapper)
544+
mapper = mapper_list[final_plane_index - 1] if final_plane_index > 0 else mapper_list[0]
545+
pixel_values = inversion.reconstruction_dict[mapper]
546+
plot_mapper(
547+
mapper, solution_vector=pixel_values, ax=axes_flat[5],
548+
title=_pf("Source Reconstruction"), colormap=colormap,
549+
vmax=source_vmax, zoom_to_brightest=False,
550+
)
551+
except Exception:
552+
axes_flat[5].axis("off")
553+
554+
fig.tight_layout(pad=0.5)
555+
save_figure(fig, path=output_path, filename="fit_quick", format=output_format, dpi=100)
461556

462557

463558
def subplot_fit_x1_plane(

0 commit comments

Comments
 (0)