Skip to content

Commit 1e99735

Browse files
Jammy2211claude
authored andcommitted
Overhaul plot styling and extract fits_* output functions
- All colormap defaults changed from "jet" to None (resolves to the configured default colormap via autoarray) - save_tracer_fits() and save_source_plane_images_fits() extracted to tracer_plots.py; plotter.py reduced to one-liner delegate calls - Unused imports (ast, numpy, conf, hdu_list_for_output_from) removed from plotter.py Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
1 parent be84d80 commit 1e99735

6 files changed

Lines changed: 135 additions & 91 deletions

File tree

autolens/analysis/plotter.py

Lines changed: 7 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
1-
import ast
21
import numpy as np
32
from typing import Optional
43

5-
from autoconf import conf
6-
from autoconf.fitsable import hdu_list_for_output_from
7-
84
import autoarray as aa
95
import autogalaxy as ag
106

@@ -13,7 +9,11 @@
139
from autogalaxy.analysis.plotter import Plotter as AgPlotter
1410

1511
from autolens.lens.tracer import Tracer
16-
from autolens.lens.plot.tracer_plots import subplot_galaxies_images
12+
from autolens.lens.plot.tracer_plots import (
13+
subplot_galaxies_images,
14+
save_tracer_fits,
15+
save_source_plane_images_fits,
16+
)
1717
from autoarray.plot.array import plot_array
1818

1919

@@ -63,72 +63,10 @@ def should_plot(name):
6363
)
6464

6565
if should_plot("fits_tracer"):
66-
67-
zoom = aa.Zoom2D(mask=grid.mask)
68-
mask = zoom.mask_2d_from(buffer=1)
69-
grid_zoom = aa.Grid2D.from_mask(mask=mask)
70-
71-
image_list = [
72-
tracer.convergence_2d_from(grid=grid_zoom).native,
73-
tracer.potential_2d_from(grid=grid_zoom).native,
74-
tracer.deflections_yx_2d_from(grid=grid_zoom).native[:, :, 0],
75-
tracer.deflections_yx_2d_from(grid=grid_zoom).native[:, :, 1],
76-
]
77-
78-
hdu_list = hdu_list_for_output_from(
79-
values_list=[image_list[0].mask.astype("float")] + image_list,
80-
ext_name_list=[
81-
"mask",
82-
"convergence",
83-
"potential",
84-
"deflections_y",
85-
"deflections_x",
86-
],
87-
header_dict=grid_zoom.mask.header_dict,
88-
)
89-
90-
hdu_list.writeto(self.image_path / "tracer.fits", overwrite=True)
66+
save_tracer_fits(tracer=tracer, grid=grid, output_path=self.image_path)
9167

9268
if should_plot("fits_source_plane_images"):
93-
94-
shape_native = conf.instance["visualize"]["plots"]["tracer"][
95-
"fits_source_plane_shape"
96-
]
97-
shape_native = ast.literal_eval(shape_native)
98-
99-
zoom = aa.Zoom2D(mask=grid.mask)
100-
mask = zoom.mask_2d_from(buffer=1)
101-
grid_source_plane = aa.Grid2D.from_extent(
102-
extent=mask.geometry.extent, shape_native=tuple(shape_native)
103-
)
104-
105-
image_list = [grid_source_plane.mask.astype("float")]
106-
ext_name_list = ["mask"]
107-
108-
for i, plane in enumerate(tracer.planes[1:]):
109-
110-
if plane.has(cls=ag.LightProfile):
111-
112-
image = plane.image_2d_from(
113-
grid=grid_source_plane,
114-
).native
115-
116-
else:
117-
118-
image = np.zeros(grid_source_plane.shape_native)
119-
120-
image_list.append(image)
121-
ext_name_list.append(f"source_plane_image_{i+1}")
122-
123-
hdu_list = hdu_list_for_output_from(
124-
values_list=image_list,
125-
ext_name_list=ext_name_list,
126-
header_dict=grid_source_plane.mask.header_dict,
127-
)
128-
129-
hdu_list.writeto(
130-
self.image_path / "source_plane_images.fits", overwrite=True
131-
)
69+
save_source_plane_images_fits(tracer=tracer, grid=grid, output_path=self.image_path)
13270

