Skip to content

Commit c525364

Browse files
authored
Merge pull request #364 from Jammy2211/feature/jax_simplify_visualization
feature/jax_simplify_visualization
2 parents 3084bfd + a147da3 commit c525364

36 files changed

Lines changed: 583 additions & 978 deletions

README.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@ PyAutoLens: Open-Source Strong Lensing
2626
.. |arXiv| image:: https://img.shields.io/badge/arXiv-1708.07377-blue
2727
:target: https://arxiv.org/abs/1708.07377
2828

29+
.. image:: https://www.repostatus.org/badges/latest/active.svg
30+
:target: https://www.repostatus.org/#active
31+
:alt: Project Status: Active
32+
33+
.. image:: https://img.shields.io/pypi/pyversions/autolens
34+
:target: https://pypi.org/project/autolens/
35+
:alt: Python Versions
36+
37+
.. image:: https://img.shields.io/pypi/v/autolens.svg
38+
:target: https://pypi.org/project/autolens/
39+
:alt: PyPI Version
40+
2941
|binder| |RTD| |Tests| |Build| |code-style| |JOSS| |arXiv|
3042

3143
`Installation Guide <https://pyautolens.readthedocs.io/en/latest/installation/overview.html>`_ |

autolens/analysis/plotter_interface.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import ast
22
import numpy as np
3+
from typing import Optional
34

45
from autoconf import conf
56
from autoconf.fitsable import hdu_list_for_output_from
@@ -31,7 +32,12 @@ class PlotterInterface(AgPlotterInterface):
3132
The path on the hard-disk to the `image` folder of the non-linear searches results.
3233
"""
3334

34-
def tracer(self, tracer: Tracer, grid: aa.type.Grid2DLike):
35+
def tracer(
36+
self,
37+
tracer: Tracer,
38+
grid: aa.type.Grid2DLike,
39+
visuals_2d_of_planes_list: Optional[aplt.Visuals2D] = None,
40+
):
3541
"""
3642
Visualizes a `Tracer` object.
3743
@@ -63,7 +69,7 @@ def should_plot(name):
6369
tracer=tracer,
6470
grid=grid,
6571
mat_plot_2d=mat_plot_2d,
66-
include_2d=self.include_2d,
72+
visuals_2d_of_planes_list=visuals_2d_of_planes_list,
6773
)
6874

6975
if should_plot("subplot_galaxies_images"):
@@ -169,7 +175,6 @@ def should_plot(name):
169175
image_plotter = aplt.Array2DPlotter(
170176
array=image,
171177
mat_plot_2d=mat_plot_2d,
172-
include_2d=self.include_2d,
173178
visuals_2d=visuals_2d,
174179
)
175180

autolens/imaging/model/plotter_interface.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Optional
22

33
import autoarray.plot as aplt
44

@@ -19,7 +19,7 @@ class PlotterInterfaceImaging(PlotterInterface):
1919
imaging_combined = AgPlotterInterfaceImaging.imaging_combined
2020

