-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathinversion.py
More file actions
250 lines (218 loc) · 8.4 KB
/
inversion.py
File metadata and controls
250 lines (218 loc) · 8.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
"""
Standalone functions for plotting inversion / pixelization reconstructions.
Replaces the inversion-specific paths in ``MatPlot2D.plot_mapper``.
"""
from typing import List, Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LogNorm, Normalize
from autoarray.plot.utils import apply_extent, apply_labels, conf_figsize, save_figure
def plot_inversion_reconstruction(
pixel_values: np.ndarray,
mapper,
ax: Optional[plt.Axes] = None,
# --- cosmetics --------------------------------------------------------------
title: str = "Reconstruction",
xlabel: str = 'x (")',
ylabel: str = 'y (")',
colormap: Optional[str] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
use_log10: bool = False,
zoom_to_brightest: bool = True,
# --- overlays ---------------------------------------------------------------
lines: Optional[List[np.ndarray]] = None,
grid: Optional[np.ndarray] = None,
# --- figure control (used only when ax is None) -----------------------------
figsize: Optional[Tuple[int, int]] = None,
output_path: Optional[str] = None,
output_filename: str = "reconstruction",
output_format: str = "png",
) -> None:
"""
Plot an inversion reconstruction using the appropriate mapper type.
Chooses between rectangular (``imshow``/``pcolormesh``) and Delaunay
(``tripcolor``) rendering based on the mapper's interpolator type.
Parameters
----------
pixel_values
1D array of reconstructed flux values, one per source pixel.
mapper
Autoarray mapper object exposing ``interpolator``, ``mesh_geometry``,
``source_plane_mesh_grid``, etc.
ax
Existing ``Axes``. ``None`` creates a new figure.
title, xlabel, ylabel
Text labels.
colormap
Matplotlib colormap name.
vmin, vmax
Explicit colour scale limits.
use_log10
Apply ``LogNorm``.
zoom_to_brightest
Pass through to ``mapper.extent_from``.
lines
Line overlays (e.g. critical curves).
grid
Scatter overlay (e.g. data-plane grid).
figsize, output_path, output_filename, output_format
Figure output controls.
"""
from autoarray.inversion.mesh.interpolator.rectangular import (
InterpolatorRectangular,
)
from autoarray.inversion.mesh.interpolator.rectangular_uniform import (
InterpolatorRectangularUniform,
)
from autoarray.inversion.mesh.interpolator.delaunay import InterpolatorDelaunay
from autoarray.inversion.mesh.interpolator.knn import InterpolatorKNearestNeighbor
if colormap is None:
from autoarray.plot.utils import _default_colormap
colormap = _default_colormap()
owns_figure = ax is None
if owns_figure:
figsize = figsize or conf_figsize("figures")
fig, ax = plt.subplots(1, 1, figsize=figsize)
else:
fig = ax.get_figure()
# --- colour normalisation --------------------------------------------------
if use_log10:
vmin_log = vmin if (vmin is not None and np.isfinite(vmin)) else 1e-4
if vmax is not None and np.isfinite(vmax):
vmax_log = vmax
elif pixel_values is not None:
with np.errstate(all="ignore"):
vmax_log = float(np.nanmax(np.asarray(pixel_values)))
if not np.isfinite(vmax_log) or vmax_log <= vmin_log:
vmax_log = vmin_log * 10.0
else:
vmax_log = vmin_log * 10.0
norm = LogNorm(vmin=vmin_log, vmax=vmax_log)
elif vmin is not None or vmax is not None:
norm = Normalize(vmin=vmin, vmax=vmax)
else:
norm = None
extent = mapper.extent_from(
values=pixel_values, zoom_to_brightest=zoom_to_brightest
)
if isinstance(
mapper.interpolator, (InterpolatorRectangular, InterpolatorRectangularUniform)
):
_plot_rectangular(ax, pixel_values, mapper, norm, colormap, extent)
elif isinstance(
mapper.interpolator, (InterpolatorDelaunay, InterpolatorKNearestNeighbor)
):
_plot_delaunay(ax, pixel_values, mapper, norm, colormap)
# --- overlays --------------------------------------------------------------
if lines is not None:
for line in lines:
if line is not None and len(line) > 0:
ax.plot(line[:, 1], line[:, 0], linewidth=2)
if grid is not None:
ax.scatter(grid[:, 1], grid[:, 0], s=1, c="w", alpha=0.5)
apply_extent(ax, extent)
apply_labels(ax, title=title, xlabel=xlabel, ylabel=ylabel)
if owns_figure:
save_figure(
fig,
path=output_path or "",
filename=output_filename,
format=output_format,
)
def _plot_rectangular(ax, pixel_values, mapper, norm, colormap, extent):
"""Render a rectangular pixelization reconstruction onto *ax*.
Uses ``imshow`` for uniform rectangular grids
(``InterpolatorRectangularUniform``) and ``pcolormesh`` for non-uniform
rectangular grids. Both paths add a colorbar.
Parameters
----------
ax
Matplotlib ``Axes`` to draw onto.
pixel_values
1-D array of reconstructed flux values, one per source pixel.
``None`` renders a zero-filled image.
mapper
Mapper object exposing ``interpolator``, ``mesh_geometry``, and
(for uniform grids) ``pixel_scales`` / ``origin``.
norm
``matplotlib.colors.Normalize`` (or ``LogNorm``) instance, or
``None`` for automatic scaling.
colormap
Matplotlib colormap name.
extent
``[xmin, xmax, ymin, ymax]`` spatial extent; passed to ``imshow``.
"""
from autoarray.inversion.mesh.interpolator.rectangular_uniform import (
InterpolatorRectangularUniform,
)
import numpy as np
shape_native = mapper.mesh_geometry.shape
if pixel_values is None:
pixel_values = np.zeros(shape_native[0] * shape_native[1])
if isinstance(mapper.interpolator, InterpolatorRectangularUniform):
from autoarray.structures.arrays.uniform_2d import Array2D
from autoarray.structures.arrays import array_2d_util
solution_array_2d = array_2d_util.array_2d_native_from(
array_2d_slim=pixel_values,
mask_2d=np.full(fill_value=False, shape=shape_native),
)
pix_array = Array2D.no_mask(
values=solution_array_2d,
pixel_scales=mapper.mesh_geometry.pixel_scales,
origin=mapper.mesh_geometry.origin,
)
im = ax.imshow(
pix_array.native.array,
cmap=colormap,
norm=norm,
extent=pix_array.geometry.extent,
aspect="auto",
origin="upper",
)
from autoarray.plot.utils import _apply_colorbar
_apply_colorbar(im, ax)
else:
y_edges, x_edges = mapper.mesh_geometry.edges_transformed.T
Y, X = np.meshgrid(y_edges, x_edges, indexing="ij")
im = ax.pcolormesh(
X,
Y,
pixel_values.reshape(shape_native),
shading="flat",
norm=norm,
cmap=colormap,
)
from autoarray.plot.utils import _apply_colorbar
_apply_colorbar(im, ax)
def _plot_delaunay(ax, pixel_values, mapper, norm, colormap):
"""Render a Delaunay or KNN pixelization reconstruction onto *ax*.
Uses ``ax.tripcolor`` with Gouraud shading so that the reconstructed
flux is interpolated smoothly across the triangulated source-plane mesh.
A colorbar is attached after rendering.
Parameters
----------
ax
Matplotlib ``Axes`` to draw onto.
pixel_values
1-D array of reconstructed flux values (one per source-plane pixel),
or an autoarray object exposing a ``.array`` attribute.
mapper
Mapper object exposing ``source_plane_mesh_grid`` — an ``(N, 2)``
array of ``(y, x)`` mesh-point coordinates.
norm
``matplotlib.colors.Normalize`` (or ``LogNorm``) instance, or
``None`` for automatic scaling.
colormap
Matplotlib colormap name.
"""
mesh_grid = mapper.source_plane_mesh_grid
x = mesh_grid[:, 1]
y = mesh_grid[:, 0]
if hasattr(pixel_values, "array"):
vals = pixel_values.array
else:
vals = pixel_values
tc = ax.tripcolor(x, y, vals, cmap=colormap, norm=norm, shading="gouraud")
from autoarray.plot.utils import _apply_colorbar
_apply_colorbar(tc, ax)