Skip to content

Commit 2d6848d

Browse files
committed
Add QA plots for source injection to analysis_tools
1 parent ea42295 commit 2d6848d

File tree

10 files changed

+644
-3
lines changed

10 files changed

+644
-3
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
description: |
2+
Tier1 plots and metrics to assess injected coadd quality
3+
tasks:
4+
injectedObjectAnalysis:
5+
class: lsst.analysis.tools.tasks.injectedObjectAnalysis.InjectedObjectAnalysisTask
6+
config:
7+
atools.completenessHist: CompletenessPurityTool
8+
atools.astromDiffRAScatterPlot: TargetInjectedCatDeltaRAScatterPlot
9+
atools.astromDiffDecScatterPlot: TargetInjectedCatDeltaDecScatterPlot
10+
atools.astromDiffMetrics: TargetInjectedCatDeltaMetrics
11+
atools.astromDiffMetrics.applyContext: CoaddContext
12+
atools.targetInjectedCatDeltaPsfScatterPlot: TargetInjectedCatDeltaPsfScatterPlot
13+
bands: ["g", "r", "i", "z", "y"]
14+
python: |
15+
from lsst.analysis.tools.atools import *
16+
from lsst.analysis.tools.contexts import *

python/lsst/analysis/tools/actions/plot/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .barPlots import *
22
from .calculateRange import *
33
from .colorColorFitPlot import *
4+
from .completenessPlot import *
45
from .diaSkyPlot import *
56
from .focalPlanePlot import *
67
from .gridPlot import *
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# This file is part of analysis_tools.
2+
#
3+
# Developed for the LSST Data Management System.
4+
# This product includes software developed by the LSST Project
5+
# (https://www.lsst.org).
6+
# See the COPYRIGHT file at the top-level directory of this distribution
7+
# for details of code ownership.
8+
#
9+
# This program is free software: you can redistribute it and/or modify
10+
# it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 3 of the License, or
12+
# (at your option) any later version.
13+
#
14+
# This program is distributed in the hope that it will be useful,
15+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17+
# GNU General Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License
20+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
21+
22+
23+
from typing import Mapping
24+
25+
import matplotlib.pyplot as plt
26+
import numpy as np
27+
from lsst.pex.config import Field, ListField
28+
from matplotlib.figure import Figure
29+
30+
from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector
31+
from .plotUtils import addPlotInfo
32+
33+
__all__ = ("CompletenessHist",)
34+
35+
36+
class CompletenessHist(PlotAction):
37+
"""Makes a scatter plot of the data with a marginal
38+
histogram for each axis.
39+
"""
40+
41+
magKey = Field[str](doc="Name of the magnitude column.", default="mag")
42+
matchDistanceKey = Field[str](doc="Name of the match distance column.", default="matchDistance")
43+
xAxisLabel = Field[str](doc="Label for the x axis.", default="Input Magnitude (mag)")
44+
inputLabel = Field[str](doc="Label for the input source histogram.", default="Synthetic Inputs")
45+
outputLabel = Field[str](doc="Label for the recovered source histogram.", default="Synthetic Recovered")
46+
numBins = Field[int](doc="Number of bins to use for the histograms.", default=100)
47+
completenessPercentiles = ListField[float](doc="Record the magnitudes at these percentiles", default=[16.0, 50.0, 84.0])
48+
49+
def getInputSchema(self) -> KeyedDataSchema:
50+
base: list[tuple[str, type[Vector] | ScalarType]] = []
51+
base.append((self.magKey, Vector))
52+
base.append((self.matchDistanceKey, Vector))
53+
return base
54+
55+
def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
56+
self._validateInput(data, **kwargs)
57+
return self.makePlot(data, **kwargs)
58+
59+
def _validateInput(self, data: KeyedData, **kwargs) -> None:
60+
"""NOTE currently can only check that something is not a Scalar, not
61+
check that the data is consistent with Vector
62+
"""
63+
needed = self.getFormattedInputSchema(**kwargs)
64+
if remainder := {key.format(**kwargs) for key, _ in needed} - {
65+
key.format(**kwargs) for key in data.keys()
66+
}:
67+
raise ValueError(f"Task needs keys {remainder} but they were not found in input")
68+
for name, typ in needed:
69+
isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
70+
if isScalar and typ != Scalar:
71+
raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
72+
73+
def makePlot(self, data, plotInfo, **kwargs):
74+
"""Makes a plot showing the fraction of injected sources recovered by
75+
input magnitude.
76+
77+
Parameters
78+
----------
79+
data : `KeyedData`
80+
All the data
81+
plotInfo : `dict`
82+
A dictionary of information about the data being plotted with keys:
83+
``camera``
84+
The camera used to take the data (`lsst.afw.cameraGeom.Camera`)
85+
``"cameraName"``
86+
The name of camera used to take the data (`str`).
87+
``"filter"``
88+
The filter used for this data (`str`).
89+
``"ccdKey"``
90+
The ccd/dectector key associated with this camera (`str`).
91+
``"visit"``
92+
The visit of the data; only included if the data is from a
93+
single epoch dataset (`str`).
94+
``"patch"``
95+
The patch that the data is from; only included if the data is
96+
from a coadd dataset (`str`).
97+
``"tract"``
98+
The tract that the data comes from (`str`).
99+
``"photoCalibDataset"``
100+
The dataset used for the calibration, e.g. "jointcal" or "fgcm"
101+
(`str`).
102+
``"skyWcsDataset"``
103+
The sky Wcs dataset used (`str`).
104+
``"rerun"``
105+
The rerun the data is stored in (`str`).
106+
107+
Returns
108+
------
109+
``fig``
110+
The figure to be saved (`matplotlib.figure.Figure`).
111+
112+
Notes
113+
-----
114+
Makes a histogram showing the fraction recovered in each magnitude
115+
bin with the number input and recovered overplotted.
116+
"""
117+
118+
# Make plot showing the fraction recovered in magnitude bins
119+
fig, axLeft = plt.subplots(dpi=300)
120+
axLeft.tick_params(axis="y", labelcolor="C0")
121+
axLeft.set_xlabel(self.xAxisLabel)
122+
axLeft.set_ylabel("Fraction Recovered", color="C0")
123+
axRight = axLeft.twinx()
124+
axRight.set_ylabel("Number of Sources")
125+
matched = np.isfinite(data[self.matchDistanceKey])
126+
nInput, bins, _ = axRight.hist(
127+
data[self.magKey],
128+
range=(np.nanmin(data[self.magKey]), np.nanmax(data[self.magKey])),
129+
bins=self.numBins,
130+
log=True,
131+
histtype="step",
132+
label=self.inputLabel,
133+
color="black",
134+
)
135+
nOutput, _, _ = axRight.hist(
136+
data[self.magKey][matched],
137+
range=(np.nanmin(data[self.magKey][matched]), np.nanmax(data[self.magKey][matched])),
138+
bins=bins,
139+
log=True,
140+
histtype="step",
141+
label=self.outputLabel,
142+
color="grey",
143+
)
144+
xlims = plt.gca().get_xlim()
145+
# TODO: put a box in the bottom corner for all the percentiles
146+
# Find bin where the fraction recovered falls below a given percentile.
147+
percentileInfo = []
148+
for pct in self.completenessPercentiles:
149+
pct /= 100
150+
magArray = np.where((nOutput / nInput < pct))[0]
151+
if len(magArray) == 0:
152+
mag = np.nan
153+
else:
154+
mag = np.min(bins[magArray])
155+
axLeft.plot([xlims[0], mag], [pct, pct], ls=":", color="grey")
156+
axLeft.plot([mag, mag], [0, pct], ls=":", color="grey")
157+
percentileInfo.append("Magnitude at {}% recovered: {:0.2f}".format(pct * 100, mag))
158+
plt.xlim(xlims)
159+
axLeft.set_ylim(0, 1.05)
160+
axRight.legend(loc="lower left", ncol=2)
161+
axLeft.axhline(1, color="grey", ls="--")
162+
axLeft.bar(
163+
bins[:-1],
164+
nOutput / nInput,
165+
width=np.diff(bins),
166+
align="edge",
167+
color="C0",
168+
alpha=0.5,
169+
zorder=10,
170+
)
171+
bboxDict = dict(boxstyle="round", facecolor="white", alpha=0.75)
172+
173+
spacing = 0
174+
for info in percentileInfo:
175+
axLeft.text(0.3, 0.2 + spacing, info, transform=fig.transFigure, bbox=bboxDict, zorder=11)
176+
spacing += 0.1
177+
178+
# Add useful information to the plot
179+
fig = plt.gcf()
180+
addPlotInfo(fig, plotInfo)
181+
return fig