2121
def fit_imaging(
22-
self, fit: FitImaging,
22+
self, fit: FitImaging, visuals_2d_of_planes_list : Optional[aplt.Visuals2D] = None
2323
):
2424
"""
2525
Visualizes a `FitImaging` object, which fits an imaging dataset.
@@ -45,7 +45,7 @@ def fit_imaging(
4545
mat_plot_2d = self.mat_plot_2d_from()
4646

4747
fit_plotter = FitImagingPlotter(
48-
fit=fit, mat_plot_2d=mat_plot_2d, include_2d=self.include_2d
48+
fit=fit, mat_plot_2d=mat_plot_2d, visuals_2d_of_planes_list=visuals_2d_of_planes_list,
4949
)
5050

5151
fit_plotter.subplot_tracer()
@@ -56,7 +56,7 @@ def should_plot(name):
5656
mat_plot_2d = self.mat_plot_2d_from()
5757

5858
fit_plotter = FitImagingPlotter(
59-
fit=fit, mat_plot_2d=mat_plot_2d, include_2d=self.include_2d
59+
fit=fit, mat_plot_2d=mat_plot_2d, visuals_2d_of_planes_list=visuals_2d_of_planes_list,
6060
)
6161

6262
plane_indexes_to_plot = [i for i in fit.tracer.plane_indexes_with_images if i != 0]
@@ -72,6 +72,7 @@ def should_plot(name):
7272
fit_plotter.subplot_fit()
7373

7474
if should_plot("subplot_fit_log10"):
75+
7576
try:
7677
if len(fit.tracer.planes) > 2:
7778
for plane_index in plane_indexes_to_plot:
@@ -81,6 +82,7 @@ def should_plot(name):
8182
except ValueError:
8283
pass
8384

85+
8486
if should_plot("subplot_of_planes"):
8587
fit_plotter.subplot_of_planes()
8688

@@ -92,7 +94,7 @@ def should_plot(name):
9294

9395
fits_to_fits(should_plot=should_plot, image_path=self.image_path, fit=fit)
9496

95-
def fit_imaging_combined(self, fit_list: List[FitImaging]):
97+
def fit_imaging_combined(self, fit_list: List[FitImaging], visuals_2d_of_planes_list : Optional[aplt.Visuals2D] = None):
9698
"""
9799
Output visualization of all `FitImaging` objects in a summed combined analysis, typically during or after a
98100
model-fit is performed.
@@ -119,7 +121,7 @@ def should_plot(name):
119121

120122
fit_plotter_list = [
121123
FitImagingPlotter(
122-
fit=fit, mat_plot_2d=mat_plot_2d, include_2d=self.include_2d
124+
fit=fit, mat_plot_2d=mat_plot_2d, visuals_2d_of_planes_list=visuals_2d_of_planes_list,
123125
)
124126
for fit in fit_list
125127
]

autolens/imaging/model/visualizer.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import autogalaxy as ag
55

66
from autolens.imaging.model.plotter_interface import PlotterInterfaceImaging
7+
8+
from autolens.lens import tracer_util
79
from autolens import exc
810

911

@@ -109,15 +111,6 @@ def visualize(
109111
except exc.InversionException:
110112
return
111113

112-
plotter_interface = PlotterInterfaceImaging(
113-
image_path=paths.image_path, title_prefix=analysis.title_prefix
114-
)
115-
116-
try:
117-
plotter_interface.fit_imaging(fit=fit)
118-
except exc.InversionException:
119-
pass
120-
121114
tracer = fit.tracer_linear_light_profiles_to_light_profiles
122115

123116
zoom = ag.Zoom2D(mask=fit.mask)
@@ -127,8 +120,28 @@ def visualize(
127120

128121
grid = ag.Grid2D.from_extent(extent=extent, shape_native=shape_native)
129122

123+
visuals_2d_of_planes_list = tracer_util.visuals_2d_of_planes_list_from(
124+
tracer=fit.tracer,
125+
grid=fit.grids.lp.mask.derive_grid.all_false
126+
)
127+
128+
plotter_interface = PlotterInterfaceImaging(
129+
image_path=paths.image_path,
130+
title_prefix=analysis.title_prefix,
131+
)
132+
133+
try:
134+
plotter_interface.fit_imaging(
135+
fit=fit,
136+
visuals_2d_of_planes_list=visuals_2d_of_planes_list
137+
)
138+
except exc.InversionException:
139+
pass
140+
130141
plotter_interface.tracer(
131-
tracer=tracer, grid=grid,
142+
tracer=tracer,
143+
grid=grid,
144+
visuals_2d_of_planes_list=visuals_2d_of_planes_list
132145
)
133146
plotter_interface.galaxies(
134147
galaxies=tracer.galaxies,

0 commit comments

Comments
 (0)