|
| 1 | +# This file is part of analysis_tools. |
| 2 | +# |
| 3 | +# Developed for the LSST Data Management System. |
| 4 | +# This product includes software developed by the LSST Project |
| 5 | +# (https://www.lsst.org). |
| 6 | +# See the COPYRIGHT file at the top-level directory of this distribution |
| 7 | +# for details of code ownership. |
| 8 | +# |
| 9 | +# This program is free software: you can redistribute it and/or modify |
| 10 | +# it under the terms of the GNU General Public License as published by |
| 11 | +# the Free Software Foundation, either version 3 of the License, or |
| 12 | +# (at your option) any later version. |
| 13 | +# |
| 14 | +# This program is distributed in the hope that it will be useful, |
| 15 | +# but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 16 | +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 17 | +# GNU General Public License for more details. |
| 18 | +# |
| 19 | +# You should have received a copy of the GNU General Public License |
| 20 | +# along with this program. If not, see <https://www.gnu.org/licenses/>. |
| 21 | + |
| 22 | + |
| 23 | +from typing import Mapping |
| 24 | + |
| 25 | +import matplotlib.pyplot as plt |
| 26 | +import numpy as np |
| 27 | +from lsst.pex.config import Field, ListField |
| 28 | +from matplotlib.figure import Figure |
| 29 | + |
| 30 | +from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector |
| 31 | +from .plotUtils import addPlotInfo |
| 32 | + |
| 33 | +__all__ = ("CompletenessHist",) |
| 34 | + |
| 35 | + |
| 36 | +class CompletenessHist(PlotAction): |
| 37 | + """Makes a scatter plot of the data with a marginal |
| 38 | + histogram for each axis. |
| 39 | + """ |
| 40 | + |
| 41 | + magKey = Field[str](doc="Name of the magnitude column.", default="mag") |
| 42 | + matchDistanceKey = Field[str](doc="Name of the match distance column.", default="matchDistance") |
| 43 | + xAxisLabel = Field[str](doc="Label for the x axis.", default="Input Magnitude (mag)") |
| 44 | + inputLabel = Field[str](doc="Label for the input source histogram.", default="Synthetic Inputs") |
| 45 | + outputLabel = Field[str](doc="Label for the recovered source histogram.", default="Synthetic Recovered") |
| 46 | + numBins = Field[int](doc="Number of bins to use for the histograms.", default=100) |
| 47 | + completenessPercentiles = ListField[float]( |
| 48 | + doc="Record the magnitudes at these percentiles", default=[84.0, 50.0, 16.0] |
| 49 | + ) |
| 50 | + |
| 51 | + def getInputSchema(self) -> KeyedDataSchema: |
| 52 | + base: list[tuple[str, type[Vector] | ScalarType]] = [] |
| 53 | + base.append((self.magKey, Vector)) |
| 54 | + base.append((self.matchDistanceKey, Vector)) |
| 55 | + return base |
| 56 | + |
| 57 | + def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: |
| 58 | + self._validateInput(data, **kwargs) |
| 59 | + return self.makePlot(data, **kwargs) |
| 60 | + |
| 61 | + def _validateInput(self, data: KeyedData, **kwargs) -> None: |
| 62 | + """NOTE currently can only check that something is not a Scalar, not |
| 63 | + check that the data is consistent with Vector |
| 64 | + """ |
| 65 | + needed = self.getFormattedInputSchema(**kwargs) |
| 66 | + if remainder := {key.format(**kwargs) for key, _ in needed} - { |
| 67 | + key.format(**kwargs) for key in data.keys() |
| 68 | + }: |
| 69 | + raise ValueError(f"Task needs keys {remainder} but they were not found in input") |
| 70 | + for name, typ in needed: |
| 71 | + isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) |
| 72 | + if isScalar and typ != Scalar: |
| 73 | + raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") |
| 74 | + |
| 75 | + def makePlot(self, data, plotInfo, **kwargs): |
| 76 | + """Makes a plot showing the fraction of injected sources recovered by |
| 77 | + input magnitude. |
| 78 | +
|
| 79 | + Parameters |
| 80 | + ---------- |
| 81 | + data : `KeyedData` |
| 82 | + All the data |
| 83 | + plotInfo : `dict` |
| 84 | + A dictionary of information about the data being plotted with keys: |
| 85 | + ``camera`` |
| 86 | + The camera used to take the data (`lsst.afw.cameraGeom.Camera`) |
| 87 | + ``"cameraName"`` |
| 88 | + The name of camera used to take the data (`str`). |
| 89 | + ``"filter"`` |
| 90 | + The filter used for this data (`str`). |
| 91 | + ``"ccdKey"`` |
| 92 | + The ccd/dectector key associated with this camera (`str`). |
| 93 | + ``"visit"`` |
| 94 | + The visit of the data; only included if the data is from a |
| 95 | + single epoch dataset (`str`). |
| 96 | + ``"patch"`` |
| 97 | + The patch that the data is from; only included if the data is |
| 98 | + from a coadd dataset (`str`). |
| 99 | + ``"tract"`` |
| 100 | + The tract that the data comes from (`str`). |
| 101 | + ``"photoCalibDataset"`` |
| 102 | + The dataset used for the calibration, e.g. "jointcal" or "fgcm" |
| 103 | + (`str`). |
| 104 | + ``"skyWcsDataset"`` |
| 105 | + The sky Wcs dataset used (`str`). |
| 106 | + ``"rerun"`` |
| 107 | + The rerun the data is stored in (`str`). |
| 108 | +
|
| 109 | + Returns |
| 110 | + ------ |
| 111 | + ``fig`` |
| 112 | + The figure to be saved (`matplotlib.figure.Figure`). |
| 113 | +
|
| 114 | + Notes |
| 115 | + ----- |
| 116 | + Makes a histogram showing the fraction recovered in each magnitude |
| 117 | + bin with the number input and recovered overplotted. |
| 118 | + """ |
| 119 | + |
| 120 | + # Make plot showing the fraction recovered in magnitude bins |
| 121 | + fig, axLeft = plt.subplots(dpi=300) |
| 122 | + axLeft.tick_params(axis="y", labelcolor="C0") |
| 123 | + axLeft.set_xlabel(self.xAxisLabel) |
| 124 | + axLeft.set_ylabel("Fraction Recovered", color="C0") |
| 125 | + axRight = axLeft.twinx() |
| 126 | + axRight.set_ylabel("Number of Sources") |
| 127 | + matched = np.isfinite(data[self.matchDistanceKey]) |
| 128 | + nInput, bins, _ = axRight.hist( |
| 129 | + data[self.magKey], |
| 130 | + range=(np.nanmin(data[self.magKey]), np.nanmax(data[self.magKey])), |
| 131 | + bins=self.numBins, |
| 132 | + log=True, |
| 133 | + histtype="step", |
| 134 | + label=self.inputLabel, |
| 135 | + color="black", |
| 136 | + ) |
| 137 | + nOutput, _, _ = axRight.hist( |
| 138 | + data[self.magKey][matched], |
| 139 | + range=(np.nanmin(data[self.magKey][matched]), np.nanmax(data[self.magKey][matched])), |
| 140 | + bins=bins, |
| 141 | + log=True, |
| 142 | + histtype="step", |
| 143 | + label=self.outputLabel, |
| 144 | + color="grey", |
| 145 | + ) |
| 146 | + |
| 147 | + # Find bin where the fraction recovered falls below a given percentile. |
| 148 | + percentileInfo = [] |
| 149 | + xlims = plt.gca().get_xlim() |
| 150 | + for pct in self.completenessPercentiles: |
| 151 | + pct /= 100 |
| 152 | + magArray = np.where((nOutput / nInput < pct))[0] |
| 153 | + if len(magArray) == 0: |
| 154 | + mag = np.nan |
| 155 | + else: |
| 156 | + mag = np.min(bins[magArray]) |
| 157 | + axLeft.plot([xlims[0], mag], [pct, pct], ls=":", color="grey") |
| 158 | + axLeft.plot([mag, mag], [0, pct], ls=":", color="grey") |
| 159 | + percentileInfo.append("Magnitude at {}% recovered: {:0.2f}".format(pct * 100, mag)) |
| 160 | + plt.xlim(xlims) |
| 161 | + axLeft.set_ylim(0, 1.05) |
| 162 | + axRight.legend(loc="lower left", ncol=2) |
| 163 | + axLeft.axhline(1, color="grey", ls="--") |
| 164 | + axLeft.bar( |
| 165 | + bins[:-1], |
| 166 | + nOutput / nInput, |
| 167 | + width=np.diff(bins), |
| 168 | + align="edge", |
| 169 | + color="C0", |
| 170 | + alpha=0.5, |
| 171 | + zorder=10, |
| 172 | + ) |
| 173 | + |
| 174 | + # Add useful information to the plot |
| 175 | + fig = plt.gcf() |
| 176 | + addPlotInfo(fig, plotInfo) |
| 177 | + statsText = "" |
| 178 | + for info in percentileInfo: |
| 179 | + statsText += f"{info}\n" |
| 180 | + bbox = dict(edgecolor="grey", linestyle=":", facecolor="none") |
| 181 | + fig.text(0.7, 0.075, statsText[:-1], bbox=bbox, transform=fig.transFigure, fontsize=6) |
| 182 | + fig.subplots_adjust(bottom=0.2) |
| 183 | + return fig |
0 commit comments