13371
def image_with_positions(self, image: aa.Array2D, positions: aa.Grid2DIrregular):
13472
"""

autolens/imaging/plot/fit_imaging_plots.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import autogalaxy as ag
77

88
from autoarray.plot.array import plot_array, _zoom_array_2d
9-
from autoarray.plot.utils import save_figure
9+
from autoarray.plot.utils import save_figure, hide_unused_axes
1010
from autoarray.plot.utils import numpy_lines as _to_lines
1111
from autogalaxy.plot.plot_utils import _critical_curves_from, _caustics_from
1212

@@ -39,7 +39,7 @@ def _get_source_vmax(fit):
3939

4040

4141
def _plot_source_plane(fit, ax, plane_index, zoom_to_brightest=True,
42-
colormap="jet", use_log10=False):
42+
colormap=None, use_log10=False):
4343
"""
4444
Plot the source-plane image (or a blank inversion placeholder) into an axes.
4545
@@ -94,7 +94,7 @@ def subplot_fit(
9494
fit,
9595
output_path: Optional[str] = None,
9696
output_format: str = "png",
97-
colormap: str = "jet",
97+
colormap: Optional[str] = None,
9898
plane_index: Optional[int] = None,
9999
):
100100
"""
@@ -214,6 +214,7 @@ def subplot_fit(
214214
_plot_source_plane(fit, axes_flat[11], final_plane_index, zoom_to_brightest=False,
215215
colormap=colormap)
216216

217+
hide_unused_axes(axes_flat)
217218
plt.tight_layout()
218219
save_figure(fig, path=output_path, filename=f"subplot_fit{plane_index_tag}", format=output_format)
219220

@@ -222,7 +223,7 @@ def subplot_fit_x1_plane(
222223
fit,
223224
output_path: Optional[str] = None,
224225
output_format: str = "png",
225-
colormap: str = "jet",
226+
colormap: Optional[str] = None,
226227
):
227228
"""
228229
Produce a 6-panel subplot for a single-plane tracer imaging fit.
@@ -286,7 +287,7 @@ def subplot_fit_log10(
286287
fit,
287288
output_path: Optional[str] = None,
288289
output_format: str = "png",
289-
colormap: str = "jet",
290+
colormap: Optional[str] = None,
290291
plane_index: Optional[int] = None,
291292
):
292293
"""
@@ -395,7 +396,7 @@ def subplot_fit_log10_x1_plane(
395396
fit,
396397
output_path: Optional[str] = None,
397398
output_format: str = "png",
398-
colormap: str = "jet",
399+
colormap: Optional[str] = None,
399400
):
400401
"""
401402
Produce a 6-panel log10 subplot for a single-plane tracer imaging fit.
@@ -456,7 +457,7 @@ def subplot_of_planes(
456457
fit,
457458
output_path: Optional[str] = None,
458459
output_format: str = "png",
459-
colormap: str = "jet",
460+
colormap: Optional[str] = None,
460461
plane_index: Optional[int] = None,
461462
):
462463
"""
@@ -524,7 +525,7 @@ def subplot_tracer_from_fit(
524525
fit,
525526
output_path: Optional[str] = None,
526527
output_format: str = "png",
527-
colormap: str = "jet",
528+
colormap: Optional[str] = None,
528529
):
529530
"""
530531
Produce a 9-panel tracer subplot derived from a `FitImaging` object.
@@ -600,7 +601,7 @@ def subplot_fit_combined(
600601
fit_list: List,
601602
output_path: Optional[str] = None,
602603
output_format: str = "png",
603-
colormap: str = "jet",
604+
colormap: Optional[str] = None,
604605
):
605606
"""
606607
Produce a combined multi-row subplot for a list of `FitImaging` objects.
@@ -682,7 +683,7 @@ def subplot_fit_combined_log10(
682683
fit_list: List,
683684
output_path: Optional[str] = None,
684685
output_format: str = "png",
685-
colormap: str = "jet",
686+
colormap: Optional[str] = None,
686687
):
687688
"""
688689
Produce a combined log10 multi-row subplot for a list of `FitImaging` objects.

autolens/interferometer/plot/fit_interferometer_plots.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def _plot_yx(y, x, ax, title, xlabel="", ylabel=""):
3737

3838

3939
def _plot_source_plane(fit, ax, plane_index, zoom_to_brightest=True,
40-
colormap="jet", use_log10=False):
40+
colormap=None, use_log10=False):
4141
"""
4242
Plot the source-plane image (or a blank inversion placeholder) into an axes.
4343
@@ -88,7 +88,7 @@ def subplot_fit(
8888
fit,
8989
output_path: Optional[str] = None,
9090
output_format: str = "png",
91-
colormap: str = "jet",
91+
colormap: Optional[str] = None,
9292
):
9393
"""
9494
Produce a 12-panel subplot summarising an interferometer fit.
@@ -197,7 +197,7 @@ def subplot_fit_real_space(
197197
fit,
198198
output_path: Optional[str] = None,
199199
output_format: str = "png",
200-
colormap: str = "jet",
200+
colormap: Optional[str] = None,
201201
):
202202
"""
203203
Produce a real-space subplot for an interferometer fit.

autolens/lens/plot/sensitivity_plots.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def subplot_tracer_images(
1616
source_image,
1717
output_path: Optional[str] = None,
1818
output_format: str = "png",
19-
colormap: str = "jet",
19+
colormap: Optional[str] = None,
2020
use_log10: bool = False,
2121
):
2222
"""
@@ -120,7 +120,7 @@ def subplot_sensitivity(
120120
data_subtracted,
121121
output_path: Optional[str] = None,
122122
output_format: str = "png",
123-
colormap: str = "jet",
123+
colormap: Optional[str] = None,
124124
use_log10: bool = False,
125125
):
126126
"""
@@ -248,7 +248,7 @@ def subplot_figures_of_merit_grid(
248248
result,
249249
output_path: Optional[str] = None,
250250
output_format: str = "png",
251-
colormap: str = "jet",
251+
colormap: Optional[str] = None,
252252
use_log_evidences: bool = True,
253253
remove_zeros: bool = True,
254254
):

autolens/lens/plot/subhalo_plots.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def subplot_detection_imaging(
1212
fit_imaging_with_subhalo,
1313
output_path: Optional[str] = None,
1414
output_format: str = "png",
15-
colormap: str = "jet",
15+
colormap: Optional[str] = None,
1616
use_log10: bool = False,
1717
use_log_evidences: bool = True,
1818
relative_to_value: float = 0.0,
@@ -103,7 +103,7 @@ def subplot_detection_fits(
103103
fit_imaging_with_subhalo,
104104
output_path: Optional[str] = None,
105105
output_format: str = "png",
106-
colormap: str = "jet",
106+
colormap: Optional[str] = None,
107107
):
108108
"""
109109
Produce a 6-panel subplot comparing imaging fits with and without a subhalo.

0 commit comments

Comments
 (0)