Skip to content

Commit d5efc44

Browse files
Jammy2211claude
authored andcommitted
Plot improvements: cb_unit for residuals, subplot_tracer in plotter, strip subplot_ from output filenames, image_with_positions title
Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
1 parent e1d4678 commit d5efc44

8 files changed

Lines changed: 63 additions & 52 deletions

File tree

autolens/analysis/plotter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from autolens.lens.tracer import Tracer
1212
from autolens.lens.plot.tracer_plots import (
13+
subplot_tracer,
1314
subplot_galaxies_images,
1415
save_tracer_fits,
1516
save_source_plane_images_fits,
@@ -54,6 +55,14 @@ def should_plot(name):
5455
output_path = str(self.image_path)
5556
fmt = self.fmt
5657

58+
if should_plot("subplot_tracer"):
59+
subplot_tracer(
60+
tracer=tracer,
61+
grid=grid,
62+
output_path=output_path,
63+
output_format=fmt,
64+
)
65+
5766
if should_plot("subplot_galaxies_images"):
5867
subplot_galaxies_images(
5968
tracer=tracer,
@@ -95,6 +104,7 @@ def should_plot(name):
95104
plot_array(
96105
array=image,
97106
positions=[pos_arr],
107+
title="Image With Positions",
98108
output_path=str(self.image_path),
99109
output_filename="image_with_positions",
100110
output_format=fmt,

autolens/imaging/plot/fit_imaging_plots.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import autogalaxy as ag
77

88
from autoarray.plot.array import plot_array, _zoom_array_2d
9-
from autoarray.plot.utils import save_figure, hide_unused_axes
9+
from autoarray.plot.utils import save_figure, hide_unused_axes, conf_subplot_figsize
1010
from autoarray.plot.utils import numpy_lines as _to_lines
1111
from autogalaxy.plot.plot_utils import _critical_curves_from, _caustics_from
1212

@@ -144,7 +144,7 @@ def subplot_fit(
144144

145145
source_vmax = _get_source_vmax(fit)
146146

147-
fig, axes = plt.subplots(3, 4, figsize=(28, 21))
147+
fig, axes = plt.subplots(3, 4, figsize=conf_subplot_figsize(3, 4))
148148
axes_flat = list(axes.flatten())
149149

150150
plot_array(array=fit.data, ax=axes_flat[0], title="Data", colormap=colormap)
@@ -200,23 +200,23 @@ def subplot_fit(
200200
norm_resid = fit.normalized_residual_map
201201
_abs_max = _symmetric_vmax(norm_resid)
202202
plot_array(array=norm_resid, ax=axes_flat[8], title="Normalized Residual Map",
203-
colormap=colormap, vmin=-_abs_max, vmax=_abs_max)
203+
colormap=colormap, vmin=-_abs_max, vmax=_abs_max, cb_unit=r"$\sigma$")
204204

205205
# Normalized residual map clipped to [-1, 1]
206206
plot_array(array=norm_resid, ax=axes_flat[9],
207207
title=r"Normalized Residual Map $1\sigma$",
208-
colormap=colormap, vmin=-1.0, vmax=1.0)
208+
colormap=colormap, vmin=-1.0, vmax=1.0, cb_unit=r"$\sigma$")
209209

210210
plot_array(array=fit.chi_squared_map, ax=axes_flat[10],
211-
title="Chi-Squared Map", colormap=colormap)
211+
title="Chi-Squared Map", colormap=colormap, cb_unit=r"$\chi^2$")
212212

213213
# Source plane not zoomed
214214
_plot_source_plane(fit, axes_flat[11], final_plane_index, zoom_to_brightest=False,
215215
colormap=colormap)
216216

217217
hide_unused_axes(axes_flat)
218218
plt.tight_layout()
219-
save_figure(fig, path=output_path, filename=f"subplot_fit{plane_index_tag}", format=output_format)
219+
save_figure(fig, path=output_path, filename=f"fit{plane_index_tag}", format=output_format)
220220

221221

222222
def subplot_fit_x1_plane(
@@ -252,7 +252,7 @@ def subplot_fit_x1_plane(
252252
colormap : str, optional
253253
Matplotlib colormap name applied to all image panels.
254254
"""
255-
fig, axes = plt.subplots(2, 3, figsize=(21, 14))
255+
fig, axes = plt.subplots(2, 3, figsize=conf_subplot_figsize(2, 3))
256256
axes_flat = list(axes.flatten())
257257

258258
try:
@@ -270,17 +270,17 @@ def subplot_fit_x1_plane(
270270

271271
norm_resid = fit.normalized_residual_map
272272
plot_array(array=norm_resid, ax=axes_flat[3], title="Lens Light Subtracted",
273-
colormap=colormap)
273+
colormap=colormap, cb_unit=r"$\sigma$")
274274

275275
plot_array(array=norm_resid, ax=axes_flat[4], title="Subtracted Image Zero Minimum",
276-
colormap=colormap, vmin=0.0)
276+
colormap=colormap, vmin=0.0, cb_unit=r"$\sigma$")
277277

278278
_abs_max = _symmetric_vmax(norm_resid)
279279
plot_array(array=norm_resid, ax=axes_flat[5], title="Normalized Residual Map",
280-
colormap=colormap, vmin=-_abs_max, vmax=_abs_max)
280+
colormap=colormap, vmin=-_abs_max, vmax=_abs_max, cb_unit=r"$\sigma$")
281281

282282
plt.tight_layout()
283-
save_figure(fig, path=output_path, filename="subplot_fit_x1_plane", format=output_format)
283+
save_figure(fig, path=output_path, filename="fit_x1_plane", format=output_format)
284284

285285

286286
def subplot_fit_log10(
@@ -328,7 +328,7 @@ def subplot_fit_log10(
328328

329329
source_vmax = _get_source_vmax(fit)
330330

331-
fig, axes = plt.subplots(3, 4, figsize=(28, 21))
331+
fig, axes = plt.subplots(3, 4, figsize=conf_subplot_figsize(3, 4))
332332
axes_flat = list(axes.flatten())
333333

334334
plot_array(array=fit.data, ax=axes_flat[0], title="Data", colormap=colormap,
@@ -376,20 +376,20 @@ def subplot_fit_log10(
376376
norm_resid = fit.normalized_residual_map
377377
_abs_max = _symmetric_vmax(norm_resid)
378378
plot_array(array=norm_resid, ax=axes_flat[8], title="Normalized Residual Map",
379-
colormap=colormap, vmin=-_abs_max, vmax=_abs_max)
379+
colormap=colormap, vmin=-_abs_max, vmax=_abs_max, cb_unit=r"$\sigma$")
380380

381381
plot_array(array=norm_resid, ax=axes_flat[9],
382382
title=r"Normalized Residual Map $1\sigma$",
383-
colormap=colormap, vmin=-1.0, vmax=1.0)
383+
colormap=colormap, vmin=-1.0, vmax=1.0, cb_unit=r"$\sigma$")
384384

385385
plot_array(array=fit.chi_squared_map, ax=axes_flat[10], title="Chi-Squared Map",
386-
colormap=colormap, use_log10=True)
386+
colormap=colormap, use_log10=True, cb_unit=r"$\chi^2$")
387387

388388
_plot_source_plane(fit, axes_flat[11], final_plane_index, zoom_to_brightest=False,
389389
colormap=colormap, use_log10=True)
390390

391391
plt.tight_layout()
392-
save_figure(fig, path=output_path, filename=f"subplot_fit_log10{plane_index_tag}", format=output_format)
392+
save_figure(fig, path=output_path, filename=f"fit_log10{plane_index_tag}", format=output_format)
393393

394394

395395
def subplot_fit_log10_x1_plane(
@@ -420,7 +420,7 @@ def subplot_fit_log10_x1_plane(
420420
colormap : str, optional
421421
Matplotlib colormap name applied to all image panels.
422422
"""
423-
fig, axes = plt.subplots(2, 3, figsize=(21, 14))
423+
fig, axes = plt.subplots(2, 3, figsize=conf_subplot_figsize(2, 3))
424424
axes_flat = list(axes.flatten())
425425

426426
try:
@@ -442,15 +442,15 @@ def subplot_fit_log10_x1_plane(
442442

443443
norm_resid = fit.normalized_residual_map
444444
plot_array(array=norm_resid, ax=axes_flat[3], title="Lens Light Subtracted",
445-
colormap=colormap)
445+
colormap=colormap, cb_unit=r"$\sigma$")
446446
_abs_max = _symmetric_vmax(norm_resid)
447447
plot_array(array=norm_resid, ax=axes_flat[4], title="Normalized Residual Map",
448-
colormap=colormap, vmin=-_abs_max, vmax=_abs_max)
448+
colormap=colormap, vmin=-_abs_max, vmax=_abs_max, cb_unit=r"$\sigma$")
449449
plot_array(array=fit.chi_squared_map, ax=axes_flat[5], title="Chi-Squared Map",
450-
colormap=colormap, use_log10=True)
450+
colormap=colormap, use_log10=True, cb_unit=r"$\chi^2$")
451451

452452
plt.tight_layout()
453-
save_figure(fig, path=output_path, filename="subplot_fit_log10", format=output_format)
453+
save_figure(fig, path=output_path, filename="fit_log10", format=output_format)
454454

455455

456456
def subplot_of_planes(
@@ -496,7 +496,7 @@ def subplot_of_planes(
496496
plane_indexes = [plane_index]
497497

498498
for pidx in plane_indexes:
499-
fig, axes = plt.subplots(1, 4, figsize=(28, 7))
499+
fig, axes = plt.subplots(1, 4, figsize=conf_subplot_figsize(1, 4))
500500
axes_flat = list(axes.flatten())
501501

502502
plot_array(array=fit.data, ax=axes_flat[0], title="Data", colormap=colormap)
@@ -518,7 +518,7 @@ def subplot_of_planes(
518518
_plot_source_plane(fit, axes_flat[3], pidx, colormap=colormap)
519519

520520
plt.tight_layout()
521-
save_figure(fig, path=output_path, filename=f"subplot_of_plane_{pidx}", format=output_format)
521+
save_figure(fig, path=output_path, filename=f"of_plane_{pidx}", format=output_format)
522522

523523

524524
def subplot_tracer_from_fit(
@@ -556,7 +556,7 @@ def subplot_tracer_from_fit(
556556
"""
557557
final_plane_index = len(fit.tracer.planes) - 1
558558

559-
fig, axes = plt.subplots(3, 3, figsize=(21, 21))
559+
fig, axes = plt.subplots(3, 3, figsize=conf_subplot_figsize(3, 3))
560560
axes_flat = list(axes.flatten())
561561

562562
tracer = fit.tracer_linear_light_profiles_to_light_profiles
@@ -594,7 +594,7 @@ def subplot_tracer_from_fit(
594594
axes_flat[i].axis("off")
595595

596596
plt.tight_layout()
597-
save_figure(fig, path=output_path, filename="subplot_tracer", format=output_format)
597+
save_figure(fig, path=output_path, filename="tracer", format=output_format)
598598

599599

600600
def subplot_fit_combined(
@@ -633,7 +633,7 @@ def subplot_fit_combined(
633633
"""
634634
n_fits = len(fit_list)
635635
n_cols = 6
636-
fig, axes = plt.subplots(n_fits, n_cols, figsize=(7 * n_cols, 7 * n_fits))
636+
fig, axes = plt.subplots(n_fits, n_cols, figsize=conf_subplot_figsize(n_fits, n_cols))
637637
if n_fits == 1:
638638
all_axes = [list(axes)]
639639
else:
@@ -673,10 +673,10 @@ def subplot_fit_combined(
673673
row_axes[4].axis("off")
674674

675675
plot_array(array=fit.normalized_residual_map, ax=row_axes[5],
676-
title="Normalized Residual Map", colormap=colormap)
676+
title="Normalized Residual Map", colormap=colormap, cb_unit=r"$\sigma$")
677677

678678
plt.tight_layout()
679-
save_figure(fig, path=output_path, filename="subplot_fit_combined", format=output_format)
679+
save_figure(fig, path=output_path, filename="fit_combined", format=output_format)
680680

681681

682682
def subplot_fit_combined_log10(
@@ -707,7 +707,7 @@ def subplot_fit_combined_log10(
707707
"""
708708
n_fits = len(fit_list)
709709
n_cols = 6
710-
fig, axes = plt.subplots(n_fits, n_cols, figsize=(7 * n_cols, 7 * n_fits))
710+
fig, axes = plt.subplots(n_fits, n_cols, figsize=conf_subplot_figsize(n_fits, n_cols))
711711
if n_fits == 1:
712712
all_axes = [list(axes)]
713713
else:
@@ -749,7 +749,7 @@ def subplot_fit_combined_log10(
749749
row_axes[4].axis("off")
750750

751751
plot_array(array=fit.normalized_residual_map, ax=row_axes[5],
752-
title="Normalized Residual Map", colormap=colormap)
752+
title="Normalized Residual Map", colormap=colormap, cb_unit=r"$\sigma$")
753753

754754
plt.tight_layout()
755755
save_figure(fig, path=output_path, filename="fit_combined_log10", format=output_format)

autolens/interferometer/plot/fit_interferometer_plots.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def subplot_fit(
170170
_plot_source_plane(fit, axes_flat[7], final_plane_index, colormap=colormap)
171171

172172
plot_array(array=fit.dirty_normalized_residual_map, ax=axes_flat[8],
173-
title="Dirty Normalized Residual Map", colormap=colormap)
173+
title="Dirty Normalized Residual Map", colormap=colormap, cb_unit=r"$\sigma$")
174174

175175
# Panel 9: clipped to [-1, 1]
176176
plot_array(
@@ -179,18 +179,19 @@ def subplot_fit(
179179
title=r"Normalized Residual Map $1\sigma$",
180180
colormap=colormap,
181181
use_log10=False,
182-
vmin=-1.0, vmax=1.0
182+
vmin=-1.0, vmax=1.0,
183+
cb_unit=r"$\sigma$",
183184
)
184185

185186
plot_array(array=fit.dirty_chi_squared_map, ax=axes_flat[10],
186-
title="Dirty Chi-Squared Map", colormap=colormap)
187+
title="Dirty Chi-Squared Map", colormap=colormap, cb_unit=r"$\chi^2$")
187188

188189
# Panel 11: source plane not zoomed
189190
_plot_source_plane(fit, axes_flat[11], final_plane_index,
190191
zoom_to_brightest=False, colormap=colormap)
191192

192193
plt.tight_layout()
193-
save_figure(fig, path=output_path, filename="subplot_fit", format=output_format)
194+
save_figure(fig, path=output_path, filename="fit", format=output_format)
194195

195196

196197
def subplot_fit_real_space(
@@ -257,4 +258,4 @@ def subplot_fit_real_space(
257258
axes_flat[1].set_title("Source Reconstruction")
258259

259260
plt.tight_layout()
260-
save_figure(fig, path=output_path, filename="subplot_fit_real_space", format=output_format)
261+
save_figure(fig, path=output_path, filename="fit_real_space", format=output_format)

autolens/lens/plot/sensitivity_plots.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def subplot_tracer_images(
112112
colormap=colormap, use_log10=use_log10, lines=no_perturb_cc_lines)
113113

114114
plt.tight_layout()
115-
save_figure(fig, path=output_path, filename="subplot_lensed_images", format=output_format)
115+
save_figure(fig, path=output_path, filename="lensed_images", format=output_format)
116116

117117

118118
def subplot_sensitivity(
@@ -241,7 +241,7 @@ def subplot_sensitivity(
241241
pass
242242

243243
plt.tight_layout()
244-
save_figure(fig, path=output_path, filename="subplot_sensitivity", format=output_format)
244+
save_figure(fig, path=output_path, filename="sensitivity", format=output_format)
245245

246246

247247
def subplot_figures_of_merit_grid(

autolens/lens/plot/subhalo_plots.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def subplot_detection_imaging(
9595
)
9696

9797
plt.tight_layout()
98-
save_figure(fig, path=output_path, filename="subplot_detection_imaging", format=output_format)
98+
save_figure(fig, path=output_path, filename="detection_imaging", format=output_format)
9999

100100

101101
def subplot_detection_fits(
@@ -174,4 +174,4 @@ def subplot_detection_fits(
174174
colormap=colormap)
175175

176176
plt.tight_layout()
177-
save_figure(fig, path=output_path, filename="subplot_detection_fits", format=output_format)
177+
save_figure(fig, path=output_path, filename="detection_fits", format=output_format)

autolens/lens/plot/tracer_plots.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import autogalaxy as ag
77

88
from autoarray.plot.array import plot_array
9-
from autoarray.plot.utils import save_figure, hide_unused_axes
9+
from autoarray.plot.utils import save_figure, hide_unused_axes, conf_subplot_figsize
1010
from autoarray.plot.utils import numpy_lines as _to_lines, numpy_positions as _to_positions
1111
from autogalaxy.plot.plot_utils import _critical_curves_from, _caustics_from
1212

@@ -62,7 +62,7 @@ def subplot_tracer(
6262

6363
magnification = LensCalc.from_mass_obj(tracer).magnification_2d_from(grid=grid)
6464

65-
fig, axes = plt.subplots(3, 3, figsize=(21, 21))
65+
fig, axes = plt.subplots(3, 3, figsize=conf_subplot_figsize(3, 3))
6666
axes_flat = list(axes.flatten())
6767

6868
plot_array(array=image, ax=axes_flat[0], title="Image",
@@ -87,7 +87,7 @@ def subplot_tracer(
8787

8888
hide_unused_axes(axes_flat)
8989
plt.tight_layout()
90-
save_figure(fig, path=output_path, filename="subplot_tracer", format=output_format)
90+
save_figure(fig, path=output_path, filename="tracer", format=output_format)
9191

9292

9393
def subplot_lensed_images(
@@ -126,7 +126,7 @@ def subplot_lensed_images(
126126
traced_grids = tracer.traced_grid_2d_list_from(grid=grid)
127127
n = tracer.total_planes
128128

129-
fig, axes = plt.subplots(1, n, figsize=(7 * n, 7))
129+
fig, axes = plt.subplots(1, n, figsize=conf_subplot_figsize(1, n))
130130
axes_flat = [axes] if n == 1 else list(np.array(axes).flatten())
131131

132132
for plane_index in range(n):
@@ -141,7 +141,7 @@ def subplot_lensed_images(
141141
)
142142

143143
plt.tight_layout()
144-
save_figure(fig, path=output_path, filename="subplot_lensed_images", format=output_format)
144+
save_figure(fig, path=output_path, filename="lensed_images", format=output_format)
145145

146146

147147
def subplot_galaxies_images(
@@ -186,7 +186,7 @@ def subplot_galaxies_images(
186186
traced_grids = tracer.traced_grid_2d_list_from(grid=grid)
187187
n = 2 * tracer.total_planes - 1
188188

189-
fig, axes = plt.subplots(1, n, figsize=(7 * n, 7))
189+
fig, axes = plt.subplots(1, n, figsize=conf_subplot_figsize(1, n))
190190
axes_flat = [axes] if n == 1 else list(np.array(axes).flatten())
191191

192192
idx = 0
@@ -228,7 +228,7 @@ def subplot_galaxies_images(
228228
idx += 1
229229

230230
plt.tight_layout()
231-
save_figure(fig, path=output_path, filename="subplot_galaxies_images", format=output_format)
231+
save_figure(fig, path=output_path, filename="galaxies_images", format=output_format)
232232

233233

234234
def save_tracer_fits(

0 commit comments

Comments
 (0)