python/lsst/analysis/tools/actions/scalar/scalarActions.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,20 @@
4040
"IqrHistAction",
4141
"DivideScalar",
4242
"RmsAction",
43+
"MagPercentileAction",
4344
)
4445

4546
import operator
4647
from math import nan
4748
from typing import cast
4849

4950
import numpy as np
51+
from astropy import units as u
5052
from lsst.pex.config import ChoiceField, Field
5153
from lsst.pex.config.configurableActions import ConfigurableActionField
5254

5355
from ...interfaces import KeyedData, KeyedDataSchema, Scalar, ScalarAction, Vector
54-
from ...math import nanMax, nanMean, nanMedian, nanMin, nanSigmaMad, nanStd
56+
from ...math import fluxToMag, isPercent, nanMax, nanMean, nanMedian, nanMin, nanSigmaMad, nanStd
5557

5658

5759
class ScalarFromVectorAction(ScalarAction):
@@ -434,3 +436,39 @@ def __call__(self, data: KeyedData, **kwargs) -> Scalar:
434436
if scalarB == 0:
435437
raise ValueError("Denominator is zero!")
436438
return scalarA / scalarB
439+
440+
441+
class MagPercentileAction(ScalarFromVectorAction):
442+
"""Calculates the magnitude at the given percentile for completeness"""
443+
444+
matchDistanceKey = Field[str]("Match distance Vector")
445+
fluxUnits = Field[str](doc="Units for the column.", default="nanojansky")
446+
percentile = Field[float](doc="The percentile to find the magnitude at.", default=50.0, check=isPercent)
447+
448+
def getInputSchema(self) -> KeyedDataSchema:
449+
return (
450+
(self.matchDistanceKey, Vector),
451+
(self.vectorKey, Vector),
452+
)
453+
454+
def __call__(self, data: KeyedData, **kwargs) -> Scalar:
455+
matched = np.isfinite(data[self.matchDistanceKey])
456+
fluxValues = data[self.vectorKey.format(**kwargs)]
457+
values = fluxToMag(fluxValues, flux_unit=u.Unit(self.fluxUnits))
458+
nInput, bins = np.histogram(
459+
values,
460+
range=(np.nanmin(values), np.nanmax(values)),
461+
bins=100,
462+
)
463+
nOutput, _ = np.histogram(
464+
values[matched],
465+
range=(np.nanmin(values[matched]), np.nanmax(values[matched])),
466+
bins=bins,
467+
)
468+
# Find bin where the fraction recovered first falls below 0.5
469+
belowPercentile = np.where((nOutput / nInput < self.percentile / 100))[0]
470+
if len(belowPercentile) == 0:
471+
mag = np.nan
472+
else:
473+
mag = np.min(bins[belowPercentile])
474+
return mag

