Skip to content

Commit 41263af

Browse files
committed
Add QA plots for source injection to analysis_tools
1 parent e17ae19 commit 41263af

File tree

9 files changed

+589
-1
lines changed

9 files changed

+589
-1
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
description: |
2+
Tier1 plots and metrics to assess injected coadd quality
3+
tasks:
4+
injectedObjectAnalysis:
5+
class: lsst.analysis.tools.tasks.injectedObjectAnalysis.InjectedObjectAnalysisTask
6+
config:
7+
atools.completenessHist: CompletenessPurityTool
8+
atools.astromDiffRAScatterPlot: TargetInjectedCatDeltaRAScatterPlot
9+
atools.astromDiffDecScatterPlot: TargetInjectedCatDeltaDecScatterPlot
10+
atools.astromDiffMetrics: TargetInjectedCatDeltaMetrics
11+
atools.astromDiffMetrics.applyContext: CoaddContext
12+
atools.targetInjectedCatDeltaPsfScatterPlot: TargetInjectedCatDeltaPsfScatterPlot
13+
bands: ["g", "r", "i", "z", "y"]
14+
python: |
15+
from lsst.analysis.tools.atools import *
16+
from lsst.analysis.tools.contexts import *

python/lsst/analysis/tools/actions/plot/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .barPlots import *
22
from .calculateRange import *
33
from .colorColorFitPlot import *
4+
from .completenessPlot import *
45
from .diaSkyPlot import *
56
from .focalPlanePlot import *
67
from .gridPlot import *
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# This file is part of analysis_tools.
2+
#
3+
# Developed for the LSST Data Management System.
4+
# This product includes software developed by the LSST Project
5+
# (https://www.lsst.org).
6+
# See the COPYRIGHT file at the top-level directory of this distribution
7+
# for details of code ownership.
8+
#
9+
# This program is free software: you can redistribute it and/or modify
10+
# it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 3 of the License, or
12+
# (at your option) any later version.
13+
#
14+
# This program is distributed in the hope that it will be useful,
15+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17+
# GNU General Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License
20+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
21+
22+
23+
from typing import Mapping
24+
25+
import matplotlib.pyplot as plt
26+
import numpy as np
27+
from lsst.pex.config import Field
28+
from matplotlib.figure import Figure
29+
30+
from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector
31+
from .plotUtils import addPlotInfo
32+
33+
__all__ = ("CompletenessHist",)
34+
35+
36+
class CompletenessHist(PlotAction):
37+
"""Makes a scatter plot of the data with a marginal
38+
histogram for each axis.
39+
"""
40+
41+
magKey = Field[str](doc="Name of the magnitude column.", default="mag")
42+
matchDistanceKey = Field[str](doc="Name of the match distance column.", default="matchDistance")
43+
xAxisLabel = Field[str](doc="Label for the x axis.", default="Input Magnitude (mag)")
44+
inputLabel = Field[str](doc="Label for the input source histogram.", default="Synthetic Inputs")
45+
outputLabel = Field[str](doc="Label for the recovered source histogram.", default="Synthetic Recovered")
46+
numBins = Field[int](doc="Number of bins to use for the histograms.", default=100)
47+
48+
def getInputSchema(self) -> KeyedDataSchema:
49+
base: list[tuple[str, type[Vector] | ScalarType]] = []
50+
base.append((self.magKey, Vector))
51+
base.append((self.matchDistanceKey, Vector))
52+
return base
53+
54+
def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
55+
self._validateInput(data, **kwargs)
56+
return self.makePlot(data, **kwargs)
57+
58+
def _validateInput(self, data: KeyedData, **kwargs) -> None:
59+
"""NOTE currently can only check that something is not a Scalar, not
60+
check that the data is consistent with Vector
61+
"""
62+
needed = self.getFormattedInputSchema(**kwargs)
63+
if remainder := {key.format(**kwargs) for key, _ in needed} - {
64+
key.format(**kwargs) for key in data.keys()
65+
}:
66+
raise ValueError(f"Task needs keys {remainder} but they were not found in input")
67+
for name, typ in needed:
68+
isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
69+
if isScalar and typ != Scalar:
70+
raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
71+
72+
def makePlot(self, data, plotInfo, **kwargs):
73+
"""Makes a plot showing the fraction of injected sources recovered by
74+
input magnitude.
75+
76+
Parameters
77+
----------
78+
data : `KeyedData`
79+
All the data
80+
plotInfo : `dict`
81+
A dictionary of information about the data being plotted with keys:
82+
``camera``
83+
The camera used to take the data (`lsst.afw.cameraGeom.Camera`)
84+
``"cameraName"``
85+
The name of camera used to take the data (`str`).
86+
``"filter"``
87+
The filter used for this data (`str`).
88+
``"ccdKey"``
89+
The ccd/dectector key associated with this camera (`str`).
90+
``"visit"``
91+
The visit of the data; only included if the data is from a
92+
single epoch dataset (`str`).
93+
``"patch"``
94+
The patch that the data is from; only included if the data is
95+
from a coadd dataset (`str`).
96+
``"tract"``
97+
The tract that the data comes from (`str`).
98+
``"photoCalibDataset"``
99+
The dataset used for the calibration, e.g. "jointcal" or "fgcm"
100+
(`str`).
101+
``"skyWcsDataset"``
102+
The sky Wcs dataset used (`str`).
103+
``"rerun"``
104+
The rerun the data is stored in (`str`).
105+
106+
Returns
107+
------
108+
``fig``
109+
The figure to be saved (`matplotlib.figure.Figure`).
110+
111+
Notes
112+
-----
113+
Makes a histogram showing the fraction recovered in each magnitude
114+
bin with the number input and recovered overplotted.
115+
"""
116+
117+
# Make plot showing the fraction recovered in magnitude bins
118+
fig, axLeft = plt.subplots(dpi=300)
119+
axLeft.tick_params(axis="y", labelcolor="C0")
120+
axLeft.set_xlabel(self.xAxisLabel)
121+
axLeft.set_ylabel("Fraction Recovered", color="C0")
122+
axRight = axLeft.twinx()
123+
axRight.set_ylabel("Number of Sources")
124+
matched = np.isfinite(data[self.matchDistanceKey])
125+
nInput, bins, _ = axRight.hist(
126+
data[self.magKey],
127+
range=(np.nanmin(data[self.magKey]), np.nanmax(data[self.magKey])),
128+
bins=self.numBins,
129+
log=True,
130+
histtype="step",
131+
label=self.inputLabel,
132+
color="black",
133+
)
134+
nOutput, _, _ = axRight.hist(
135+
data[self.magKey][matched],
136+
range=(np.nanmin(data[self.magKey][matched]), np.nanmax(data[self.magKey][matched])),
137+
bins=bins,
138+
log=True,
139+
histtype="step",
140+
label=self.outputLabel,
141+
color="grey",
142+
)
143+
xlims = plt.gca().get_xlim()
144+
# TODO: put a box in the bottom corner for all the percentiles
145+
# Find bin where the fraction recovered first falls below 0.5
146+
lessThanHalf = np.where((nOutput / nInput < 0.5))[0]
147+
if len(lessThanHalf) == 0:
148+
mag50 = np.nan
149+
else:
150+
mag50 = np.min(bins[lessThanHalf])
151+
axLeft.plot([xlims[0], mag50], [0.5, 0.5], ls=":", color="grey")
152+
axLeft.plot([mag50, mag50], [0, 0.5], ls=":", color="grey")
153+
plt.xlim(xlims)
154+
axLeft.set_ylim(0, 1.05)
155+
axRight.legend(loc="lower left", ncol=2)
156+
axLeft.axhline(1, color="grey", ls="--")
157+
axLeft.bar(
158+
bins[:-1],
159+
nOutput / nInput,
160+
width=np.diff(bins),
161+
align="edge",
162+
color="C0",
163+
alpha=0.5,
164+
zorder=10,
165+
)
166+
bboxDict = dict(boxstyle="round", facecolor="white", alpha=0.75)
167+
168+
info50 = "Magnitude at 50% recovered: {:0.2f}".format(mag50)
169+
axLeft.text(0.3, 0.2, info50, transform=fig.transFigure, bbox=bboxDict, zorder=11)
170+
171+
# Add useful information to the plot
172+
fig = plt.gcf()
173+
addPlotInfo(fig, plotInfo)
174+
return fig

