diff --git a/python/lsst/ap/association/__init__.py b/python/lsst/ap/association/__init__.py index f337d225..b2306b55 100644 --- a/python/lsst/ap/association/__init__.py +++ b/python/lsst/ap/association/__init__.py @@ -30,3 +30,4 @@ from .version import * from .mpSkyEphemerisQuery import * from .ssSingleFrameAssociation import * +from .testApdb import * diff --git a/python/lsst/ap/association/diaPipe.py b/python/lsst/ap/association/diaPipe.py index 5856f77a..c62ea19c 100644 --- a/python/lsst/ap/association/diaPipe.py +++ b/python/lsst/ap/association/diaPipe.py @@ -789,7 +789,8 @@ def writeToApdb(self, updatedDiaObjects, associatedDiaSources, diaForcedSources) diaForcedSourceStore) self.log.info("APDB updated.") - def testDataFrameIndex(self, df): + @staticmethod + def testDataFrameIndex(df): """Test the sorted DataFrame index for duplicates. Wrapped as a separate function to allow for mocking of the this task @@ -922,7 +923,8 @@ def mergeCatalogs(self, originalCatalog, newCatalog, catalogName): mergedCatalog = pd.concat([originalCatalog], sort=True) return mergedCatalog.loc[:, originalCatalog.columns] - def updateObjectTable(self, diaObjects, diaSources): + @staticmethod + def updateObjectTable(diaObjects, diaSources): """Update the diaObject table with the new diaSource records. Parameters diff --git a/python/lsst/ap/association/testApdb.py b/python/lsst/ap/association/testApdb.py new file mode 100644 index 00000000..feccc3f2 --- /dev/null +++ b/python/lsst/ap/association/testApdb.py @@ -0,0 +1,778 @@ +# +# LSST Data Management System +# Copyright 2008-2016 AURA/LSST. +# +# This product includes software developed by the +# LSST Project (http://www.lsst.org/). +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the LSST License Statement and +# the GNU General Public License along with this program. If not, +# see . +# + +"""Standalone pipelinetask to populate an APDB with simulated data. +""" + +__all__ = ("TestApdbConfig", + "TestApdbTask", + ) + + +import numpy as np +import pandas as pd +import time + +from lsst.daf.base import DateTime +import lsst.dax.apdb as daxApdb +import lsst.geom +from lsst.meas.base import DetectorVisitIdGeneratorConfig, IdGenerator +import lsst.pex.config as pexConfig +import lsst.pipe.base as pipeBase +from lsst.ap.association.loadDiaCatalogs import LoadDiaCatalogsTask +from lsst.ap.association.diaPipe import DiaPipelineTask +from lsst.ap.association.utils import ( + convertTableToSdmSchema, + readSchemaFromApdb, + column_dtype, + make_empty_catalog, + dropEmptyColumns, +) +import lsst.sphgeom +from lsst.utils.timer import timeMethod + + +class TestApdbConnections( + pipeBase.PipelineTaskConnections, + dimensions=("instrument", "visit", "detector")): + """Butler connections for TestApdbTask. + """ + + apdbTestMarker = pipeBase.connectionTypes.Output( + doc="Marker dataset storing the configuration of the Apdb for each " + "visit/detector. Used to signal the completion of the pipeline.", + name="apdbTest_marker", + storageClass="Config", + dimensions=("instrument", "visit", "detector"), + ) + + +class TestApdbConfig(pipeBase.PipelineTaskConfig, + pipelineConnections=TestApdbConnections): + """Config for TestApdbTask. + """ + apdb_config_url = pexConfig.Field( + dtype=str, + default=None, + optional=False, + doc="A config file specifying the APDB and its connection parameters, " + "typically written by the apdb-cli command-line utility. " + "The database must already be initialized.", + ) + survey_area = pexConfig.Field( + dtype=float, + default=20000, + doc="Area (in degrees) of the simulated survey", + ) + fov = pexConfig.Field( + dtype=float, + default=9.6, + doc="Field of view of the camera, in square degrees.", + ) + stellar_density = pexConfig.Field( + dtype=float, + default=1750, + doc="Average number of real transient and variable objects" + " detected per square degree. For this simulation, these will" + " always be detected, and detected in the same location." + "The default is chosen such that:" + " Density x Rubin Fov (9.6) x # of visits per night (~600) ~ 10M", + ) + false_positive_ratio = pexConfig.Field( + dtype=float, + default=2, + doc="Average ratio of false detections to real sources." + "These will be detected in random locations.", + ) + false_positive_variability = pexConfig.Field( + dtype=float, + default=100, + doc="Parameter characterizing the variability in the rate of false" + " positives.", + ) + sky_seed = pexConfig.Field( + dtype=int, + default=37, + doc="Seed used to simulate the real sources.", + ) + historyThreshold = pexConfig.Field( + dtype=int, + doc="Minimum number of detections of a diaObject required " + "to run forced photometry. Set to 1 to include all diaObjects.", + default=2, + ) + objId_start = pexConfig.Field( + dtype=int, + default=1, + doc="Starting diaObject ID number of real objects.", + ) + maximum_table_length = pexConfig.Field( + dtype=int, + default=65535, + doc="Maximum length of tables allowed to be written in one operation" + " to the Cassandra APDB", + ) + raise_on_associated_fakes = pexConfig.Field( + dtype=bool, + default=True, + doc="Raise a RuntimeError if any fake sources are associated" + " with existing DiaObjects. This is likely due to restarting" + " the simulation after partial writes to the APDB", + ) + + idGenerator = DetectorVisitIdGeneratorConfig.make_field() + idGeneratorFakes = DetectorVisitIdGeneratorConfig.make_field() + idGeneratorForced = DetectorVisitIdGeneratorConfig.make_field() + + def setDefaults(self): + self.idGenerator = DetectorVisitIdGeneratorConfig(release_id=0, n_releases=3) + self.idGeneratorFakes = DetectorVisitIdGeneratorConfig(release_id=1, n_releases=3) + self.idGeneratorForced = DetectorVisitIdGeneratorConfig(release_id=2, n_releases=3) + + +class TestApdbTask(LoadDiaCatalogsTask): + """Task for loading, associating and storing Difference Image Analysis + (DIA) Objects and Sources. + """ + ConfigClass = TestApdbConfig + _DefaultName = "apdbTest" + + def __init__(self, initInputs=None, **kwargs): + super().__init__(**kwargs) + self.apdb = daxApdb.Apdb.from_uri(self.config.apdb_config_url) + self.schema = readSchemaFromApdb(self.apdb) + self.prepareSurvey() + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + inputs = butlerQC.get(inputRefs) + inputs["visit"] = butlerQC.quantum.dataId["visit"] + inputs["detector"] = butlerQC.quantum.dataId["detector"] + inputs["idGenerator"] = self.config.idGenerator.apply(butlerQC.quantum.dataId) + inputs["idGeneratorFakes"] = self.config.idGeneratorFakes.apply(butlerQC.quantum.dataId) + inputs["idGeneratorForced"] = self.config.idGeneratorForced.apply(butlerQC.quantum.dataId) + self.run(**inputs) + # Note that the commented-out code below is intentionally omitted + # We don't want this to write anything to the Butler! + butlerQC.put(pipeBase.Struct(), outputRefs) + + def prepareSurvey(self): + """Prepare survey parameters common to all observations + """ + deg2ToSteradian = np.deg2rad(1)**2 # Conversion factor from square degrees to steradians + sphereDegrees = 4*np.pi/deg2ToSteradian # area of a unit sphere in square degrees + sphereFraction = self.config.survey_area/sphereDegrees + survey_area = self.config.survey_area*deg2ToSteradian + # Avoid the branch cut of the arcsin function at the equator + # Calculate the angle from the pole needed to cover `survey_area` square + # degrees of a sphere + if sphereFraction < 0.5: + declinationMax = np.arcsin(1 - survey_area/(2*np.pi)) + else: + declinationMax = np.pi - np.arcsin(1 - (4*np.pi - survey_area)/(2*np.pi)) + self.radiusMax = stereographicRaDec2XY(0, declinationMax)[0] + self.nReal = int(self.config.survey_area*self.config.stellar_density) + # Radius of the focal plane, in radians + self.fpRadius = np.radians(np.sqrt(self.config.fov/np.pi)) + + def run(self, visit, detector, + idGenerator=IdGenerator(), + idGeneratorFakes=IdGenerator(), + idGeneratorForced=IdGenerator()): + """Generate a full focal plane simulation of real and fake sources, + run association and write to the APDB + + Parameters + ---------- + visit : `int` + Visit ID. + detector : `int` + Detector number, used for the ID generator. + Note that each instance simulates an entire focal plane. + The detector dimension is used as a convenience to allow re-using + code and infrastructure. + idGenerator : `lsst.meas.base,IdGenerator`, optional + ID generator used for "real" sources. + idGeneratorFakes : `lsst.meas.base,IdGenerator`, optional + ID generator for random "fake" sources. + Separate ID generators are needed to avoid overflows between visits, + since an entire focal plane is simulated per instance. + idGeneratorForced : `lsst.meas.base,IdGenerator`, optional + ID generator for forced sources at existing diaObject locations. + """ + t_sim0 = time.time() + idGen = idGenerator.make_table_id_factory() + idGenFakes = idGeneratorFakes.make_table_id_factory() + idGenForced = idGeneratorForced.make_table_id_factory() + seed = int(visit*1000 + detector) + x, y = randomCircleXY(self.radiusMax, seed) + ra, dec = stereographicXY2RaDec(x, y) + if ra < 0: + ra += 2*np.pi + delta0, _ = stereographicRaDec2XY(0, dec - self.fpRadius) + delta1, _ = stereographicRaDec2XY(0, dec + self.fpRadius) + delta = abs(delta1 - delta0)/2. + x0 = x - delta + x1 = x + delta + y0 = y - delta + y1 = y + delta + ra0, dec0 = stereographicXY2RaDec(x0, y0) + ra1, dec1 = stereographicXY2RaDec(x1, y1) + region = lsst.sphgeom.ConvexPolygon([lsst.geom.SpherePoint(ra0, dec0, lsst.geom.radians).getVector(), + lsst.geom.SpherePoint(ra0, dec1, lsst.geom.radians).getVector(), + lsst.geom.SpherePoint(ra1, dec1, lsst.geom.radians).getVector(), + lsst.geom.SpherePoint(ra1, dec0, lsst.geom.radians).getVector()] + ) + self.log.info(f"Simulating {self.nReal} sources around RA={np.degrees(ra)}, Dec={np.degrees(dec)}") + xS, yS = randomCircleXY(self.radiusMax, self.config.sky_seed, n=self.nReal) + diaObjIds = np.arange(self.config.objId_start, self.nReal + self.config.objId_start) + inds = (xS > x0) & (xS < x1) & (yS > y0) & (yS < y1) + nSim = np.sum(inds) + diaObjIds = diaObjIds[inds] + self.log.info(f"{nSim} sources made the spatial cut") + diaSourcesReal = self.createDiaSources(*stereographicXY2RaDec(xS[inds], yS[inds]), + idGenerator=idGen, + visit=visit, + detector=detector, + diaObjectIds=diaObjIds) + + rng = np.random.RandomState(seed) + scale = nSim*self.config.false_positive_ratio/self.config.false_positive_variability + nBogus = int(rng.standard_gamma(scale)*self.config.false_positive_variability) + + diaSourcesBogus = self.createDiaSources(*self.generateFalseDetections(x0, x1, y0, y1, nBogus, seed), + visit=visit, + detector=detector, + idGenerator=idGenFakes) + diaSourcesRaw = pd.concat([diaSourcesReal, diaSourcesBogus]) + + diaSources = convertTableToSdmSchema(self.schema, diaSourcesRaw, tableName="DiaSource") + t_sim1 = time.time() + self.log.info("Simulated source timing: %.2fs", t_sim1 - t_sim0) + + t_load0 = time.time() + diaObjects = self.loadDiaObjects(region.getBoundingCircle(), self.schema) + t_load1 = time.time() + self.log.info("diaObject load timing: %.2fs", t_load1 - t_load0) + + if diaObjects.empty: + self.log.info("diaObjects contain 0 diaSources (empty)") + else: + nDiaSources = diaObjects.nDiaSources + self.log.info(f"diaObjects contain {np.min(nDiaSources)} to {np.max(nDiaSources)} diaSources") + + # Associate DiaSources with DiaObjects + associatedDiaSources, newDiaObjects, associatedDiaObjects = self.associateDiaSources(diaSources, + diaObjects, + diaSourcesReal, + diaSourcesBogus, + ) + # Merge new and preloaded diaObjects + mergedDiaObjects = self.mergeAssociatedCatalogs(associatedDiaObjects, newDiaObjects) + + nObj = len(mergedDiaObjects) + nSrc = len(associatedDiaSources) + dateTime = DateTime.now().toAstropy() + ind = 0 + t_write0 = time.time() + # Note that nObj must always be equal to or greater than nSrc + for start in range(0, nObj, self.config.maximum_table_length): + end = min(start + self.config.maximum_table_length, nObj) + diaObjectsChunk = mergedDiaObjects.iloc[start:end] + nObjChunk = len(diaObjectsChunk) + self.log.info(f"Writing diaObject chunk {ind} of length {nObjChunk} (of {nObj}) to the APDB") + srcEnd = min(start + self.config.maximum_table_length, nSrc) + if srcEnd <= start: + finalDiaSources = None + else: + diaSourcesChunk = associatedDiaSources.iloc[start:srcEnd] + finalDiaSources = convertTableToSdmSchema(self.schema, diaSourcesChunk, tableName="DiaSource") + nSrcChunk = len(diaSourcesChunk) + self.log.info(f"Writing diaSource chunk {ind} of length {nSrcChunk} (of {nSrc}) to the APDB") + diaForcedSources = self.runForcedMeasurement(diaObjectsChunk, idGenForced, visit, detector) + + finalDiaObjects = convertTableToSdmSchema(self.schema, diaObjectsChunk, tableName="DiaObject") + finalDiaForcedSources = convertTableToSdmSchema(self.schema, diaForcedSources, + tableName="DiaForcedSource") + self.log.info(f"Writing forced source chunk {ind} of length" + f" {len(finalDiaForcedSources)} to the APDB") + self.writeToApdb(finalDiaObjects, finalDiaSources, finalDiaForcedSources, dateTime) + ind += 1 + t_write1 = time.time() + self.log.info("APDB write timing: %.2fs", t_write1 - t_write0) + marker = pexConfig.Config() + return pipeBase.Struct(apdbTestMarker=marker) + + def createDiaSources(self, raVals, decVals, idGenerator, visit, detector, diaObjectIds=None): + """Create diaSources with the supplied coordinates. + + Parameters + ---------- + raVals : `numpy.ndarray` + Right Ascension (RA) values of the desired sources. + decVals : `numpy.ndarray` + Declination (Dec) values of the desired sources. + idGenerator : `lsst.meas.base,IdGenerator`, optional + ID generator used for the source IDs + diaObjectIds : `numpy.ndarray`, optional + diaObjectIds to assign to the sources. If not supplied, the object + ID is set to the diaSourceId + + Returns + ------- + diaSources : `pandas.DataFrame` + Table of sources with the supplied coordinates and IDs assigned. + """ + n = len(raVals) + raVals[raVals < 0] += 2*np.pi + ra = pd.Series(np.degrees(raVals), name='ra') + dec = pd.Series(np.degrees(decVals), name='dec') + diaSourceId = pd.Series([idGenerator() for i in range(n)], name='diaSourceId') + if diaObjectIds is None: + diaObjectId = pd.Series(diaSourceId, name='diaObjectId') + else: + diaObjectId = pd.Series(diaObjectIds, name='diaObjectId') + baseSources = pd.concat([diaSourceId, diaObjectId, ra, dec], axis=1) + baseSources["visit"] = visit + baseSources["detector"] = detector + + baseSources["band"] = 'A' + preserveColumns = ["diaSourceId", "diaObjectId", "ra", "dec", "visit", "detector", "band"] + + diaSources = fillRandomTable(self.schema, baseSources, + tableName="DiaSource", + preserveColumns=preserveColumns) + # Do *not* set the index to the diaSourceId. It will need to be diaObjectId for matching (later) + # diaSources.set_index("diaSourceId", inplace=True) + return diaSources + + def generateFalseDetections(self, x0, x1, y0, y1, nBogus, seed): + """Generate random coordinates within the ranges provided. + + Parameters + ---------- + x0 : `int` + Minimum projected x coordinate + x1 : `int` + Maximum projected x coordinate + y0 : `int` + Minimum projected y coordinate + y1 : `int` + Maximum projected y coordinate + nBogus : `int` + Number of "fake" sources within the region. + rng : `int` + Seed value for the random number generator to provide repeatable results. + + Returns + ------- + ra, dec : `numpy.ndarray` + Coordinates matching the randomly generated locations. + """ + self.log.info(f"Simulating {nBogus} false detections within region.") + rng = np.random.RandomState(seed) + x = rng.random_sample(nBogus)*(x1 - x0) + x0 + y = rng.random_sample(nBogus)*(y1 - y0) + y0 + return stereographicXY2RaDec(x, y) + + def simpleMatch(self, diaSourceTable, diaObjects, updateNdiaSources=True): + """Match by pre-defined ID. + + Parameters + ---------- + diaSourceTable : `pandas.DataFrame` + New DIASources to be associated with existing DIAObjects. + diaObjects : `pandas.DataFrame` + Existing diaObjects from the Apdb. + + Returns + ------- + result : `lsst.pipe.base.Struct` + Results struct with components. + + - ``matchedDiaSources`` : DiaSources that were matched. Matched + Sources have their diaObjectId updated and set to the id of the + diaObject they were matched to. (`pandas.DataFrame`) + - ``unAssocDiaSources`` : DiaSources that were not matched. + Unassociated sources have their diaObject set to 0 as they + were not associated with any existing DiaObjects. + (`pandas.DataFrame`) + - ``nUpdatedDiaObjects`` : Number of DiaObjects that were + matched to new DiaSources. (`int`) + - ``nUnassociatedDiaObjects`` : Number of DiaObjects that were + not matched a new DiaSource. (`int`) + """ + # Only use diaObjectId as the index here, and do not update the index + # of the original diaSourceTable + diaSourceTable.set_index("diaObjectId", inplace=True) + matchedDiaSources = diaSourceTable.loc[diaSourceTable.index.intersection(diaObjects.index)] + unAssocDiaSources = diaSourceTable.loc[diaSourceTable.index.difference(diaObjects.index)] + + matchedDiaObjectInds = diaObjects.index.intersection(diaSourceTable.index) + if not diaObjects.empty and updateNdiaSources: + nDiaSources = diaObjects.loc[matchedDiaObjectInds, "nDiaSources"] + 1 + diaObjects.loc[matchedDiaObjectInds, "nDiaSources"] = nDiaSources + + # Reset the index of the diaSource dataframes, so diaObjectId is still + # a valid column + return pipeBase.Struct(unAssocDiaSources=unAssocDiaSources.reset_index(), + matchedDiaSources=matchedDiaSources.reset_index(), + nUpdatedDiaObjects=len(matchedDiaSources), + nUnassociatedDiaObjects=len(unAssocDiaSources), + matchedDiaObjects=diaObjects.loc[matchedDiaObjectInds], + ) + + def associateDiaSources(self, diaSourceTable, diaObjects, diaSourcesReal, diaSourcesBogus): + """Associate DiaSources with DiaObjects. + + Parameters + ---------- + diaSourceTable : `pandas.DataFrame` + Newly detected DiaSources. + diaObjects : `pandas.DataFrame` + Table of DiaObjects from preloaded DiaObjects. + + Returns + ------- + associatedDiaSources : `pandas.DataFrame` + Associated DiaSources with DiaObjects. + newDiaObjects : `pandas.DataFrame` + Table of new DiaObjects after association. + """ + # Associate new DiaSources with existing DiaObjects. + assocResults = self.simpleMatch(diaSourceTable, diaObjects, updateNdiaSources=True) + assocReal = self.simpleMatch(diaSourcesReal, diaObjects, updateNdiaSources=False) + assocBogus = self.simpleMatch(diaSourcesBogus, diaObjects, updateNdiaSources=False) + + toAssociate = [] + + # Create new DiaObjects from unassociated diaSources. + createResults = self.createNewDiaObjects(assocResults.unAssocDiaSources) + if len(assocResults.matchedDiaSources) > 0: + toAssociate.append(assocResults.matchedDiaSources) + toAssociate.append(createResults.diaSources) + associatedDiaSources = pd.concat(toAssociate) + + self.log.info("%i updated and %i unassociated diaSources. Creating %i new diaObjects", + assocResults.nUpdatedDiaObjects, + assocResults.nUnassociatedDiaObjects, + createResults.nNewDiaObjects, + ) + self.log.info(f"{assocReal.nUpdatedDiaObjects} real sources associated" + f" and {assocReal.nUnassociatedDiaObjects} not associated") + self.log.info(f"{assocBogus.nUpdatedDiaObjects} fake sources associated" + f" and {assocBogus.nUnassociatedDiaObjects} not associated") + + if self.config.raise_on_associated_fakes & assocBogus.nUpdatedDiaObjects > 0: + raise RuntimeError("Fake sources were associated with real DiaObjects." + " This is likely due to restarting the simulation" + " after partial writes to the APDB") + # Index the DiaSource catalog for this visit after all associations + # have been made. + associatedDiaSources.set_index(["diaObjectId", + "band", + "diaSourceId"], + drop=False, + inplace=True) + return (associatedDiaSources, createResults.newDiaObjects, assocResults.matchedDiaObjects) + + def createNewDiaObjects(self, unAssocDiaSources): + """Loop through the set of DiaSources and create new DiaObjects + for unassociated DiaSources. + + Parameters + ---------- + unAssocDiaSources : `pandas.DataFrame` + Set of DiaSources to create new DiaObjects from. + + Returns + ------- + results : `lsst.pipe.base.Struct` + Results struct containing: + + - diaSources : `pandas.DataFrame` + DiaSource catalog with updated DiaObject ids. + - newDiaObjects : `pandas.DataFrame` + Newly created DiaObjects from the unassociated DiaSources. + - nNewDiaObjects : `int` + Number of newly created diaObjects. + """ + if len(unAssocDiaSources) == 0: + newDiaObjects = make_empty_catalog(self.schema, tableName="DiaObject") + else: + # Do *not* set the diaObjectId to the diaSourceId. + # For this simulation we are using custom diaObjectIds, + # and need to preserve them. + # unAssocDiaSources["diaObjectId"] = unAssocDiaSources["diaSourceId"] + preserveColumns = ["diaObjectId", "nearbyLowzGal"] + # Needs to be set for correct database formatting + unAssocDiaSources["nearbyLowzGal"] = None + + # Fill the forced sources for each diaObject with random data + newDiaObjects = fillRandomTable(self.schema, unAssocDiaSources, + tableName="DiaObject", + preserveColumns=preserveColumns) + newDiaObjects.nDiaSources = 1 + return pipeBase.Struct(diaSources=unAssocDiaSources, + newDiaObjects=newDiaObjects, + nNewDiaObjects=len(newDiaObjects)) + + def mergeAssociatedCatalogs(self, diaObjects, newDiaObjects): + """Merge the associated diaObjects to their previous history. + Also update the index of associatedDiaSources in place. + + Parameters + ---------- + diaObjects : `pandas.DataFrame` + Table of DiaObjects from preloaded DiaObjects. + newDiaObjects : `pandas.DataFrame` + Table of new DiaObjects after association. + + Returns + ------- + mergedDiaObjects : `pandas.DataFrame` + Table of new DiaObjects merged with their history. + + Raises + ------ + RuntimeError + Raised if duplicate DiaObjects are found. + """ + + # Append new DiaObjects to their previous history. + # Do not modify diaSources + newDiaObjects.set_index("diaObjectId", drop=False, inplace=True) + if diaObjects.empty: + mergedDiaObjects = newDiaObjects + elif not newDiaObjects.empty: + mergedDiaObjects = pd.concat([diaObjects, newDiaObjects], sort=True) + else: + mergedDiaObjects = diaObjects + if DiaPipelineTask.testDataFrameIndex(mergedDiaObjects): + raise RuntimeError("Duplicate DiaObjects created after association.") + return mergedDiaObjects + + def runForcedMeasurement(self, diaObjects, idGenerator, visit, detector): + """Forced Source Measurement + + Forced photometry on the difference and calibrated exposures using the + new and updated DiaObject locations. + + Parameters + ---------- + diaObjects : `pandas.DataFrame` + Catalog of DiaObjects. + idGenerator : `lsst.meas.base.IdGenerator` + Object that generates source IDs and random number generator seeds. + visit : `int` + Visit ID. + detector : `int` + Detector number, used for the ID generator. + Note that each instance simulates an entire focal plane. + The detector dimension is used as a convenience to allow re-using + code and infrastructure. + + Returns + ------- + diaForcedSources : `pandas.DataFrame` + Catalog of calibrated forced photometered fluxes on both the + difference and direct images at DiaObject locations. + """ + # Restrict forced source measurement to objects with sufficient history to be reliable. + objectTable = diaObjects.query(f'nDiaSources >= {self.config.historyThreshold}') + preserveColumns = ["diaObjectId", "ra", "dec"] + baseForcedSources = objectTable[preserveColumns].copy() + baseForcedSources["visit"] = visit + preserveColumns.append("visit") + baseForcedSources["detector"] = detector + preserveColumns.append("detector") + baseForcedSources["band"] = 'A' + preserveColumns.append("band") + baseForcedSources["diaForcedSourceId"] = [idGenerator() for i in range(len(baseForcedSources))] + preserveColumns.append("diaForcedSourceId") + # Fill the forced sources for each diaObject with random data + diaForcedSources = fillRandomTable(self.schema, baseForcedSources, + tableName="DiaForcedSource", + preserveColumns=preserveColumns) + self.log.info(f"Updating {len(diaForcedSources)} diaForcedSources in the APDB") + return diaForcedSources + + @timeMethod + def writeToApdb(self, updatedDiaObjects, associatedDiaSources, diaForcedSources, dateTime): + """Write to the Alert Production Database (Apdb). + + Store DiaSources, updated DiaObjects, and DiaForcedSources in the + Alert Production Database (Apdb). + + Parameters + ---------- + updatedDiaObjects : `pandas.DataFrame` + Catalog of updated DiaObjects. + associatedDiaSources : `pandas.DataFrame` + Associated DiaSources with DiaObjects. + diaForcedSources : `pandas.DataFrame` + Catalog of calibrated forced photometered fluxes on both the + difference and direct images at DiaObject locations. + """ + # Store DiaSources, updated DiaObjects, and DiaForcedSources in the + # Apdb. + # Drop empty columns that are nullable in the APDB. + diaObjectStore = dropEmptyColumns(self.schema, updatedDiaObjects, tableName="DiaObject") + if associatedDiaSources is None: + diaSourceStore = None + else: + diaSourceStore = dropEmptyColumns(self.schema, associatedDiaSources, tableName="DiaSource") + diaForcedSourceStore = dropEmptyColumns(self.schema, diaForcedSources, tableName="DiaForcedSource") + self.apdb.store( + dateTime, + diaObjectStore, + diaSourceStore, + diaForcedSourceStore) + self.log.info("APDB updated.") + + +def randomCircleXY(radius, seed, n=None): + """Generate random x, y coordinates within a circular region + + Parameters + ---------- + radius : `float` + Radius of the circular region. + seed : `int` + Seed for the random number generator. + n : `int`, optional + Number of points to generate. + + Returns + ------- + x, y : `float`, or `numpy.ndarray` + Random coordinates inside the region. + """ + # Draw the locations from a regular grid in x,y, but use polar + # coordinates to maintain a circular region in this space. + rng = np.random.RandomState(seed) + r2 = rng.random_sample(n)*radius**2 + phi = rng.random_sample(n)*2*np.pi + x = np.sqrt(r2)*np.cos(phi) + y = np.sqrt(r2)*np.sin(phi) + return (x, y) + + +def stereographicXY2RaDec(x, y): + """Convert from a grid-like stereographic projection to RA and Dec + + Notes + ----- + A stereographic projection centered on the pole + From Eq 56 of Calabretta and Greisen (2002) + "Representations of celestial coordinates in FITS" + + Parameters + ---------- + x : `float` + Column index of the stereographic grid. + y : `float` + Row index of the stereographic grid. + + Returns + ------- + ra, dec : `float` + Right Ascension and Declination. + """ + r = np.sqrt(x**2 + y**2) + dec = np.pi/2 - 2*np.arctan(r/2) + ra = np.arctan2(y, x) + return (ra, dec) + + +def stereographicRaDec2XY(ra, dec): + """Convert from Ra, Dec to a grid-like stereographic projection + + Notes + ----- + A stereographic projection centered on the pole + From Eq 56 of Calabretta and Greisen (2002) + "Representations of celestial coordinates in FITS" + + Parameters + ---------- + ra : `float` + Right Ascension (RA) in radians. + dec : `float` + Declination (Dec) in radians. + + Returns + ------- + x, y : `float` + Grid coordinates of the stereographic projection + """ + r = 2*np.cos(dec)/(1 + np.sin(dec)) + x = r*np.cos(ra) + y = r*np.sin(ra) + return (x, y) + + +def fillRandomTable(apdbSchema, sourceTable, tableName, preserveColumns=None): + """Force a table to conform to the schema defined by the APDB. + + This method uses the table definitions in ``sdm_schemas`` to + load the schema of the APDB, and does not actually connect to the APDB. + + Parameters + ---------- + apdbSchema : `dict` [`str`, `lsst.dax.apdb.schema_model.Table`] + Schema from ``sdm_schemas`` containing the table definition to use. + sourceTable : `pandas.DataFrame` + The input table to convert. + tableName : `str` + Name of the table in the schema to use. + preserveColumns : `list` of `str`, optional + List of columns to copy from the input sourceTable. + + Returns + ------- + `pandas.DataFrame` + A table with the correct schema for the APDB and data copied from + the input ``sourceTable``. + """ + table = apdbSchema[tableName] + + data = {} + nSrc = len(sourceTable) + rng = np.random.default_rng() + if preserveColumns is None: + preserveColumns = [] + + for columnDef in table.columns: + if columnDef.name in preserveColumns: + data[columnDef.name] = sourceTable[columnDef.name] + dtype = column_dtype(columnDef.datatype) + if columnDef.name in sourceTable.columns: + data[columnDef.name] = pd.Series(sourceTable[columnDef.name], dtype=dtype) + else: + try: + dataInit = rng.random(nSrc).astype(dtype) + except (TypeError, ValueError): + dataInit = np.zeros(nSrc, dtype=column_dtype(columnDef.datatype)) + data[columnDef.name] = pd.Series(dataInit, index=sourceTable.index) + + df = convertTableToSdmSchema(apdbSchema, pd.DataFrame(data), tableName=tableName) + return df diff --git a/tests/test_packageAlerts.py b/tests/test_packageAlerts.py deleted file mode 100644 index a77f5187..00000000 --- a/tests/test_packageAlerts.py +++ /dev/null @@ -1,654 +0,0 @@ -# This file is part of ap_association. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (https://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -import io -import os - -import numpy as np -import pandas as pd -import tempfile -import unittest -from unittest.mock import patch, Mock -from astropy import wcs -from astropy.nddata import CCDData -import fastavro -try: - import confluent_kafka - from confluent_kafka import KafkaException -except ImportError: - confluent_kafka = None - -import lsst.alert.packet as alertPack -from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask -from lsst.afw.cameraGeom.testUtils import DetectorWrapper -import lsst.afw.image as afwImage -import lsst.daf.base as dafBase -from lsst.dax.apdb import Apdb, ApdbSql -import lsst.geom as geom -import lsst.meas.base.tests -from lsst.sphgeom import Box -import lsst.utils.tests -import utils_tests - - -def _roundTripThroughApdb(objects, sources, forcedSources, dateTime): - """Run object and source catalogs through the Apdb to get the correct - table schemas. - - Parameters - ---------- - objects : `pandas.DataFrame` - Set of test DiaObjects to round trip. - sources : `pandas.DataFrame` - Set of test DiaSources to round trip. - forcedSources : `pandas.DataFrame` - Set of test DiaForcedSources to round trip. - dateTime : `astropy.time.Time` - Time for the Apdb. - - Returns - ------- - objects : `pandas.DataFrame` - Round tripped objects. - sources : `pandas.DataFrame` - Round tripped sources. - """ - with tempfile.NamedTemporaryFile() as tmpFile: - apdbConfig = ApdbSql.init_database(db_url="sqlite:///" + tmpFile.name) - apdb = Apdb.from_config(apdbConfig) - - wholeSky = Box.full() - loadedObjects = apdb.getDiaObjects(wholeSky) - if loadedObjects.empty: - diaObjects = objects - else: - diaObjects = pd.concat([loadedObjects, objects]) - loadedDiaSources = apdb.getDiaSources(wholeSky, [], dateTime) - if loadedDiaSources.empty: - diaSources = sources - else: - diaSources = pd.concat([loadedDiaSources, sources]) - loadedDiaForcedSources = apdb.getDiaForcedSources(wholeSky, [], dateTime) - if loadedDiaForcedSources.empty: - diaForcedSources = forcedSources - else: - diaForcedSources = pd.concat([loadedDiaForcedSources, forcedSources]) - - apdb.store(dateTime, diaObjects, diaSources, diaForcedSources) - - diaObjects = apdb.getDiaObjects(wholeSky) - diaSources = apdb.getDiaSources(wholeSky, - np.unique(diaObjects["diaObjectId"]), - dateTime) - diaForcedSources = apdb.getDiaForcedSources( - wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime) - - diaObjects.set_index("diaObjectId", drop=False, inplace=True) - diaSources.set_index(["diaObjectId", "band", "diaSourceId"], - drop=False, - inplace=True) - diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True) - - return (diaObjects, diaSources, diaForcedSources) - - -VISIT = 2 -DETECTOR = 42 - - -def mock_alert(alert_id): - """Generate a minimal mock alert. - """ - return { - "alertId": alert_id, - "diaSource": { - "midpointMjdTai": 5, - "diaSourceId": 1234, - "visit": VISIT, - "detector": DETECTOR, - "band": 'g', - "ra": 12.5, - "dec": -16.9, - # These types are 32-bit floats in the avro schema, so we have to - # make them that type here, so that they round trip appropriately. - "x": np.float32(15.7), - "y": np.float32(89.8), - "apFlux": np.float32(54.85), - "apFluxErr": np.float32(70.0), - "snr": np.float32(6.7), - "psfFlux": np.float32(700.0), - "psfFluxErr": np.float32(90.0), - } - } - - -def _deserialize_alert(alert_bytes): - """Deserialize an alert message from Kafka. - - Parameters - ---------- - alert_bytes : `bytes` - Binary-encoding serialized Avro alert, including Confluent Wire - Format prefix. - - Returns - ------- - alert : `dict` - An alert payload. - """ - schema = alertPack.Schema.from_uri(str(alertPack.get_uri_to_latest_schema())) - content_bytes = io.BytesIO(alert_bytes[5:]) - - return fastavro.schemaless_reader(content_bytes, schema.definition) - - -class TestPackageAlerts(lsst.utils.tests.TestCase): - def setUp(self): - # Create an instance of random generator with fixed seed. - rng = np.random.default_rng(1234) - - patcher = patch.dict(os.environ, {"AP_KAFKA_PRODUCER_PASSWORD": "fake_password", - "AP_KAFKA_PRODUCER_USERNAME": "fake_username", - "AP_KAFKA_SERVER": "fake_server", - "AP_KAFKA_TOPIC": "fake_topic"}) - self.environ = patcher.start() - self.addCleanup(patcher.stop) - self.cutoutSize = 35 - self.center = lsst.geom.Point2D(50.1, 49.8) - self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30), - lsst.geom.Extent2I(140, 160)) - self.dataset = lsst.meas.base.tests.TestDataset(self.bbox) - self.dataset.addSource(100000.0, self.center) - exposure, catalog = self.dataset.realize( - 10.0, - self.dataset.makeMinimalSchema(), - randomSeed=1234) - self.exposure = exposure - detector = DetectorWrapper(id=DETECTOR, bbox=exposure.getBBox()).detector - self.exposure.setDetector(detector) - - visit = afwImage.VisitInfo( - id=VISIT, - exposureTime=200., - date=dafBase.DateTime("2014-05-13T17:00:00.000000000", - dafBase.DateTime.Timescale.TAI)) - self.exposure.info.id = 1234 - self.exposure.info.setVisitInfo(visit) - - self.exposure.setFilter( - afwImage.FilterLabel(band='g', physical="g.MP9401")) - - diaObjects = utils_tests.makeDiaObjects(2, self.exposure, rng) - diaSourceHistory = utils_tests.makeDiaSources( - 10, diaObjects["diaObjectId"].to_numpy(), self.exposure, rng) - diaForcedSources = utils_tests.makeDiaForcedSources( - 10, diaObjects["diaObjectId"].to_numpy(), self.exposure, rng) - self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb( - diaObjects, - diaSourceHistory, - diaForcedSources, - self.exposure.visitInfo.date.toAstropy()) - diaSourceHistory["programId"] = 0 - - self.diaSources = diaSourceHistory.loc[[(1, "g", 9), (2, "g", 10)], :] - self.diaSources["bboxSize"] = self.cutoutSize - self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9), - (2, "g", 10)]) - - self.cutoutWcs = wcs.WCS(naxis=2) - self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]] - self.cutoutWcs.wcs.crval = [ - self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(), - self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()] - self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix() - self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] - - def testCreateExtentMinimum(self): - """Test the extent creation for the cutout bbox returns a cutout with - the minimum cutouut size. - """ - packConfig = PackageAlertsConfig() - # Just create a minimum less than the default cutout size. - packConfig.minCutoutSize = self.cutoutSize - 5 - packageAlerts = PackageAlertsTask(config=packConfig) - extent = packageAlerts.createDiaSourceExtent( - packConfig.minCutoutSize - 5) - self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize, - packConfig.minCutoutSize)) - # Test that the cutout size is correctly increased. - extent = packageAlerts.createDiaSourceExtent(self.cutoutSize) - self.assertTrue(extent == geom.Extent2I(self.cutoutSize, - self.cutoutSize)) - - def testCreateExtentMaximum(self): - """Test the extent creation for the cutout bbox returns a cutout with - the maximum cutout size. - """ - packConfig = PackageAlertsConfig() - # Just create a maximum more than the default cutout size. - packConfig.maxCutoutSize = self.cutoutSize + 5 - packageAlerts = PackageAlertsTask(config=packConfig) - extent = packageAlerts.createDiaSourceExtent( - packConfig.maxCutoutSize + 5) - self.assertTrue(extent == geom.Extent2I(packConfig.maxCutoutSize, - packConfig.maxCutoutSize)) - # Test that the cutout size is correctly reduced. - extent = packageAlerts.createDiaSourceExtent(self.cutoutSize) - self.assertTrue(extent == geom.Extent2I(self.cutoutSize, - self.cutoutSize)) - - def testCreateCcdDataCutout(self): - """Test that the data is being extracted into the CCDData cutout - correctly. - """ - packageAlerts = PackageAlertsTask() - - diaSrcId = 1234 - ccdData = packageAlerts.createCcdDataCutout( - self.exposure, - self.exposure.getWcs().getSkyOrigin(), - self.exposure.getWcs().getPixelOrigin(), - self.exposure.getBBox().getDimensions(), - self.exposure.getPhotoCalib(), - diaSrcId) - calibExposure = self.exposure.getPhotoCalib().calibrateImage( - self.exposure.getMaskedImage()) - - self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd, - self.cutoutWcs.wcs.cd) - self.assertFloatsAlmostEqual(ccdData.data, - calibExposure.getImage().array) - self.assertFloatsAlmostEqual(ccdData.psf, - self.exposure.psf.computeKernelImage(self.center).array) - - ccdData = packageAlerts.createCcdDataCutout( - self.exposure, - geom.SpherePoint(0, 0, geom.degrees), - geom.Point2D(-100, -100), - self.exposure.getBBox().getDimensions(), - self.exposure.getPhotoCalib(), - diaSrcId) - self.assertTrue(ccdData is None) - - def testMakeLocalTransformMatrix(self): - """Test that the local WCS approximation is correct. - """ - packageAlerts = PackageAlertsTask() - - sphPoint = self.exposure.getWcs().pixelToSky(self.center) - cutout = self.exposure.getCutout(sphPoint, - geom.Extent2I(self.cutoutSize, - self.cutoutSize)) - cd = packageAlerts.makeLocalTransformMatrix( - cutout.getWcs(), self.center, sphPoint) - self.assertFloatsAlmostEqual( - cd, - cutout.getWcs().getCdMatrix(), - rtol=1e-11, - atol=1e-11) - - def testStreamCcdDataToBytes(self): - """Test round tripping an CCDData cutout to bytes and back. - """ - packageAlerts = PackageAlertsTask() - - sphPoint = self.exposure.getWcs().pixelToSky(self.center) - cutout = self.exposure.getCutout(sphPoint, - geom.Extent2I(self.cutoutSize, - self.cutoutSize)) - cutoutCcdData = CCDData( - data=cutout.getImage().array, - wcs=self.cutoutWcs, - unit="adu") - - cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData) - with io.BytesIO(cutoutBytes) as bytesIO: - cutoutFromBytes = CCDData.read(bytesIO, format="fits") - self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data) - - def testMakeAlertDict(self): - """Test stripping data from the various data products and into a - dictionary "alert". - """ - packageAlerts = PackageAlertsTask() - alertId = 1234 - - for srcIdx, diaSource in self.diaSources.iterrows(): - sphPoint = geom.SpherePoint(diaSource["ra"], - diaSource["dec"], - geom.degrees) - pixelPoint = geom.Point2D(diaSource["x"], diaSource["y"]) - cutout = self.exposure.getCutout(sphPoint, - geom.Extent2I(self.cutoutSize, - self.cutoutSize)) - ccdCutout = packageAlerts.createCcdDataCutout( - cutout, - sphPoint, - pixelPoint, - geom.Extent2I(self.cutoutSize, self.cutoutSize), - cutout.getPhotoCalib(), - 1234) - cutoutBytes = packageAlerts.streamCcdDataToBytes( - ccdCutout) - objSources = self.diaSourceHistory.loc[srcIdx[0]] - objForcedSources = self.diaForcedSources.loc[srcIdx[0]] - alert = packageAlerts.makeAlertDict( - alertId, - diaSource, - self.diaObjects.loc[srcIdx[0]], - objSources, - objForcedSources, - ccdCutout, - ccdCutout, - ccdCutout) - self.assertEqual(len(alert), 10) - - self.assertEqual(alert["alertId"], alertId) - self.assertEqual(alert["diaSource"], diaSource.to_dict()) - self.assertEqual(alert["cutoutDifference"], - cutoutBytes) - self.assertEqual(alert["cutoutScience"], - cutoutBytes) - self.assertEqual(alert["cutoutTemplate"], - cutoutBytes) - - @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') - def test_produceAlerts_empty_password(self): - """ Test that produceAlerts raises if the password is empty or missing. - """ - self.environ['AP_KAFKA_PRODUCER_PASSWORD'] = "" - with self.assertRaisesRegex(ValueError, "Kafka password"): - packConfig = PackageAlertsConfig(doProduceAlerts=True) - PackageAlertsTask(config=packConfig) - - del self.environ['AP_KAFKA_PRODUCER_PASSWORD'] - with self.assertRaisesRegex(ValueError, "Kafka password"): - packConfig = PackageAlertsConfig(doProduceAlerts=True) - PackageAlertsTask(config=packConfig) - - @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') - def test_produceAlerts_empty_username(self): - """ Test that produceAlerts raises if the username is empty or missing. - """ - self.environ['AP_KAFKA_PRODUCER_USERNAME'] = "" - with self.assertRaisesRegex(ValueError, "Kafka username"): - packConfig = PackageAlertsConfig(doProduceAlerts=True) - PackageAlertsTask(config=packConfig) - - del self.environ['AP_KAFKA_PRODUCER_USERNAME'] - with self.assertRaisesRegex(ValueError, "Kafka username"): - packConfig = PackageAlertsConfig(doProduceAlerts=True) - PackageAlertsTask(config=packConfig) - - @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') - def test_produceAlerts_empty_server(self): - """ Test that produceAlerts raises if the server is empty or missing. - """ - self.environ['AP_KAFKA_SERVER'] = "" - with self.assertRaisesRegex(ValueError, "Kafka server"): - packConfig = PackageAlertsConfig(doProduceAlerts=True) - PackageAlertsTask(config=packConfig) - - del self.environ['AP_KAFKA_SERVER'] - with self.assertRaisesRegex(ValueError, "Kafka server"): - packConfig = PackageAlertsConfig(doProduceAlerts=True) - PackageAlertsTask(config=packConfig) - - @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') - def test_produceAlerts_empty_topic(self): - """ Test that produceAlerts raises if the topic is empty or missing. - """ - self.environ['AP_KAFKA_TOPIC'] = "" - with self.assertRaisesRegex(ValueError, "Kafka topic"): - packConfig = PackageAlertsConfig(doProduceAlerts=True) - PackageAlertsTask(config=packConfig) - - del self.environ['AP_KAFKA_TOPIC'] - with self.assertRaisesRegex(ValueError, "Kafka topic"): - packConfig = PackageAlertsConfig(doProduceAlerts=True) - PackageAlertsTask(config=packConfig) - - @patch('confluent_kafka.Producer') - @patch.object(PackageAlertsTask, '_server_check') - @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') - def test_produceAlerts_success(self, mock_server_check, mock_producer): - """ Test that produceAlerts calls the producer on all provided alerts - when the alerts are all under the batch size limit. - """ - packConfig = PackageAlertsConfig(doProduceAlerts=True) - packageAlerts = PackageAlertsTask(config=packConfig) - alerts = [mock_alert(1), mock_alert(2)] - - # Create a variable and assign it an instance of the patched kafka producer - producer_instance = mock_producer.return_value - producer_instance.produce = Mock() - producer_instance.flush = Mock() - unix_midpoint = self.exposure.visitInfo.date.toAstropy().tai.unix - exposure_time = self.exposure.visitInfo.exposureTime - packageAlerts.produceAlerts(alerts, VISIT, DETECTOR, unix_midpoint, exposure_time) - - self.assertEqual(mock_server_check.call_count, 1) - self.assertEqual(producer_instance.produce.call_count, len(alerts)) - self.assertEqual(producer_instance.flush.call_count, len(alerts)+1) - - @patch('confluent_kafka.Producer') - @patch.object(PackageAlertsTask, '_server_check') - @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') - def test_produceAlerts_one_failure(self, mock_server_check, mock_producer): - """ Test that produceAlerts correctly fails on one alert - and is writing the failure to disk. - """ - counter = 0 - - def mock_produce(*args, **kwargs): - nonlocal counter - counter += 1 - if counter == 2: - raise KafkaException - else: - return - - packConfig = PackageAlertsConfig(doProduceAlerts=True, doWriteFailedAlerts=True) - packageAlerts = PackageAlertsTask(config=packConfig) - - patcher = patch("builtins.open") - patch_open = patcher.start() - alerts = [mock_alert(1), mock_alert(2), mock_alert(3)] - unix_midpoint = self.exposure.visitInfo.date.toAstropy().tai.unix - exposure_time = self.exposure.visitInfo.exposureTime - - producer_instance = mock_producer.return_value - producer_instance.produce = Mock(side_effect=mock_produce) - producer_instance.flush = Mock() - packageAlerts.produceAlerts(alerts, VISIT, DETECTOR, unix_midpoint, exposure_time) - - self.assertEqual(mock_server_check.call_count, 1) - self.assertEqual(producer_instance.produce.call_count, len(alerts)) - self.assertEqual(patch_open.call_count, 1) - self.assertIn(f"{VISIT}_{DETECTOR}_2.avro", patch_open.call_args.args[0]) - # Because one produce raises, we call flush one fewer times than in the success - # test above. - self.assertEqual(producer_instance.flush.call_count, len(alerts)) - patcher.stop() - - @patch.object(PackageAlertsTask, '_server_check') - def testRun_without_produce(self, mock_server_check): - """Test the run method of package alerts with produce set to False and - doWriteAlerts set to true. - """ - packConfig = PackageAlertsConfig(doWriteAlerts=True) - with tempfile.TemporaryDirectory(prefix='alerts') as tempdir: - packConfig.alertWriteLocation = tempdir - packageAlerts = PackageAlertsTask(config=packConfig) - - packageAlerts.run(self.diaSources, - self.diaObjects, - self.diaSourceHistory, - self.diaForcedSources, - self.exposure, - self.exposure, - self.exposure) - - self.assertEqual(mock_server_check.call_count, 0) - - with open(os.path.join(tempdir, f"{VISIT}_{DETECTOR}.avro"), 'rb') as f: - writer_schema, data_stream = \ - packageAlerts.alertSchema.retrieve_alerts(f) - data = list(data_stream) - - self.assertEqual(len(data), len(self.diaSources)) - for idx, alert in enumerate(data): - for key, value in alert["diaSource"].items(): - if isinstance(value, float): - if np.isnan(self.diaSources.iloc[idx][key]): - self.assertTrue(np.isnan(value)) - else: - self.assertAlmostEqual( - 1 - value / self.diaSources.iloc[idx][key], - 0.) - else: - self.assertEqual(value, self.diaSources.iloc[idx][key]) - sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], - alert["diaSource"]["dec"], - geom.degrees) - pixelPoint = geom.Point2D(alert["diaSource"]["x"], alert["diaSource"]["y"]) - cutout = self.exposure.getCutout(sphPoint, - geom.Extent2I(self.cutoutSize, - self.cutoutSize)) - ccdCutout = packageAlerts.createCcdDataCutout( - cutout, - sphPoint, - pixelPoint, - geom.Extent2I(self.cutoutSize, self.cutoutSize), - cutout.getPhotoCalib(), - 1234) - self.assertEqual(alert["cutoutDifference"], - packageAlerts.streamCcdDataToBytes(ccdCutout)) - - @patch.object(PackageAlertsTask, '_server_check') - def testRun_without_produce_use_averagePsf(self, mock_server_check): - """Test the run method of package alerts with produce set to False and - doWriteAlerts set to true. - """ - packConfig = PackageAlertsConfig(doWriteAlerts=True) - with tempfile.TemporaryDirectory(prefix='alerts') as tempdir: - packConfig.alertWriteLocation = tempdir - packConfig.useAveragePsf = True - packageAlerts = PackageAlertsTask(config=packConfig) - - packageAlerts.run(self.diaSources, - self.diaObjects, - self.diaSourceHistory, - self.diaForcedSources, - self.exposure, - self.exposure, - self.exposure) - - self.assertEqual(mock_server_check.call_count, 0) - - with open(os.path.join(tempdir, f"{VISIT}_{DETECTOR}.avro"), 'rb') as f: - writer_schema, data_stream = \ - packageAlerts.alertSchema.retrieve_alerts(f) - data = list(data_stream) - - self.assertEqual(len(data), len(self.diaSources)) - for idx, alert in enumerate(data): - for key, value in alert["diaSource"].items(): - if isinstance(value, float): - if np.isnan(self.diaSources.iloc[idx][key]): - self.assertTrue(np.isnan(value)) - else: - self.assertAlmostEqual( - 1 - value / self.diaSources.iloc[idx][key], - 0.) - else: - self.assertEqual(value, self.diaSources.iloc[idx][key]) - sphPoint = geom.SpherePoint(alert["diaSource"]["ra"], - alert["diaSource"]["dec"], - geom.degrees) - pixelPoint = geom.Point2D(alert["diaSource"]["x"], alert["diaSource"]["y"]) - cutout = self.exposure.getCutout(sphPoint, - geom.Extent2I(self.cutoutSize, - self.cutoutSize)) - ccdCutout = packageAlerts.createCcdDataCutout( - cutout, - sphPoint, - pixelPoint, - geom.Extent2I(self.cutoutSize, self.cutoutSize), - cutout.getPhotoCalib(), - 1234) - self.assertEqual(alert["cutoutDifference"], - packageAlerts.streamCcdDataToBytes(ccdCutout)) - - @patch.object(PackageAlertsTask, 'produceAlerts') - @patch('confluent_kafka.Producer') - @patch.object(PackageAlertsTask, '_server_check') - @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') - def testRun_with_produce(self, mock_produceAlerts, mock_server_check, mock_producer): - """Test that packageAlerts calls produceAlerts when doProduceAlerts - is set to True. - """ - packConfig = PackageAlertsConfig(doProduceAlerts=True) - packageAlerts = PackageAlertsTask(config=packConfig) - - packageAlerts.run(self.diaSources, - self.diaObjects, - self.diaSourceHistory, - self.diaForcedSources, - self.exposure, - self.exposure, - self.exposure) - self.assertEqual(mock_server_check.call_count, 1) - self.assertEqual(mock_produceAlerts.call_count, 1) - - def test_serialize_alert_round_trip(self): - """Test that values in the alert packet exactly round trip. - """ - packClass = PackageAlertsConfig() - packageAlerts = PackageAlertsTask(config=packClass) - - alert = mock_alert(1) - serialized = PackageAlertsTask._serializeAlert(packageAlerts, alert) - deserialized = _deserialize_alert(serialized) - - for field in alert['diaSource']: - self.assertEqual(alert['diaSource'][field], deserialized['diaSource'][field]) - self.assertEqual(1, deserialized["alertId"]) - - @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled') - def test_server_check(self): - - with self.assertRaisesRegex(KafkaException, "_TRANSPORT"): - packConfig = PackageAlertsConfig(doProduceAlerts=True) - PackageAlertsTask(config=packConfig) - - -class MemoryTester(lsst.utils.tests.MemoryTestCase): - pass - - -def setup_module(module): - lsst.utils.tests.init() - - -if __name__ == "__main__": - lsst.utils.tests.init() - unittest.main()