python/lsst/analysis/tools/actions/vector/vectorActions.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"ResidualWithPerGroupStatistic",
3636
"RAcosDec",
3737
"AngularSeparation",
38+
"MagPercentileAction",
3839
)
3940

4041
import logging
@@ -44,11 +45,11 @@
4445
import pandas as pd
4546
from astropy import units as u
4647
from astropy.coordinates import SkyCoord
47-
from lsst.pex.config import DictField, Field
48+
from lsst.pex.config import DictField, Field, ListField
4849
from lsst.pex.config.configurableActions import ConfigurableActionField, ConfigurableActionStructField
4950

5051
from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction
51-
from ...math import divide, fluxToMag, log10
52+
from ...math import divide, fluxToMag, isPercent, log10
5253
from .selectors import VectorSelector
5354

5455
_LOG = logging.getLogger(__name__)
@@ -404,3 +405,42 @@ def __call__(self, data: KeyedData, **kwargs) -> Vector:
404405

405406
result = joinedDf["value_individual"] - joinedDf["value_group"]
406407
return np.array(result)
408+
409+
410+
class MagPercentileAction(VectorAction):
411+
"""Calculates the magnitude at the given percentile for completeness"""
412+
413+
matchDistanceKey = Field[str]("Match distance Vector")
414+
vectorKey = Field[str](doc="Key of vector which should be loaded")
415+
fluxUnits = Field[str](doc="Units for the column.", default="nanojansky")
416+
percentiles = ListField[float](doc="The percentiles to find the magnitude at.", default=[16.0, 50.0, 84.0], itemCheck=isPercent)
417+
418+
def getInputSchema(self) -> KeyedDataSchema:
419+
return (
420+
(self.matchDistanceKey, Vector),
421+
(self.vectorKey, Vector),
422+
)
423+
424+
def __call__(self, data: KeyedData, **kwargs) -> Scalar:
425+
matched = np.isfinite(data[self.matchDistanceKey])
426+
fluxValues = data[self.vectorKey.format(**kwargs)]
427+
values = fluxToMag(fluxValues, flux_unit=u.Unit(self.fluxUnits))
428+
nInput, bins = np.histogram(
429+
values,
430+
range=(np.nanmin(values), np.nanmax(values)),
431+
bins=100,
432+
)
433+
nOutput, _ = np.histogram(
434+
values[matched],
435+
range=(np.nanmin(values[matched]), np.nanmax(values[matched])),
436+
bins=bins,
437+
)
438+
# Find bin where the fraction recovered first falls below 0.5
439+
mags = []
440+
for pct in self.percentiles:
441+
belowPercentile = np.where((nOutput / nInput < pct / 100))[0]
442+
if len(belowPercentile) == 0:
443+
mags.append(np.nan)
444+
else:
445+
mags.append(np.min(bins[belowPercentile]))
446+
return np.array(mags)

python/lsst/analysis/tools/atools/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from .skyFluxStatisticMetrics import *
3434
from .skyObject import *
3535
from .skySource import *
36+
from .sourceInjectionPlots import *
3637
from .sources import *
3738
from .stellarLocus import *
3839
from .wholeSkyPlotTool import *

0 commit comments

Comments
 (0)