python/lsst/analysis/tools/actions/scalar/scalarActions.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,20 @@
4040
"IqrHistAction",
4141
"DivideScalar",
4242
"RmsAction",
43+
"MagPercentileAction",
4344
)
4445

4546
import operator
4647
from math import nan
4748
from typing import cast
4849

4950
import numpy as np
51+
from astropy import units as u
5052
from lsst.pex.config import ChoiceField, Field
5153
from lsst.pex.config.configurableActions import ConfigurableActionField
5254

5355
from ...interfaces import KeyedData, KeyedDataSchema, Scalar, ScalarAction, Vector
54-
from ...math import nanMax, nanMean, nanMedian, nanMin, nanSigmaMad, nanStd
56+
from ...math import fluxToMag, isPercent, nanMax, nanMean, nanMedian, nanMin, nanSigmaMad, nanStd
5557

5658

5759
class ScalarFromVectorAction(ScalarAction):
@@ -434,3 +436,39 @@ def __call__(self, data: KeyedData, **kwargs) -> Scalar:
434436
if scalarB == 0:
435437
raise ValueError("Denominator is zero!")
436438
return scalarA / scalarB
439+
440+
441+
class MagPercentileAction(ScalarFromVectorAction):
442+
"""Calculates the magnitude at the given percentile for completeness"""
443+
444+
matchDistanceKey = Field[str]("Match distance Vector")
445+
fluxUnits = Field[str](doc="Units for the column.", default="nanojansky")
446+
percentile = Field[float](doc="The percentile to find the magnitude at.", default=50.0, check=isPercent)
447+
448+
def getInputSchema(self) -> KeyedDataSchema:
449+
return (
450+
(self.matchDistanceKey, Vector),
451+
(self.vectorKey, Vector),
452+
)
453+
454+
def __call__(self, data: KeyedData, **kwargs) -> Scalar:
455+
matched = np.isfinite(data[self.matchDistanceKey])
456+
fluxValues = data[self.vectorKey.format(**kwargs)]
457+
values = fluxToMag(fluxValues, flux_unit=u.Unit(self.fluxUnits))
458+
nInput, bins = np.histogram(
459+
values,
460+
range=(np.nanmin(values), np.nanmax(values)),
461+
bins=100,
462+
)
463+
nOutput, _ = np.histogram(
464+
values[matched],
465+
range=(np.nanmin(values[matched]), np.nanmax(values[matched])),
466+
bins=bins,
467+
)
468+
# Find bin where the fraction recovered first falls below 0.5
469+
belowPercentile = np.where((nOutput / nInput < self.percentile / 100))[0]
470+
if len(belowPercentile) == 0:
471+
mag = np.nan
472+
else:
473+
mag = np.min(bins[belowPercentile])
474+
return mag

python/lsst/analysis/tools/atools/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,6 @@
3232
from .skyFluxStatisticMetrics import *
3333
from .skyObject import *
3434
from .skySource import *
35+
from .sourceInjectionPlots import *
3536
from .sources import *
3637
from .stellarLocus import *

0 commit comments

Comments
 (0)