diff --git a/python/lsst/ap/association/__init__.py b/python/lsst/ap/association/__init__.py
index f337d225..b2306b55 100644
--- a/python/lsst/ap/association/__init__.py
+++ b/python/lsst/ap/association/__init__.py
@@ -30,3 +30,4 @@
from .version import *
from .mpSkyEphemerisQuery import *
from .ssSingleFrameAssociation import *
+from .testApdb import *
diff --git a/python/lsst/ap/association/diaPipe.py b/python/lsst/ap/association/diaPipe.py
index 5856f77a..c62ea19c 100644
--- a/python/lsst/ap/association/diaPipe.py
+++ b/python/lsst/ap/association/diaPipe.py
@@ -789,7 +789,8 @@ def writeToApdb(self, updatedDiaObjects, associatedDiaSources, diaForcedSources)
diaForcedSourceStore)
self.log.info("APDB updated.")
- def testDataFrameIndex(self, df):
+ @staticmethod
+ def testDataFrameIndex(df):
"""Test the sorted DataFrame index for duplicates.
Wrapped as a separate function to allow for mocking of the this task
@@ -922,7 +923,8 @@ def mergeCatalogs(self, originalCatalog, newCatalog, catalogName):
mergedCatalog = pd.concat([originalCatalog], sort=True)
return mergedCatalog.loc[:, originalCatalog.columns]
- def updateObjectTable(self, diaObjects, diaSources):
+ @staticmethod
+ def updateObjectTable(diaObjects, diaSources):
"""Update the diaObject table with the new diaSource records.
Parameters
diff --git a/python/lsst/ap/association/testApdb.py b/python/lsst/ap/association/testApdb.py
new file mode 100644
index 00000000..feccc3f2
--- /dev/null
+++ b/python/lsst/ap/association/testApdb.py
@@ -0,0 +1,778 @@
+#
+# LSST Data Management System
+# Copyright 2008-2016 AURA/LSST.
+#
+# This product includes software developed by the
+# LSST Project (http://www.lsst.org/).
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the LSST License Statement and
+# the GNU General Public License along with this program. If not,
+# see .
+#
+
+"""Standalone pipelinetask to populate an APDB with simulated data.
+"""
+
+__all__ = ("TestApdbConfig",
+ "TestApdbTask",
+ )
+
+
+import numpy as np
+import pandas as pd
+import time
+
+from lsst.daf.base import DateTime
+import lsst.dax.apdb as daxApdb
+import lsst.geom
+from lsst.meas.base import DetectorVisitIdGeneratorConfig, IdGenerator
+import lsst.pex.config as pexConfig
+import lsst.pipe.base as pipeBase
+from lsst.ap.association.loadDiaCatalogs import LoadDiaCatalogsTask
+from lsst.ap.association.diaPipe import DiaPipelineTask
+from lsst.ap.association.utils import (
+ convertTableToSdmSchema,
+ readSchemaFromApdb,
+ column_dtype,
+ make_empty_catalog,
+ dropEmptyColumns,
+)
+import lsst.sphgeom
+from lsst.utils.timer import timeMethod
+
+
+class TestApdbConnections(
+ pipeBase.PipelineTaskConnections,
+ dimensions=("instrument", "visit", "detector")):
+ """Butler connections for TestApdbTask.
+ """
+
+ apdbTestMarker = pipeBase.connectionTypes.Output(
+ doc="Marker dataset storing the configuration of the Apdb for each "
+ "visit/detector. Used to signal the completion of the pipeline.",
+ name="apdbTest_marker",
+ storageClass="Config",
+ dimensions=("instrument", "visit", "detector"),
+ )
+
+
+class TestApdbConfig(pipeBase.PipelineTaskConfig,
+ pipelineConnections=TestApdbConnections):
+ """Config for TestApdbTask.
+ """
+ apdb_config_url = pexConfig.Field(
+ dtype=str,
+ default=None,
+ optional=False,
+ doc="A config file specifying the APDB and its connection parameters, "
+ "typically written by the apdb-cli command-line utility. "
+ "The database must already be initialized.",
+ )
+ survey_area = pexConfig.Field(
+ dtype=float,
+ default=20000,
+ doc="Area (in degrees) of the simulated survey",
+ )
+ fov = pexConfig.Field(
+ dtype=float,
+ default=9.6,
+ doc="Field of view of the camera, in square degrees.",
+ )
+ stellar_density = pexConfig.Field(
+ dtype=float,
+ default=1750,
+ doc="Average number of real transient and variable objects"
+ " detected per square degree. For this simulation, these will"
+ " always be detected, and detected in the same location."
+ "The default is chosen such that:"
+ " Density x Rubin Fov (9.6) x # of visits per night (~600) ~ 10M",
+ )
+ false_positive_ratio = pexConfig.Field(
+ dtype=float,
+ default=2,
+ doc="Average ratio of false detections to real sources."
+ "These will be detected in random locations.",
+ )
+ false_positive_variability = pexConfig.Field(
+ dtype=float,
+ default=100,
+ doc="Parameter characterizing the variability in the rate of false"
+ " positives.",
+ )
+ sky_seed = pexConfig.Field(
+ dtype=int,
+ default=37,
+ doc="Seed used to simulate the real sources.",
+ )
+ historyThreshold = pexConfig.Field(
+ dtype=int,
+ doc="Minimum number of detections of a diaObject required "
+ "to run forced photometry. Set to 1 to include all diaObjects.",
+ default=2,
+ )
+ objId_start = pexConfig.Field(
+ dtype=int,
+ default=1,
+ doc="Starting diaObject ID number of real objects.",
+ )
+ maximum_table_length = pexConfig.Field(
+ dtype=int,
+ default=65535,
+ doc="Maximum length of tables allowed to be written in one operation"
+ " to the Cassandra APDB",
+ )
+ raise_on_associated_fakes = pexConfig.Field(
+ dtype=bool,
+ default=True,
+ doc="Raise a RuntimeError if any fake sources are associated"
+ " with existing DiaObjects. This is likely due to restarting"
+ " the simulation after partial writes to the APDB",
+ )
+
+ idGenerator = DetectorVisitIdGeneratorConfig.make_field()
+ idGeneratorFakes = DetectorVisitIdGeneratorConfig.make_field()
+ idGeneratorForced = DetectorVisitIdGeneratorConfig.make_field()
+
+ def setDefaults(self):
+ self.idGenerator = DetectorVisitIdGeneratorConfig(release_id=0, n_releases=3)
+ self.idGeneratorFakes = DetectorVisitIdGeneratorConfig(release_id=1, n_releases=3)
+ self.idGeneratorForced = DetectorVisitIdGeneratorConfig(release_id=2, n_releases=3)
+
+
+class TestApdbTask(LoadDiaCatalogsTask):
+ """Task for loading, associating and storing Difference Image Analysis
+ (DIA) Objects and Sources.
+ """
+ ConfigClass = TestApdbConfig
+ _DefaultName = "apdbTest"
+
+ def __init__(self, initInputs=None, **kwargs):
+ super().__init__(**kwargs)
+ self.apdb = daxApdb.Apdb.from_uri(self.config.apdb_config_url)
+ self.schema = readSchemaFromApdb(self.apdb)
+ self.prepareSurvey()
+
+ def runQuantum(self, butlerQC, inputRefs, outputRefs):
+ inputs = butlerQC.get(inputRefs)
+ inputs["visit"] = butlerQC.quantum.dataId["visit"]
+ inputs["detector"] = butlerQC.quantum.dataId["detector"]
+ inputs["idGenerator"] = self.config.idGenerator.apply(butlerQC.quantum.dataId)
+ inputs["idGeneratorFakes"] = self.config.idGeneratorFakes.apply(butlerQC.quantum.dataId)
+ inputs["idGeneratorForced"] = self.config.idGeneratorForced.apply(butlerQC.quantum.dataId)
+ self.run(**inputs)
+ # Note that the commented-out code below is intentionally omitted
+ # We don't want this to write anything to the Butler!
+ butlerQC.put(pipeBase.Struct(), outputRefs)
+
+ def prepareSurvey(self):
+ """Prepare survey parameters common to all observations
+ """
+ deg2ToSteradian = np.deg2rad(1)**2 # Conversion factor from square degrees to steradians
+ sphereDegrees = 4*np.pi/deg2ToSteradian # area of a unit sphere in square degrees
+ sphereFraction = self.config.survey_area/sphereDegrees
+ survey_area = self.config.survey_area*deg2ToSteradian
+ # Avoid the branch cut of the arcsin function at the equator
+ # Calculate the angle from the pole needed to cover `survey_area` square
+ # degrees of a sphere
+ if sphereFraction < 0.5:
+ declinationMax = np.arcsin(1 - survey_area/(2*np.pi))
+ else:
+ declinationMax = np.pi - np.arcsin(1 - (4*np.pi - survey_area)/(2*np.pi))
+ self.radiusMax = stereographicRaDec2XY(0, declinationMax)[0]
+ self.nReal = int(self.config.survey_area*self.config.stellar_density)
+ # Radius of the focal plane, in radians
+ self.fpRadius = np.radians(np.sqrt(self.config.fov/np.pi))
+
+ def run(self, visit, detector,
+ idGenerator=IdGenerator(),
+ idGeneratorFakes=IdGenerator(),
+ idGeneratorForced=IdGenerator()):
+ """Generate a full focal plane simulation of real and fake sources,
+ run association and write to the APDB
+
+ Parameters
+ ----------
+ visit : `int`
+ Visit ID.
+ detector : `int`
+ Detector number, used for the ID generator.
+ Note that each instance simulates an entire focal plane.
+ The detector dimension is used as a convenience to allow re-using
+ code and infrastructure.
+ idGenerator : `lsst.meas.base,IdGenerator`, optional
+ ID generator used for "real" sources.
+ idGeneratorFakes : `lsst.meas.base,IdGenerator`, optional
+ ID generator for random "fake" sources.
+ Separate ID generators are needed to avoid overflows between visits,
+ since an entire focal plane is simulated per instance.
+ idGeneratorForced : `lsst.meas.base,IdGenerator`, optional
+ ID generator for forced sources at existing diaObject locations.
+ """
+ t_sim0 = time.time()
+ idGen = idGenerator.make_table_id_factory()
+ idGenFakes = idGeneratorFakes.make_table_id_factory()
+ idGenForced = idGeneratorForced.make_table_id_factory()
+ seed = int(visit*1000 + detector)
+ x, y = randomCircleXY(self.radiusMax, seed)
+ ra, dec = stereographicXY2RaDec(x, y)
+ if ra < 0:
+ ra += 2*np.pi
+ delta0, _ = stereographicRaDec2XY(0, dec - self.fpRadius)
+ delta1, _ = stereographicRaDec2XY(0, dec + self.fpRadius)
+ delta = abs(delta1 - delta0)/2.
+ x0 = x - delta
+ x1 = x + delta
+ y0 = y - delta
+ y1 = y + delta
+ ra0, dec0 = stereographicXY2RaDec(x0, y0)
+ ra1, dec1 = stereographicXY2RaDec(x1, y1)
+ region = lsst.sphgeom.ConvexPolygon([lsst.geom.SpherePoint(ra0, dec0, lsst.geom.radians).getVector(),
+ lsst.geom.SpherePoint(ra0, dec1, lsst.geom.radians).getVector(),
+ lsst.geom.SpherePoint(ra1, dec1, lsst.geom.radians).getVector(),
+ lsst.geom.SpherePoint(ra1, dec0, lsst.geom.radians).getVector()]
+ )
+ self.log.info(f"Simulating {self.nReal} sources around RA={np.degrees(ra)}, Dec={np.degrees(dec)}")
+ xS, yS = randomCircleXY(self.radiusMax, self.config.sky_seed, n=self.nReal)
+ diaObjIds = np.arange(self.config.objId_start, self.nReal + self.config.objId_start)
+ inds = (xS > x0) & (xS < x1) & (yS > y0) & (yS < y1)
+ nSim = np.sum(inds)
+ diaObjIds = diaObjIds[inds]
+ self.log.info(f"{nSim} sources made the spatial cut")
+ diaSourcesReal = self.createDiaSources(*stereographicXY2RaDec(xS[inds], yS[inds]),
+ idGenerator=idGen,
+ visit=visit,
+ detector=detector,
+ diaObjectIds=diaObjIds)
+
+ rng = np.random.RandomState(seed)
+ scale = nSim*self.config.false_positive_ratio/self.config.false_positive_variability
+ nBogus = int(rng.standard_gamma(scale)*self.config.false_positive_variability)
+
+ diaSourcesBogus = self.createDiaSources(*self.generateFalseDetections(x0, x1, y0, y1, nBogus, seed),
+ visit=visit,
+ detector=detector,
+ idGenerator=idGenFakes)
+ diaSourcesRaw = pd.concat([diaSourcesReal, diaSourcesBogus])
+
+ diaSources = convertTableToSdmSchema(self.schema, diaSourcesRaw, tableName="DiaSource")
+ t_sim1 = time.time()
+ self.log.info("Simulated source timing: %.2fs", t_sim1 - t_sim0)
+
+ t_load0 = time.time()
+ diaObjects = self.loadDiaObjects(region.getBoundingCircle(), self.schema)
+ t_load1 = time.time()
+ self.log.info("diaObject load timing: %.2fs", t_load1 - t_load0)
+
+ if diaObjects.empty:
+ self.log.info("diaObjects contain 0 diaSources (empty)")
+ else:
+ nDiaSources = diaObjects.nDiaSources
+ self.log.info(f"diaObjects contain {np.min(nDiaSources)} to {np.max(nDiaSources)} diaSources")
+
+ # Associate DiaSources with DiaObjects
+ associatedDiaSources, newDiaObjects, associatedDiaObjects = self.associateDiaSources(diaSources,
+ diaObjects,
+ diaSourcesReal,
+ diaSourcesBogus,
+ )
+ # Merge new and preloaded diaObjects
+ mergedDiaObjects = self.mergeAssociatedCatalogs(associatedDiaObjects, newDiaObjects)
+
+ nObj = len(mergedDiaObjects)
+ nSrc = len(associatedDiaSources)
+ dateTime = DateTime.now().toAstropy()
+ ind = 0
+ t_write0 = time.time()
+ # Note that nObj must always be equal to or greater than nSrc
+ for start in range(0, nObj, self.config.maximum_table_length):
+ end = min(start + self.config.maximum_table_length, nObj)
+ diaObjectsChunk = mergedDiaObjects.iloc[start:end]
+ nObjChunk = len(diaObjectsChunk)
+ self.log.info(f"Writing diaObject chunk {ind} of length {nObjChunk} (of {nObj}) to the APDB")
+ srcEnd = min(start + self.config.maximum_table_length, nSrc)
+ if srcEnd <= start:
+ finalDiaSources = None
+ else:
+ diaSourcesChunk = associatedDiaSources.iloc[start:srcEnd]
+ finalDiaSources = convertTableToSdmSchema(self.schema, diaSourcesChunk, tableName="DiaSource")
+ nSrcChunk = len(diaSourcesChunk)
+ self.log.info(f"Writing diaSource chunk {ind} of length {nSrcChunk} (of {nSrc}) to the APDB")
+ diaForcedSources = self.runForcedMeasurement(diaObjectsChunk, idGenForced, visit, detector)
+
+ finalDiaObjects = convertTableToSdmSchema(self.schema, diaObjectsChunk, tableName="DiaObject")
+ finalDiaForcedSources = convertTableToSdmSchema(self.schema, diaForcedSources,
+ tableName="DiaForcedSource")
+ self.log.info(f"Writing forced source chunk {ind} of length"
+ f" {len(finalDiaForcedSources)} to the APDB")
+ self.writeToApdb(finalDiaObjects, finalDiaSources, finalDiaForcedSources, dateTime)
+ ind += 1
+ t_write1 = time.time()
+ self.log.info("APDB write timing: %.2fs", t_write1 - t_write0)
+ marker = pexConfig.Config()
+ return pipeBase.Struct(apdbTestMarker=marker)
+
+ def createDiaSources(self, raVals, decVals, idGenerator, visit, detector, diaObjectIds=None):
+ """Create diaSources with the supplied coordinates.
+
+ Parameters
+ ----------
+ raVals : `numpy.ndarray`
+ Right Ascension (RA) values of the desired sources.
+ decVals : `numpy.ndarray`
+ Declination (Dec) values of the desired sources.
+ idGenerator : `lsst.meas.base,IdGenerator`, optional
+ ID generator used for the source IDs
+ diaObjectIds : `numpy.ndarray`, optional
+ diaObjectIds to assign to the sources. If not supplied, the object
+ ID is set to the diaSourceId
+
+ Returns
+ -------
+ diaSources : `pandas.DataFrame`
+ Table of sources with the supplied coordinates and IDs assigned.
+ """
+ n = len(raVals)
+ raVals[raVals < 0] += 2*np.pi
+ ra = pd.Series(np.degrees(raVals), name='ra')
+ dec = pd.Series(np.degrees(decVals), name='dec')
+ diaSourceId = pd.Series([idGenerator() for i in range(n)], name='diaSourceId')
+ if diaObjectIds is None:
+ diaObjectId = pd.Series(diaSourceId, name='diaObjectId')
+ else:
+ diaObjectId = pd.Series(diaObjectIds, name='diaObjectId')
+ baseSources = pd.concat([diaSourceId, diaObjectId, ra, dec], axis=1)
+ baseSources["visit"] = visit
+ baseSources["detector"] = detector
+
+ baseSources["band"] = 'A'
+ preserveColumns = ["diaSourceId", "diaObjectId", "ra", "dec", "visit", "detector", "band"]
+
+ diaSources = fillRandomTable(self.schema, baseSources,
+ tableName="DiaSource",
+ preserveColumns=preserveColumns)
+ # Do *not* set the index to the diaSourceId. It will need to be diaObjectId for matching (later)
+ # diaSources.set_index("diaSourceId", inplace=True)
+ return diaSources
+
+ def generateFalseDetections(self, x0, x1, y0, y1, nBogus, seed):
+ """Generate random coordinates within the ranges provided.
+
+ Parameters
+ ----------
+ x0 : `int`
+ Minimum projected x coordinate
+ x1 : `int`
+ Maximum projected x coordinate
+ y0 : `int`
+ Minimum projected y coordinate
+ y1 : `int`
+ Maximum projected y coordinate
+ nBogus : `int`
+ Number of "fake" sources within the region.
+ rng : `int`
+ Seed value for the random number generator to provide repeatable results.
+
+ Returns
+ -------
+ ra, dec : `numpy.ndarray`
+ Coordinates matching the randomly generated locations.
+ """
+ self.log.info(f"Simulating {nBogus} false detections within region.")
+ rng = np.random.RandomState(seed)
+ x = rng.random_sample(nBogus)*(x1 - x0) + x0
+ y = rng.random_sample(nBogus)*(y1 - y0) + y0
+ return stereographicXY2RaDec(x, y)
+
+ def simpleMatch(self, diaSourceTable, diaObjects, updateNdiaSources=True):
+ """Match by pre-defined ID.
+
+ Parameters
+ ----------
+ diaSourceTable : `pandas.DataFrame`
+ New DIASources to be associated with existing DIAObjects.
+ diaObjects : `pandas.DataFrame`
+ Existing diaObjects from the Apdb.
+
+ Returns
+ -------
+ result : `lsst.pipe.base.Struct`
+ Results struct with components.
+
+ - ``matchedDiaSources`` : DiaSources that were matched. Matched
+ Sources have their diaObjectId updated and set to the id of the
+ diaObject they were matched to. (`pandas.DataFrame`)
+ - ``unAssocDiaSources`` : DiaSources that were not matched.
+ Unassociated sources have their diaObject set to 0 as they
+ were not associated with any existing DiaObjects.
+ (`pandas.DataFrame`)
+ - ``nUpdatedDiaObjects`` : Number of DiaObjects that were
+ matched to new DiaSources. (`int`)
+ - ``nUnassociatedDiaObjects`` : Number of DiaObjects that were
+ not matched a new DiaSource. (`int`)
+ """
+ # Only use diaObjectId as the index here, and do not update the index
+ # of the original diaSourceTable
+ diaSourceTable.set_index("diaObjectId", inplace=True)
+ matchedDiaSources = diaSourceTable.loc[diaSourceTable.index.intersection(diaObjects.index)]
+ unAssocDiaSources = diaSourceTable.loc[diaSourceTable.index.difference(diaObjects.index)]
+
+ matchedDiaObjectInds = diaObjects.index.intersection(diaSourceTable.index)
+ if not diaObjects.empty and updateNdiaSources:
+ nDiaSources = diaObjects.loc[matchedDiaObjectInds, "nDiaSources"] + 1
+ diaObjects.loc[matchedDiaObjectInds, "nDiaSources"] = nDiaSources
+
+ # Reset the index of the diaSource dataframes, so diaObjectId is still
+ # a valid column
+ return pipeBase.Struct(unAssocDiaSources=unAssocDiaSources.reset_index(),
+ matchedDiaSources=matchedDiaSources.reset_index(),
+ nUpdatedDiaObjects=len(matchedDiaSources),
+ nUnassociatedDiaObjects=len(unAssocDiaSources),
+ matchedDiaObjects=diaObjects.loc[matchedDiaObjectInds],
+ )
+
+ def associateDiaSources(self, diaSourceTable, diaObjects, diaSourcesReal, diaSourcesBogus):
+ """Associate DiaSources with DiaObjects.
+
+ Parameters
+ ----------
+ diaSourceTable : `pandas.DataFrame`
+ Newly detected DiaSources.
+ diaObjects : `pandas.DataFrame`
+ Table of DiaObjects from preloaded DiaObjects.
+
+ Returns
+ -------
+ associatedDiaSources : `pandas.DataFrame`
+ Associated DiaSources with DiaObjects.
+ newDiaObjects : `pandas.DataFrame`
+ Table of new DiaObjects after association.
+ """
+ # Associate new DiaSources with existing DiaObjects.
+ assocResults = self.simpleMatch(diaSourceTable, diaObjects, updateNdiaSources=True)
+ assocReal = self.simpleMatch(diaSourcesReal, diaObjects, updateNdiaSources=False)
+ assocBogus = self.simpleMatch(diaSourcesBogus, diaObjects, updateNdiaSources=False)
+
+ toAssociate = []
+
+ # Create new DiaObjects from unassociated diaSources.
+ createResults = self.createNewDiaObjects(assocResults.unAssocDiaSources)
+ if len(assocResults.matchedDiaSources) > 0:
+ toAssociate.append(assocResults.matchedDiaSources)
+ toAssociate.append(createResults.diaSources)
+ associatedDiaSources = pd.concat(toAssociate)
+
+ self.log.info("%i updated and %i unassociated diaSources. Creating %i new diaObjects",
+ assocResults.nUpdatedDiaObjects,
+ assocResults.nUnassociatedDiaObjects,
+ createResults.nNewDiaObjects,
+ )
+ self.log.info(f"{assocReal.nUpdatedDiaObjects} real sources associated"
+ f" and {assocReal.nUnassociatedDiaObjects} not associated")
+ self.log.info(f"{assocBogus.nUpdatedDiaObjects} fake sources associated"
+ f" and {assocBogus.nUnassociatedDiaObjects} not associated")
+
+ if self.config.raise_on_associated_fakes & assocBogus.nUpdatedDiaObjects > 0:
+ raise RuntimeError("Fake sources were associated with real DiaObjects."
+ " This is likely due to restarting the simulation"
+ " after partial writes to the APDB")
+ # Index the DiaSource catalog for this visit after all associations
+ # have been made.
+ associatedDiaSources.set_index(["diaObjectId",
+ "band",
+ "diaSourceId"],
+ drop=False,
+ inplace=True)
+ return (associatedDiaSources, createResults.newDiaObjects, assocResults.matchedDiaObjects)
+
+ def createNewDiaObjects(self, unAssocDiaSources):
+ """Loop through the set of DiaSources and create new DiaObjects
+ for unassociated DiaSources.
+
+ Parameters
+ ----------
+ unAssocDiaSources : `pandas.DataFrame`
+ Set of DiaSources to create new DiaObjects from.
+
+ Returns
+ -------
+ results : `lsst.pipe.base.Struct`
+ Results struct containing:
+
+ - diaSources : `pandas.DataFrame`
+ DiaSource catalog with updated DiaObject ids.
+ - newDiaObjects : `pandas.DataFrame`
+ Newly created DiaObjects from the unassociated DiaSources.
+ - nNewDiaObjects : `int`
+ Number of newly created diaObjects.
+ """
+ if len(unAssocDiaSources) == 0:
+ newDiaObjects = make_empty_catalog(self.schema, tableName="DiaObject")
+ else:
+ # Do *not* set the diaObjectId to the diaSourceId.
+ # For this simulation we are using custom diaObjectIds,
+ # and need to preserve them.
+ # unAssocDiaSources["diaObjectId"] = unAssocDiaSources["diaSourceId"]
+ preserveColumns = ["diaObjectId", "nearbyLowzGal"]
+ # Needs to be set for correct database formatting
+ unAssocDiaSources["nearbyLowzGal"] = None
+
+ # Fill the forced sources for each diaObject with random data
+ newDiaObjects = fillRandomTable(self.schema, unAssocDiaSources,
+ tableName="DiaObject",
+ preserveColumns=preserveColumns)
+ newDiaObjects.nDiaSources = 1
+ return pipeBase.Struct(diaSources=unAssocDiaSources,
+ newDiaObjects=newDiaObjects,
+ nNewDiaObjects=len(newDiaObjects))
+
+ def mergeAssociatedCatalogs(self, diaObjects, newDiaObjects):
+ """Merge the associated diaObjects to their previous history.
+ Also update the index of associatedDiaSources in place.
+
+ Parameters
+ ----------
+ diaObjects : `pandas.DataFrame`
+ Table of DiaObjects from preloaded DiaObjects.
+ newDiaObjects : `pandas.DataFrame`
+ Table of new DiaObjects after association.
+
+ Returns
+ -------
+ mergedDiaObjects : `pandas.DataFrame`
+ Table of new DiaObjects merged with their history.
+
+ Raises
+ ------
+ RuntimeError
+ Raised if duplicate DiaObjects are found.
+ """
+
+ # Append new DiaObjects to their previous history.
+ # Do not modify diaSources
+ newDiaObjects.set_index("diaObjectId", drop=False, inplace=True)
+ if diaObjects.empty:
+ mergedDiaObjects = newDiaObjects
+ elif not newDiaObjects.empty:
+ mergedDiaObjects = pd.concat([diaObjects, newDiaObjects], sort=True)
+ else:
+ mergedDiaObjects = diaObjects
+ if DiaPipelineTask.testDataFrameIndex(mergedDiaObjects):
+ raise RuntimeError("Duplicate DiaObjects created after association.")
+ return mergedDiaObjects
+
+ def runForcedMeasurement(self, diaObjects, idGenerator, visit, detector):
+ """Forced Source Measurement
+
+ Forced photometry on the difference and calibrated exposures using the
+ new and updated DiaObject locations.
+
+ Parameters
+ ----------
+ diaObjects : `pandas.DataFrame`
+ Catalog of DiaObjects.
+ idGenerator : `lsst.meas.base.IdGenerator`
+ Object that generates source IDs and random number generator seeds.
+ visit : `int`
+ Visit ID.
+ detector : `int`
+ Detector number, used for the ID generator.
+ Note that each instance simulates an entire focal plane.
+ The detector dimension is used as a convenience to allow re-using
+ code and infrastructure.
+
+ Returns
+ -------
+ diaForcedSources : `pandas.DataFrame`
+ Catalog of calibrated forced photometered fluxes on both the
+ difference and direct images at DiaObject locations.
+ """
+ # Restrict forced source measurement to objects with sufficient history to be reliable.
+ objectTable = diaObjects.query(f'nDiaSources >= {self.config.historyThreshold}')
+ preserveColumns = ["diaObjectId", "ra", "dec"]
+ baseForcedSources = objectTable[preserveColumns].copy()
+ baseForcedSources["visit"] = visit
+ preserveColumns.append("visit")
+ baseForcedSources["detector"] = detector
+ preserveColumns.append("detector")
+ baseForcedSources["band"] = 'A'
+ preserveColumns.append("band")
+ baseForcedSources["diaForcedSourceId"] = [idGenerator() for i in range(len(baseForcedSources))]
+ preserveColumns.append("diaForcedSourceId")
+ # Fill the forced sources for each diaObject with random data
+ diaForcedSources = fillRandomTable(self.schema, baseForcedSources,
+ tableName="DiaForcedSource",
+ preserveColumns=preserveColumns)
+ self.log.info(f"Updating {len(diaForcedSources)} diaForcedSources in the APDB")
+ return diaForcedSources
+
+ @timeMethod
+ def writeToApdb(self, updatedDiaObjects, associatedDiaSources, diaForcedSources, dateTime):
+ """Write to the Alert Production Database (Apdb).
+
+ Store DiaSources, updated DiaObjects, and DiaForcedSources in the
+ Alert Production Database (Apdb).
+
+ Parameters
+ ----------
+ updatedDiaObjects : `pandas.DataFrame`
+ Catalog of updated DiaObjects.
+ associatedDiaSources : `pandas.DataFrame`
+ Associated DiaSources with DiaObjects.
+ diaForcedSources : `pandas.DataFrame`
+ Catalog of calibrated forced photometered fluxes on both the
+ difference and direct images at DiaObject locations.
+ """
+ # Store DiaSources, updated DiaObjects, and DiaForcedSources in the
+ # Apdb.
+ # Drop empty columns that are nullable in the APDB.
+ diaObjectStore = dropEmptyColumns(self.schema, updatedDiaObjects, tableName="DiaObject")
+ if associatedDiaSources is None:
+ diaSourceStore = None
+ else:
+ diaSourceStore = dropEmptyColumns(self.schema, associatedDiaSources, tableName="DiaSource")
+ diaForcedSourceStore = dropEmptyColumns(self.schema, diaForcedSources, tableName="DiaForcedSource")
+ self.apdb.store(
+ dateTime,
+ diaObjectStore,
+ diaSourceStore,
+ diaForcedSourceStore)
+ self.log.info("APDB updated.")
+
+
+def randomCircleXY(radius, seed, n=None):
+ """Generate random x, y coordinates within a circular region
+
+ Parameters
+ ----------
+ radius : `float`
+ Radius of the circular region.
+ seed : `int`
+ Seed for the random number generator.
+ n : `int`, optional
+ Number of points to generate.
+
+ Returns
+ -------
+ x, y : `float`, or `numpy.ndarray`
+ Random coordinates inside the region.
+ """
+ # Draw the locations from a regular grid in x,y, but use polar
+ # coordinates to maintain a circular region in this space.
+ rng = np.random.RandomState(seed)
+ r2 = rng.random_sample(n)*radius**2
+ phi = rng.random_sample(n)*2*np.pi
+ x = np.sqrt(r2)*np.cos(phi)
+ y = np.sqrt(r2)*np.sin(phi)
+ return (x, y)
+
+
+def stereographicXY2RaDec(x, y):
+ """Convert from a grid-like stereographic projection to RA and Dec
+
+ Notes
+ -----
+ A stereographic projection centered on the pole
+ From Eq 56 of Calabretta and Greisen (2002)
+ "Representations of celestial coordinates in FITS"
+
+ Parameters
+ ----------
+ x : `float`
+ Column index of the stereographic grid.
+ y : `float`
+ Row index of the stereographic grid.
+
+ Returns
+ -------
+ ra, dec : `float`
+ Right Ascension and Declination.
+ """
+ r = np.sqrt(x**2 + y**2)
+ dec = np.pi/2 - 2*np.arctan(r/2)
+ ra = np.arctan2(y, x)
+ return (ra, dec)
+
+
+def stereographicRaDec2XY(ra, dec):
+ """Convert from Ra, Dec to a grid-like stereographic projection
+
+ Notes
+ -----
+ A stereographic projection centered on the pole
+ From Eq 56 of Calabretta and Greisen (2002)
+ "Representations of celestial coordinates in FITS"
+
+ Parameters
+ ----------
+ ra : `float`
+ Right Ascension (RA) in radians.
+ dec : `float`
+ Declination (Dec) in radians.
+
+ Returns
+ -------
+ x, y : `float`
+ Grid coordinates of the stereographic projection
+ """
+ r = 2*np.cos(dec)/(1 + np.sin(dec))
+ x = r*np.cos(ra)
+ y = r*np.sin(ra)
+ return (x, y)
+
+
+def fillRandomTable(apdbSchema, sourceTable, tableName, preserveColumns=None):
+ """Force a table to conform to the schema defined by the APDB.
+
+ This method uses the table definitions in ``sdm_schemas`` to
+ load the schema of the APDB, and does not actually connect to the APDB.
+
+ Parameters
+ ----------
+ apdbSchema : `dict` [`str`, `lsst.dax.apdb.schema_model.Table`]
+ Schema from ``sdm_schemas`` containing the table definition to use.
+ sourceTable : `pandas.DataFrame`
+ The input table to convert.
+ tableName : `str`
+ Name of the table in the schema to use.
+ preserveColumns : `list` of `str`, optional
+ List of columns to copy from the input sourceTable.
+
+ Returns
+ -------
+ `pandas.DataFrame`
+ A table with the correct schema for the APDB and data copied from
+ the input ``sourceTable``.
+ """
+ table = apdbSchema[tableName]
+
+ data = {}
+ nSrc = len(sourceTable)
+ rng = np.random.default_rng()
+ if preserveColumns is None:
+ preserveColumns = []
+
+ for columnDef in table.columns:
+ if columnDef.name in preserveColumns:
+ data[columnDef.name] = sourceTable[columnDef.name]
+ dtype = column_dtype(columnDef.datatype)
+ if columnDef.name in sourceTable.columns:
+ data[columnDef.name] = pd.Series(sourceTable[columnDef.name], dtype=dtype)
+ else:
+ try:
+ dataInit = rng.random(nSrc).astype(dtype)
+ except (TypeError, ValueError):
+ dataInit = np.zeros(nSrc, dtype=column_dtype(columnDef.datatype))
+ data[columnDef.name] = pd.Series(dataInit, index=sourceTable.index)
+
+ df = convertTableToSdmSchema(apdbSchema, pd.DataFrame(data), tableName=tableName)
+ return df
diff --git a/tests/test_packageAlerts.py b/tests/test_packageAlerts.py
deleted file mode 100644
index a77f5187..00000000
--- a/tests/test_packageAlerts.py
+++ /dev/null
@@ -1,654 +0,0 @@
-# This file is part of ap_association.
-#
-# Developed for the LSST Data Management System.
-# This product includes software developed by the LSST Project
-# (https://www.lsst.org).
-# See the COPYRIGHT file at the top-level directory of this distribution
-# for details of code ownership.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program. If not, see .
-
-import io
-import os
-
-import numpy as np
-import pandas as pd
-import tempfile
-import unittest
-from unittest.mock import patch, Mock
-from astropy import wcs
-from astropy.nddata import CCDData
-import fastavro
-try:
- import confluent_kafka
- from confluent_kafka import KafkaException
-except ImportError:
- confluent_kafka = None
-
-import lsst.alert.packet as alertPack
-from lsst.ap.association import PackageAlertsConfig, PackageAlertsTask
-from lsst.afw.cameraGeom.testUtils import DetectorWrapper
-import lsst.afw.image as afwImage
-import lsst.daf.base as dafBase
-from lsst.dax.apdb import Apdb, ApdbSql
-import lsst.geom as geom
-import lsst.meas.base.tests
-from lsst.sphgeom import Box
-import lsst.utils.tests
-import utils_tests
-
-
-def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
- """Run object and source catalogs through the Apdb to get the correct
- table schemas.
-
- Parameters
- ----------
- objects : `pandas.DataFrame`
- Set of test DiaObjects to round trip.
- sources : `pandas.DataFrame`
- Set of test DiaSources to round trip.
- forcedSources : `pandas.DataFrame`
- Set of test DiaForcedSources to round trip.
- dateTime : `astropy.time.Time`
- Time for the Apdb.
-
- Returns
- -------
- objects : `pandas.DataFrame`
- Round tripped objects.
- sources : `pandas.DataFrame`
- Round tripped sources.
- """
- with tempfile.NamedTemporaryFile() as tmpFile:
- apdbConfig = ApdbSql.init_database(db_url="sqlite:///" + tmpFile.name)
- apdb = Apdb.from_config(apdbConfig)
-
- wholeSky = Box.full()
- loadedObjects = apdb.getDiaObjects(wholeSky)
- if loadedObjects.empty:
- diaObjects = objects
- else:
- diaObjects = pd.concat([loadedObjects, objects])
- loadedDiaSources = apdb.getDiaSources(wholeSky, [], dateTime)
- if loadedDiaSources.empty:
- diaSources = sources
- else:
- diaSources = pd.concat([loadedDiaSources, sources])
- loadedDiaForcedSources = apdb.getDiaForcedSources(wholeSky, [], dateTime)
- if loadedDiaForcedSources.empty:
- diaForcedSources = forcedSources
- else:
- diaForcedSources = pd.concat([loadedDiaForcedSources, forcedSources])
-
- apdb.store(dateTime, diaObjects, diaSources, diaForcedSources)
-
- diaObjects = apdb.getDiaObjects(wholeSky)
- diaSources = apdb.getDiaSources(wholeSky,
- np.unique(diaObjects["diaObjectId"]),
- dateTime)
- diaForcedSources = apdb.getDiaForcedSources(
- wholeSky, np.unique(diaObjects["diaObjectId"]), dateTime)
-
- diaObjects.set_index("diaObjectId", drop=False, inplace=True)
- diaSources.set_index(["diaObjectId", "band", "diaSourceId"],
- drop=False,
- inplace=True)
- diaForcedSources.set_index(["diaObjectId"], drop=False, inplace=True)
-
- return (diaObjects, diaSources, diaForcedSources)
-
-
-VISIT = 2
-DETECTOR = 42
-
-
-def mock_alert(alert_id):
- """Generate a minimal mock alert.
- """
- return {
- "alertId": alert_id,
- "diaSource": {
- "midpointMjdTai": 5,
- "diaSourceId": 1234,
- "visit": VISIT,
- "detector": DETECTOR,
- "band": 'g',
- "ra": 12.5,
- "dec": -16.9,
- # These types are 32-bit floats in the avro schema, so we have to
- # make them that type here, so that they round trip appropriately.
- "x": np.float32(15.7),
- "y": np.float32(89.8),
- "apFlux": np.float32(54.85),
- "apFluxErr": np.float32(70.0),
- "snr": np.float32(6.7),
- "psfFlux": np.float32(700.0),
- "psfFluxErr": np.float32(90.0),
- }
- }
-
-
-def _deserialize_alert(alert_bytes):
- """Deserialize an alert message from Kafka.
-
- Parameters
- ----------
- alert_bytes : `bytes`
- Binary-encoding serialized Avro alert, including Confluent Wire
- Format prefix.
-
- Returns
- -------
- alert : `dict`
- An alert payload.
- """
- schema = alertPack.Schema.from_uri(str(alertPack.get_uri_to_latest_schema()))
- content_bytes = io.BytesIO(alert_bytes[5:])
-
- return fastavro.schemaless_reader(content_bytes, schema.definition)
-
-
-class TestPackageAlerts(lsst.utils.tests.TestCase):
- def setUp(self):
- # Create an instance of random generator with fixed seed.
- rng = np.random.default_rng(1234)
-
- patcher = patch.dict(os.environ, {"AP_KAFKA_PRODUCER_PASSWORD": "fake_password",
- "AP_KAFKA_PRODUCER_USERNAME": "fake_username",
- "AP_KAFKA_SERVER": "fake_server",
- "AP_KAFKA_TOPIC": "fake_topic"})
- self.environ = patcher.start()
- self.addCleanup(patcher.stop)
- self.cutoutSize = 35
- self.center = lsst.geom.Point2D(50.1, 49.8)
- self.bbox = lsst.geom.Box2I(lsst.geom.Point2I(-20, -30),
- lsst.geom.Extent2I(140, 160))
- self.dataset = lsst.meas.base.tests.TestDataset(self.bbox)
- self.dataset.addSource(100000.0, self.center)
- exposure, catalog = self.dataset.realize(
- 10.0,
- self.dataset.makeMinimalSchema(),
- randomSeed=1234)
- self.exposure = exposure
- detector = DetectorWrapper(id=DETECTOR, bbox=exposure.getBBox()).detector
- self.exposure.setDetector(detector)
-
- visit = afwImage.VisitInfo(
- id=VISIT,
- exposureTime=200.,
- date=dafBase.DateTime("2014-05-13T17:00:00.000000000",
- dafBase.DateTime.Timescale.TAI))
- self.exposure.info.id = 1234
- self.exposure.info.setVisitInfo(visit)
-
- self.exposure.setFilter(
- afwImage.FilterLabel(band='g', physical="g.MP9401"))
-
- diaObjects = utils_tests.makeDiaObjects(2, self.exposure, rng)
- diaSourceHistory = utils_tests.makeDiaSources(
- 10, diaObjects["diaObjectId"].to_numpy(), self.exposure, rng)
- diaForcedSources = utils_tests.makeDiaForcedSources(
- 10, diaObjects["diaObjectId"].to_numpy(), self.exposure, rng)
- self.diaObjects, diaSourceHistory, self.diaForcedSources = _roundTripThroughApdb(
- diaObjects,
- diaSourceHistory,
- diaForcedSources,
- self.exposure.visitInfo.date.toAstropy())
- diaSourceHistory["programId"] = 0
-
- self.diaSources = diaSourceHistory.loc[[(1, "g", 9), (2, "g", 10)], :]
- self.diaSources["bboxSize"] = self.cutoutSize
- self.diaSourceHistory = diaSourceHistory.drop(labels=[(1, "g", 9),
- (2, "g", 10)])
-
- self.cutoutWcs = wcs.WCS(naxis=2)
- self.cutoutWcs.wcs.crpix = [self.center[0], self.center[1]]
- self.cutoutWcs.wcs.crval = [
- self.exposure.getWcs().getSkyOrigin().getRa().asDegrees(),
- self.exposure.getWcs().getSkyOrigin().getDec().asDegrees()]
- self.cutoutWcs.wcs.cd = self.exposure.getWcs().getCdMatrix()
- self.cutoutWcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
-
- def testCreateExtentMinimum(self):
- """Test the extent creation for the cutout bbox returns a cutout with
- the minimum cutouut size.
- """
- packConfig = PackageAlertsConfig()
- # Just create a minimum less than the default cutout size.
- packConfig.minCutoutSize = self.cutoutSize - 5
- packageAlerts = PackageAlertsTask(config=packConfig)
- extent = packageAlerts.createDiaSourceExtent(
- packConfig.minCutoutSize - 5)
- self.assertTrue(extent == geom.Extent2I(packConfig.minCutoutSize,
- packConfig.minCutoutSize))
- # Test that the cutout size is correctly increased.
- extent = packageAlerts.createDiaSourceExtent(self.cutoutSize)
- self.assertTrue(extent == geom.Extent2I(self.cutoutSize,
- self.cutoutSize))
-
- def testCreateExtentMaximum(self):
- """Test the extent creation for the cutout bbox returns a cutout with
- the maximum cutout size.
- """
- packConfig = PackageAlertsConfig()
- # Just create a maximum more than the default cutout size.
- packConfig.maxCutoutSize = self.cutoutSize + 5
- packageAlerts = PackageAlertsTask(config=packConfig)
- extent = packageAlerts.createDiaSourceExtent(
- packConfig.maxCutoutSize + 5)
- self.assertTrue(extent == geom.Extent2I(packConfig.maxCutoutSize,
- packConfig.maxCutoutSize))
- # Test that the cutout size is correctly reduced.
- extent = packageAlerts.createDiaSourceExtent(self.cutoutSize)
- self.assertTrue(extent == geom.Extent2I(self.cutoutSize,
- self.cutoutSize))
-
- def testCreateCcdDataCutout(self):
- """Test that the data is being extracted into the CCDData cutout
- correctly.
- """
- packageAlerts = PackageAlertsTask()
-
- diaSrcId = 1234
- ccdData = packageAlerts.createCcdDataCutout(
- self.exposure,
- self.exposure.getWcs().getSkyOrigin(),
- self.exposure.getWcs().getPixelOrigin(),
- self.exposure.getBBox().getDimensions(),
- self.exposure.getPhotoCalib(),
- diaSrcId)
- calibExposure = self.exposure.getPhotoCalib().calibrateImage(
- self.exposure.getMaskedImage())
-
- self.assertFloatsAlmostEqual(ccdData.wcs.wcs.cd,
- self.cutoutWcs.wcs.cd)
- self.assertFloatsAlmostEqual(ccdData.data,
- calibExposure.getImage().array)
- self.assertFloatsAlmostEqual(ccdData.psf,
- self.exposure.psf.computeKernelImage(self.center).array)
-
- ccdData = packageAlerts.createCcdDataCutout(
- self.exposure,
- geom.SpherePoint(0, 0, geom.degrees),
- geom.Point2D(-100, -100),
- self.exposure.getBBox().getDimensions(),
- self.exposure.getPhotoCalib(),
- diaSrcId)
- self.assertTrue(ccdData is None)
-
- def testMakeLocalTransformMatrix(self):
- """Test that the local WCS approximation is correct.
- """
- packageAlerts = PackageAlertsTask()
-
- sphPoint = self.exposure.getWcs().pixelToSky(self.center)
- cutout = self.exposure.getCutout(sphPoint,
- geom.Extent2I(self.cutoutSize,
- self.cutoutSize))
- cd = packageAlerts.makeLocalTransformMatrix(
- cutout.getWcs(), self.center, sphPoint)
- self.assertFloatsAlmostEqual(
- cd,
- cutout.getWcs().getCdMatrix(),
- rtol=1e-11,
- atol=1e-11)
-
- def testStreamCcdDataToBytes(self):
- """Test round tripping an CCDData cutout to bytes and back.
- """
- packageAlerts = PackageAlertsTask()
-
- sphPoint = self.exposure.getWcs().pixelToSky(self.center)
- cutout = self.exposure.getCutout(sphPoint,
- geom.Extent2I(self.cutoutSize,
- self.cutoutSize))
- cutoutCcdData = CCDData(
- data=cutout.getImage().array,
- wcs=self.cutoutWcs,
- unit="adu")
-
- cutoutBytes = packageAlerts.streamCcdDataToBytes(cutoutCcdData)
- with io.BytesIO(cutoutBytes) as bytesIO:
- cutoutFromBytes = CCDData.read(bytesIO, format="fits")
- self.assertFloatsAlmostEqual(cutoutCcdData.data, cutoutFromBytes.data)
-
- def testMakeAlertDict(self):
- """Test stripping data from the various data products and into a
- dictionary "alert".
- """
- packageAlerts = PackageAlertsTask()
- alertId = 1234
-
- for srcIdx, diaSource in self.diaSources.iterrows():
- sphPoint = geom.SpherePoint(diaSource["ra"],
- diaSource["dec"],
- geom.degrees)
- pixelPoint = geom.Point2D(diaSource["x"], diaSource["y"])
- cutout = self.exposure.getCutout(sphPoint,
- geom.Extent2I(self.cutoutSize,
- self.cutoutSize))
- ccdCutout = packageAlerts.createCcdDataCutout(
- cutout,
- sphPoint,
- pixelPoint,
- geom.Extent2I(self.cutoutSize, self.cutoutSize),
- cutout.getPhotoCalib(),
- 1234)
- cutoutBytes = packageAlerts.streamCcdDataToBytes(
- ccdCutout)
- objSources = self.diaSourceHistory.loc[srcIdx[0]]
- objForcedSources = self.diaForcedSources.loc[srcIdx[0]]
- alert = packageAlerts.makeAlertDict(
- alertId,
- diaSource,
- self.diaObjects.loc[srcIdx[0]],
- objSources,
- objForcedSources,
- ccdCutout,
- ccdCutout,
- ccdCutout)
- self.assertEqual(len(alert), 10)
-
- self.assertEqual(alert["alertId"], alertId)
- self.assertEqual(alert["diaSource"], diaSource.to_dict())
- self.assertEqual(alert["cutoutDifference"],
- cutoutBytes)
- self.assertEqual(alert["cutoutScience"],
- cutoutBytes)
- self.assertEqual(alert["cutoutTemplate"],
- cutoutBytes)
-
- @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
- def test_produceAlerts_empty_password(self):
- """ Test that produceAlerts raises if the password is empty or missing.
- """
- self.environ['AP_KAFKA_PRODUCER_PASSWORD'] = ""
- with self.assertRaisesRegex(ValueError, "Kafka password"):
- packConfig = PackageAlertsConfig(doProduceAlerts=True)
- PackageAlertsTask(config=packConfig)
-
- del self.environ['AP_KAFKA_PRODUCER_PASSWORD']
- with self.assertRaisesRegex(ValueError, "Kafka password"):
- packConfig = PackageAlertsConfig(doProduceAlerts=True)
- PackageAlertsTask(config=packConfig)
-
- @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
- def test_produceAlerts_empty_username(self):
- """ Test that produceAlerts raises if the username is empty or missing.
- """
- self.environ['AP_KAFKA_PRODUCER_USERNAME'] = ""
- with self.assertRaisesRegex(ValueError, "Kafka username"):
- packConfig = PackageAlertsConfig(doProduceAlerts=True)
- PackageAlertsTask(config=packConfig)
-
- del self.environ['AP_KAFKA_PRODUCER_USERNAME']
- with self.assertRaisesRegex(ValueError, "Kafka username"):
- packConfig = PackageAlertsConfig(doProduceAlerts=True)
- PackageAlertsTask(config=packConfig)
-
- @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
- def test_produceAlerts_empty_server(self):
- """ Test that produceAlerts raises if the server is empty or missing.
- """
- self.environ['AP_KAFKA_SERVER'] = ""
- with self.assertRaisesRegex(ValueError, "Kafka server"):
- packConfig = PackageAlertsConfig(doProduceAlerts=True)
- PackageAlertsTask(config=packConfig)
-
- del self.environ['AP_KAFKA_SERVER']
- with self.assertRaisesRegex(ValueError, "Kafka server"):
- packConfig = PackageAlertsConfig(doProduceAlerts=True)
- PackageAlertsTask(config=packConfig)
-
- @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
- def test_produceAlerts_empty_topic(self):
- """ Test that produceAlerts raises if the topic is empty or missing.
- """
- self.environ['AP_KAFKA_TOPIC'] = ""
- with self.assertRaisesRegex(ValueError, "Kafka topic"):
- packConfig = PackageAlertsConfig(doProduceAlerts=True)
- PackageAlertsTask(config=packConfig)
-
- del self.environ['AP_KAFKA_TOPIC']
- with self.assertRaisesRegex(ValueError, "Kafka topic"):
- packConfig = PackageAlertsConfig(doProduceAlerts=True)
- PackageAlertsTask(config=packConfig)
-
- @patch('confluent_kafka.Producer')
- @patch.object(PackageAlertsTask, '_server_check')
- @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
- def test_produceAlerts_success(self, mock_server_check, mock_producer):
- """ Test that produceAlerts calls the producer on all provided alerts
- when the alerts are all under the batch size limit.
- """
- packConfig = PackageAlertsConfig(doProduceAlerts=True)
- packageAlerts = PackageAlertsTask(config=packConfig)
- alerts = [mock_alert(1), mock_alert(2)]
-
- # Create a variable and assign it an instance of the patched kafka producer
- producer_instance = mock_producer.return_value
- producer_instance.produce = Mock()
- producer_instance.flush = Mock()
- unix_midpoint = self.exposure.visitInfo.date.toAstropy().tai.unix
- exposure_time = self.exposure.visitInfo.exposureTime
- packageAlerts.produceAlerts(alerts, VISIT, DETECTOR, unix_midpoint, exposure_time)
-
- self.assertEqual(mock_server_check.call_count, 1)
- self.assertEqual(producer_instance.produce.call_count, len(alerts))
- self.assertEqual(producer_instance.flush.call_count, len(alerts)+1)
-
- @patch('confluent_kafka.Producer')
- @patch.object(PackageAlertsTask, '_server_check')
- @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
- def test_produceAlerts_one_failure(self, mock_server_check, mock_producer):
- """ Test that produceAlerts correctly fails on one alert
- and is writing the failure to disk.
- """
- counter = 0
-
- def mock_produce(*args, **kwargs):
- nonlocal counter
- counter += 1
- if counter == 2:
- raise KafkaException
- else:
- return
-
- packConfig = PackageAlertsConfig(doProduceAlerts=True, doWriteFailedAlerts=True)
- packageAlerts = PackageAlertsTask(config=packConfig)
-
- patcher = patch("builtins.open")
- patch_open = patcher.start()
- alerts = [mock_alert(1), mock_alert(2), mock_alert(3)]
- unix_midpoint = self.exposure.visitInfo.date.toAstropy().tai.unix
- exposure_time = self.exposure.visitInfo.exposureTime
-
- producer_instance = mock_producer.return_value
- producer_instance.produce = Mock(side_effect=mock_produce)
- producer_instance.flush = Mock()
- packageAlerts.produceAlerts(alerts, VISIT, DETECTOR, unix_midpoint, exposure_time)
-
- self.assertEqual(mock_server_check.call_count, 1)
- self.assertEqual(producer_instance.produce.call_count, len(alerts))
- self.assertEqual(patch_open.call_count, 1)
- self.assertIn(f"{VISIT}_{DETECTOR}_2.avro", patch_open.call_args.args[0])
- # Because one produce raises, we call flush one fewer times than in the success
- # test above.
- self.assertEqual(producer_instance.flush.call_count, len(alerts))
- patcher.stop()
-
- @patch.object(PackageAlertsTask, '_server_check')
- def testRun_without_produce(self, mock_server_check):
- """Test the run method of package alerts with produce set to False and
- doWriteAlerts set to true.
- """
- packConfig = PackageAlertsConfig(doWriteAlerts=True)
- with tempfile.TemporaryDirectory(prefix='alerts') as tempdir:
- packConfig.alertWriteLocation = tempdir
- packageAlerts = PackageAlertsTask(config=packConfig)
-
- packageAlerts.run(self.diaSources,
- self.diaObjects,
- self.diaSourceHistory,
- self.diaForcedSources,
- self.exposure,
- self.exposure,
- self.exposure)
-
- self.assertEqual(mock_server_check.call_count, 0)
-
- with open(os.path.join(tempdir, f"{VISIT}_{DETECTOR}.avro"), 'rb') as f:
- writer_schema, data_stream = \
- packageAlerts.alertSchema.retrieve_alerts(f)
- data = list(data_stream)
-
- self.assertEqual(len(data), len(self.diaSources))
- for idx, alert in enumerate(data):
- for key, value in alert["diaSource"].items():
- if isinstance(value, float):
- if np.isnan(self.diaSources.iloc[idx][key]):
- self.assertTrue(np.isnan(value))
- else:
- self.assertAlmostEqual(
- 1 - value / self.diaSources.iloc[idx][key],
- 0.)
- else:
- self.assertEqual(value, self.diaSources.iloc[idx][key])
- sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
- alert["diaSource"]["dec"],
- geom.degrees)
- pixelPoint = geom.Point2D(alert["diaSource"]["x"], alert["diaSource"]["y"])
- cutout = self.exposure.getCutout(sphPoint,
- geom.Extent2I(self.cutoutSize,
- self.cutoutSize))
- ccdCutout = packageAlerts.createCcdDataCutout(
- cutout,
- sphPoint,
- pixelPoint,
- geom.Extent2I(self.cutoutSize, self.cutoutSize),
- cutout.getPhotoCalib(),
- 1234)
- self.assertEqual(alert["cutoutDifference"],
- packageAlerts.streamCcdDataToBytes(ccdCutout))
-
- @patch.object(PackageAlertsTask, '_server_check')
- def testRun_without_produce_use_averagePsf(self, mock_server_check):
- """Test the run method of package alerts with produce set to False and
- doWriteAlerts set to true.
- """
- packConfig = PackageAlertsConfig(doWriteAlerts=True)
- with tempfile.TemporaryDirectory(prefix='alerts') as tempdir:
- packConfig.alertWriteLocation = tempdir
- packConfig.useAveragePsf = True
- packageAlerts = PackageAlertsTask(config=packConfig)
-
- packageAlerts.run(self.diaSources,
- self.diaObjects,
- self.diaSourceHistory,
- self.diaForcedSources,
- self.exposure,
- self.exposure,
- self.exposure)
-
- self.assertEqual(mock_server_check.call_count, 0)
-
- with open(os.path.join(tempdir, f"{VISIT}_{DETECTOR}.avro"), 'rb') as f:
- writer_schema, data_stream = \
- packageAlerts.alertSchema.retrieve_alerts(f)
- data = list(data_stream)
-
- self.assertEqual(len(data), len(self.diaSources))
- for idx, alert in enumerate(data):
- for key, value in alert["diaSource"].items():
- if isinstance(value, float):
- if np.isnan(self.diaSources.iloc[idx][key]):
- self.assertTrue(np.isnan(value))
- else:
- self.assertAlmostEqual(
- 1 - value / self.diaSources.iloc[idx][key],
- 0.)
- else:
- self.assertEqual(value, self.diaSources.iloc[idx][key])
- sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
- alert["diaSource"]["dec"],
- geom.degrees)
- pixelPoint = geom.Point2D(alert["diaSource"]["x"], alert["diaSource"]["y"])
- cutout = self.exposure.getCutout(sphPoint,
- geom.Extent2I(self.cutoutSize,
- self.cutoutSize))
- ccdCutout = packageAlerts.createCcdDataCutout(
- cutout,
- sphPoint,
- pixelPoint,
- geom.Extent2I(self.cutoutSize, self.cutoutSize),
- cutout.getPhotoCalib(),
- 1234)
- self.assertEqual(alert["cutoutDifference"],
- packageAlerts.streamCcdDataToBytes(ccdCutout))
-
- @patch.object(PackageAlertsTask, 'produceAlerts')
- @patch('confluent_kafka.Producer')
- @patch.object(PackageAlertsTask, '_server_check')
- @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
- def testRun_with_produce(self, mock_produceAlerts, mock_server_check, mock_producer):
- """Test that packageAlerts calls produceAlerts when doProduceAlerts
- is set to True.
- """
- packConfig = PackageAlertsConfig(doProduceAlerts=True)
- packageAlerts = PackageAlertsTask(config=packConfig)
-
- packageAlerts.run(self.diaSources,
- self.diaObjects,
- self.diaSourceHistory,
- self.diaForcedSources,
- self.exposure,
- self.exposure,
- self.exposure)
- self.assertEqual(mock_server_check.call_count, 1)
- self.assertEqual(mock_produceAlerts.call_count, 1)
-
- def test_serialize_alert_round_trip(self):
- """Test that values in the alert packet exactly round trip.
- """
- packClass = PackageAlertsConfig()
- packageAlerts = PackageAlertsTask(config=packClass)
-
- alert = mock_alert(1)
- serialized = PackageAlertsTask._serializeAlert(packageAlerts, alert)
- deserialized = _deserialize_alert(serialized)
-
- for field in alert['diaSource']:
- self.assertEqual(alert['diaSource'][field], deserialized['diaSource'][field])
- self.assertEqual(1, deserialized["alertId"])
-
- @unittest.skipIf(confluent_kafka is None, 'Kafka is not enabled')
- def test_server_check(self):
-
- with self.assertRaisesRegex(KafkaException, "_TRANSPORT"):
- packConfig = PackageAlertsConfig(doProduceAlerts=True)
- PackageAlertsTask(config=packConfig)
-
-
-class MemoryTester(lsst.utils.tests.MemoryTestCase):
- pass
-
-
-def setup_module(module):
- lsst.utils.tests.init()
-
-
-if __name__ == "__main__":
- lsst.utils.tests.init()
- unittest.main()