2222
2323from typing import Mapping
2424
25- import matplotlib .pyplot as plt
2625import numpy as np
2726from lsst .pex .config import ChoiceField , Field
2827from lsst .pex .config .configurableActions import ConfigurableActionField
29- from lsst .utils .plotting import set_rubin_plotstyle
28+ from lsst .utils .plotting import make_figure , set_rubin_plotstyle
3029from matplotlib .figure import Figure
3130
3231from ...actions .keyedData import CalcCompletenessHistogramAction
3938class CompletenessHist (PlotAction ):
4039 """Makes plots of completeness and purity."""
4140
41+ label_shift = Field [float ](
42+ doc = "Fraction of plot width to shift completeness/purity labels by."
43+ "Ignored if percentiles_style is not 'below_line'" ,
44+ default = - 0.1 ,
45+ )
4246 action = ConfigurableActionField [CalcCompletenessHistogramAction ](
4347 doc = "Action to compute completeness/purity" ,
4448 )
45- mag_ref_label = Field [str ](doc = "Label for the completeness x axis." , default = "Reference magnitude" )
46- mag_target_label = Field [str ](doc = "Label for the purity x axis." , default = "Measured magnitude" )
49+ color_counts = Field [str ](doc = "Color for the line showing object counts" , default = "#029E73" )
50+ color_right = Field [str ](
51+ doc = "Color for the line showing the correctly classified fraction" , default = "#949494"
52+ )
53+ color_wrong = Field [str ](
54+ doc = "Color for the line showing the wrongly classified fraction" , default = "#DE8F05"
55+ )
56+ legendLocation = Field [str ](doc = "Legend position within main plot" , default = "lower left" )
57+ mag_ref_label = Field [str ](doc = "Label for the completeness x axis." , default = "Reference Magnitude" )
58+ mag_target_label = Field [str ](doc = "Label for the purity x axis." , default = "Measured Magnitude" )
4759 percentiles_style = ChoiceField [str ](
4860 doc = "Style and locations for completeness threshold percentile labels" ,
4961 allowed = {
@@ -52,6 +64,8 @@ class CompletenessHist(PlotAction):
5264 },
5365 default = "below_line" ,
5466 )
67+ publicationStyle = Field [bool ](doc = "Make a publication-style of plot" , default = False )
68+ show_purity = Field [bool ](doc = "Whether to include a purity plot below completness" , default = True )
5569
5670 def getInputSchema (self ) -> KeyedDataSchema :
5771 yield from self .action .getOutputSchema ()
@@ -132,10 +146,12 @@ def makePlot(self, data, plotInfo, **kwargs):
132146
133147 # Make plot showing the fraction recovered in magnitude bins
134148 set_rubin_plotstyle ()
135- fig , axes = plt .subplots (dpi = 300 , nrows = 2 , figsize = (8 , 8 ))
136- color_counts = "purple"
137- color_wrong = "firebrick"
138- color_right = "teal"
149+ n_sub = 1 + self .show_purity
150+ fig = make_figure (dpi = 300 , figsize = (8 , 4 * n_sub ))
151+ if self .show_purity :
152+ axes = (fig .add_subplot (2 , 1 , 1 ), fig .add_subplot (2 , 1 , 2 ))
153+ else :
154+ axes = [fig .add_axes ([0.1 , 0.15 , 0.8 , 0.75 ])]
139155 max_left = 1.05
140156
141157 band = kwargs .get ("band" )
@@ -167,28 +183,37 @@ def makePlot(self, data, plotInfo, **kwargs):
167183
168184 counts_all = data [names ["count" ]]
169185
186+ if self .publicationStyle :
187+ lineTuples = (
188+ (data [names ["completeness" ]], False , "k" , "Completeness" ),
189+ (data [names ["completeness_bad_match" ]], False , self .color_wrong , "Incorrect Class" ),
190+ )
191+ else :
192+ lineTuples = (
193+ (data [names ["completeness" ]], True , "k" , "Completeness" ),
194+ (data [names ["completeness_bad_match" ]], False , self .color_wrong , "Incorrect class" ),
195+ (data [names ["completeness_good_match" ]], False , self .color_right , "Correct Class" ),
196+ )
197+
170198 plots = {
171199 "Completeness" : {
172200 "count_type" : "Reference" ,
173201 "counts" : data [names ["count_ref" ]],
174- "lines" : (
175- (data [names ["completeness" ]], True , "k" , "completeness" ),
176- (data [names ["completeness_bad_match" ]], False , color_wrong , "wrong class" ),
177- (data [names ["completeness_good_match" ]], False , color_right , "right class" ),
178- ),
202+ "lines" : lineTuples ,
179203 "xlabel" : self .mag_ref_label ,
180204 },
181- "Purity" : {
205+ }
206+ if self .show_purity :
207+ plots ["Purity" ] = {
182208 "count_type" : "Object" ,
183209 "counts" : data [names ["count_target" ]],
184210 "lines" : (
185- (data [names ["purity" ]], True , "k" , None ),
186- (data [names ["purity_bad_match" ]], False , color_wrong , "wrong class" ),
187- (data [names ["purity_good_match" ]], False , color_right , "right class" ),
211+ (data [names ["purity" ]], True , "k" , "Purity" ),
212+ (data [names ["purity_bad_match" ]], False , self . color_wrong , "Incorrect class" ),
213+ (data [names ["purity_good_match" ]], False , self . color_right , "Correct class" ),
188214 ),
189215 "xlabel" : self .mag_target_label ,
190- },
191- }
216+ }
192217
193218 # idx == 0 should be completeness; update this if that assumption
194219 # is changed
@@ -203,9 +228,10 @@ def makePlot(self, data, plotInfo, **kwargs):
203228 xticks = np .arange (round (xlim [0 ]), round (xlim [1 ])),
204229 yticks = np .linspace (0 , 1 , 11 ),
205230 )
206- axes_idx .grid (color = "lightgrey" , ls = "-" )
231+ if not self .publicationStyle :
232+ axes_idx .grid (color = "lightgrey" , ls = "-" )
207233 ax_right = axes_idx .twinx ()
208- ax_right .set_ylabel (f"{ plot_data ['count_type' ]} counts/mag " )
234+ ax_right .set_ylabel (f"{ plot_data ['count_type' ]} Counts/Magnitude" , color = "k " )
209235 ax_right .set_yscale ("log" )
210236
211237 for y , do_err , color , label in plot_data ["lines" ]:
@@ -214,26 +240,43 @@ def makePlot(self, data, plotInfo, **kwargs):
214240 y = y ,
215241 xerr = x_err if do_err else None ,
216242 yerr = 1.0 / np .sqrt (counts_all + 1 ) if do_err else None ,
243+ capsize = 0 ,
217244 color = color ,
218245 label = label ,
219246 )
220247 y = plot_data ["counts" ] / interval
221248 # It should be unusual for np.max(y) to be zero; nonetheless...
249+ lines_left , labels_left = axes_idx .get_legend_handles_labels ()
222250 ax_right .step (
223251 [x [0 ] - interval ] + list (x ) + [x [- 1 ] + interval ],
224252 [0 ] + list (y ) + [0 ],
225253 where = "mid" ,
226- color = color_counts ,
227- label = "counts " ,
254+ color = self . color_counts ,
255+ label = "Counts " ,
228256 )
257+
258+ # Force the inputs counts histogram to the back
259+ ax_right .zorder = 1
260+ axes_idx .zorder = 2
261+ axes_idx .patch .set_visible (False )
262+
229263 ax_right .set_ylim (0.999 , 10 ** (max_left * np .log10 (max (np .nanmax (y ), 2 ))))
230- ax_right .tick_params (axis = "y" , labelcolor = color_counts )
231- lines_left , labels_left = axes_idx .get_legend_handles_labels ()
264+ ax_right .tick_params (axis = "y" , labelcolor = self .color_counts )
232265 lines_right , labels_right = ax_right .get_legend_handles_labels ()
233- axes_idx .legend (lines_left + lines_right , labels_left + labels_right , loc = "lower left" , ncol = 2 )
266+
267+ # Using fig for legend
268+ (axes_idx if self .show_purity else fig ).legend (
269+ lines_left + lines_right ,
270+ labels_left + labels_right ,
271+ loc = self .legendLocation ,
272+ ncol = 2 ,
273+ )
234274
235275 if idx == 0 :
236- percentiles = self .action .config_metrics .completeness_percentiles
276+ if not self .publicationStyle :
277+ percentiles = self .action .config_metrics .completeness_percentiles
278+ else :
279+ percentiles = [90.0 , 50.0 ]
237280 if percentiles :
238281 above_plot = self .percentiles_style == "above_plot"
239282 below_line = self .percentiles_style == "below_line"
@@ -242,7 +285,7 @@ def makePlot(self, data, plotInfo, **kwargs):
242285 if above_plot :
243286 texts = []
244287 elif below_line :
245- offset = 0.1 * (xlims [1 ] - xlims [0 ])
288+ offset = self . label_shift * (xlims [1 ] - xlims [0 ])
246289 else :
247290 raise RuntimeError (f"Unimplemented { self .percentiles_style = } " )
248291 for pct in percentiles :
@@ -260,13 +303,22 @@ def makePlot(self, data, plotInfo, **kwargs):
260303 if above_plot :
261304 texts .append (text )
262305 elif below_line :
263- axes_idx .text (mag_completeness - offset , pct , text , ha = "right" , va = "top" )
306+ axes_idx .text (
307+ mag_completeness + offset ,
308+ pct - 0.02 ,
309+ text ,
310+ ha = "right" ,
311+ va = "top" ,
312+ fontsize = 12 ,
313+ )
264314 if above_plot :
265315 texts = f"Thresholds: { '; ' .join (texts )} "
266316 axes_idx .text (xlims [0 ], max_left , texts , ha = "left" , va = "bottom" )
267317
268318 # Add useful information to the plot
269- addPlotInfo (fig , plotInfo )
270- fig .tight_layout ()
271- fig .subplots_adjust (top = 0.90 )
319+ if not self .publicationStyle :
320+ addPlotInfo (fig , plotInfo )
321+ if self .show_purity :
322+ fig .tight_layout ()
323+ fig .subplots_adjust (top = 0.90 )
272324 return fig
0 commit comments