22
22
23
23
__all__ = ("AssociatedSourcesTractAnalysisConfig" , "AssociatedSourcesTractAnalysisTask" )
24
24
25
+ import astropy .time
26
+ import astropy .units as u
27
+ import lsst .pex .config as pexConfig
25
28
import numpy as np
26
- from astropy .table import join , vstack
29
+ import pandas as pd
30
+ from astropy .table import Table , join , vstack
31
+ from lsst .drp .tasks .gbdesAstrometricFit import calculate_apparent_motion
27
32
from lsst .geom import Box2D
28
33
from lsst .pipe .base import NoWorkFound
29
34
from lsst .pipe .base import connectionTypes as ct
30
35
from lsst .skymap import BaseSkyMap
36
+ from smatch import Matcher
31
37
32
38
from ..interfaces import AnalysisBaseConfig , AnalysisBaseConnections , AnalysisPipelineTask
33
39
@@ -70,11 +76,59 @@ class AssociatedSourcesTractAnalysisConnections(
70
76
dimensions = ("instrument" ,),
71
77
isCalibration = True ,
72
78
)
79
+ astrometricCorrectionCatalogs = ct .Input (
80
+ doc = "Catalog containing proper motion and parallax." ,
81
+ name = "gbdesAstrometricFit_starCatalog" ,
82
+ storageClass = "ArrowNumpyDict" ,
83
+ dimensions = ("instrument" , "skymap" , "tract" , "physical_filter" ),
84
+ multiple = True ,
85
+ deferLoad = True ,
86
+ )
87
+ visitTable = ct .Input (
88
+ doc = "Catalog containing visit information." ,
89
+ name = "visitTable" ,
90
+ storageClass = "DataFrame" ,
91
+ dimensions = ("instrument" ,),
92
+ )
93
+
94
+ def __init__ (self , * , config = None ):
95
+ super ().__init__ (config = config )
96
+
97
+ if not config .applyAstrometricCorrections :
98
+ self .inputs .remove ("astrometricCorrectionCatalogs" )
99
+ self .inputs .remove ("visitTable" )
73
100
74
101
75
102
class AssociatedSourcesTractAnalysisConfig (
76
103
AnalysisBaseConfig , pipelineConnections = AssociatedSourcesTractAnalysisConnections
77
104
):
105
+ applyAstrometricCorrections = pexConfig .Field (
106
+ dtype = bool ,
107
+ default = True ,
108
+ doc = "Apply proper motion and parallax corrections to source positions." ,
109
+ )
110
+ matchingRadius = pexConfig .Field (
111
+ dtype = float ,
112
+ default = 0.2 ,
113
+ doc = (
114
+ "Radius in mas with which to match the mean positions of the sources with the positions in the"
115
+ " astrometricCorrectionCatalogs."
116
+ ),
117
+ )
118
+ astrometricCorrectionParameters = pexConfig .DictField (
119
+ keytype = str ,
120
+ itemtype = str ,
121
+ # TODO: DM-45845 Update default names when catalog gets updated.
122
+ default = {
123
+ "ra" : "starX" ,
124
+ "dec" : "starY" ,
125
+ "pmRA" : "starPMx" ,
126
+ "pmDec" : "starPMy" ,
127
+ "parallax" : "starParallax" ,
128
+ },
129
+ doc = "Column names for position and motion parameters in the astrometric correction catalogs." ,
130
+ )
131
+
78
132
def setDefaults (self ):
79
133
super ().setDefaults ()
80
134
@@ -91,28 +145,39 @@ def getBoxWcs(skymap, tract):
91
145
tractBox = tractInfo .getBBox ()
92
146
return tractBox , wcs
93
147
94
- @classmethod
95
- def callback (cls , inputs , dataId ):
148
+ def callback (self , inputs , dataId ):
96
149
"""Callback function to be used with reconstructor."""
97
- return cls .prepareAssociatedSources (
150
+ return self .prepareAssociatedSources (
98
151
inputs ["skyMap" ],
99
152
dataId ["tract" ],
100
153
inputs ["sourceCatalogs" ],
101
154
inputs ["associatedSources" ],
155
+ inputs ["astrometricCorrectionCatalogs" ],
156
+ inputs ["visitTable" ],
102
157
)
103
158
104
- @classmethod
105
- def prepareAssociatedSources (cls , skymap , tract , sourceCatalogs , associatedSources ):
159
+ def prepareAssociatedSources (
160
+ self ,
161
+ skymap ,
162
+ tract ,
163
+ sourceCatalogs ,
164
+ associatedSources ,
165
+ astrometricCorrectionCatalogs = None ,
166
+ visitTable = None ,
167
+ ):
106
168
"""Concatenate source catalogs and join on associated object index."""
107
169
108
170
# Keep only sources with associations
109
171
sourceCatalogStack = vstack (sourceCatalogs )
110
172
dataJoined = join (sourceCatalogStack , associatedSources , keys = "sourceId" , join_type = "inner" )
111
173
174
+ if astrometricCorrectionCatalogs is not None :
175
+ self .applyAstrometricCorrections (dataJoined , astrometricCorrectionCatalogs , visitTable )
176
+
112
177
# Determine which sources are contained in tract
113
178
ra = np .radians (dataJoined ["coord_ra" ])
114
179
dec = np .radians (dataJoined ["coord_dec" ])
115
- box , wcs = cls .getBoxWcs (skymap , tract )
180
+ box , wcs = self .getBoxWcs (skymap , tract )
116
181
box = Box2D (box )
117
182
x , y = wcs .skyToPixelArray (ra , dec )
118
183
boxSelection = box .contains (x , y )
@@ -123,6 +188,77 @@ def prepareAssociatedSources(cls, skymap, tract, sourceCatalogs, associatedSourc
123
188
124
189
return dataFiltered
125
190
191
+ def applyAstrometricCorrections (self , dataJoined , astrometricCorrectionCatalogs , visitTable ):
192
+ """Use proper motion/parallax catalogs to shift positions to median
193
+ epoch of the visits.
194
+
195
+ Parameters
196
+ ----------
197
+ dataJoined : `astropy.table.Table`
198
+ Table containing source positions, which will be modified in place.
199
+ astrometricCorrectionCatalogs: `dict` [`pd.DataFrame`]
200
+ Dictionary keyed by band with proper motion and parallax catalogs.
201
+ visitTable : `pd.DataFrame`
202
+ Table containing the MJDs of the visits.
203
+ """
204
+ for band in np .unique (dataJoined ["band" ]):
205
+ bandInd = dataJoined ["band" ] == band
206
+ bandSources = dataJoined [bandInd ]
207
+ # Add key for sorting below.
208
+ bandSources ["__index__" ] = np .arange (len (bandSources ))
209
+ bandSourcesDf = bandSources .to_pandas ()
210
+ meanRAs = bandSourcesDf .groupby ("obj_index" )["coord_ra" ].aggregate ("mean" )
211
+ meanDecs = bandSourcesDf .groupby ("obj_index" )["coord_dec" ].aggregate ("mean" )
212
+
213
+ bandPMs = astrometricCorrectionCatalogs [band ]
214
+ with Matcher (meanRAs , meanDecs ) as m :
215
+ idx , i1 , i2 , d = m .query_radius (
216
+ bandPMs [self .config .astrometricCorrectionParameters ["ra" ]],
217
+ bandPMs [self .config .astrometricCorrectionParameters ["dec" ]],
218
+ (self .config .matchingRadius * u .mas ).to (u .degree ),
219
+ return_indices = True ,
220
+ )
221
+
222
+ catRAs = np .zeros_like (meanRAs )
223
+ catDecs = np .zeros_like (meanRAs )
224
+ pmRAs = np .zeros_like (meanRAs )
225
+ pmDecs = np .zeros_like (meanRAs )
226
+ parallaxes = np .zeros (len (meanRAs ))
227
+ catRAs [i1 ] = bandPMs [self .config .astrometricCorrectionParameters ["ra" ]][i2 ]
228
+ catDecs [i1 ] = bandPMs [self .config .astrometricCorrectionParameters ["dec" ]][i2 ]
229
+ pmRAs [i1 ] = bandPMs [self .config .astrometricCorrectionParameters ["pmRA" ]][i2 ]
230
+ pmDecs [i1 ] = bandPMs [self .config .astrometricCorrectionParameters ["pmDec" ]][i2 ]
231
+ parallaxes [i1 ] = bandPMs [self .config .astrometricCorrectionParameters ["parallax" ]][i2 ]
232
+
233
+ pmDf = Table (
234
+ {
235
+ "ra" : catRAs * u .degree ,
236
+ "dec" : catDecs * u .degree ,
237
+ "pmRA" : pmRAs * u .mas / u .yr ,
238
+ "pmDec" : pmDecs * u .mas / u .yr ,
239
+ "parallax" : parallaxes * u .mas ,
240
+ "obj_index" : meanRAs .index ,
241
+ }
242
+ )
243
+
244
+ dataWithPM = join (bandSources , pmDf , keys = "obj_index" , join_type = "left" )
245
+
246
+ visits = bandSourcesDf ["visit" ].unique ()
247
+ mjds = [visitTable .loc [visit ]["expMidptMJD" ] for visit in visits ]
248
+ mjdTable = Table (
249
+ [astropy .time .Time (mjds , format = "mjd" , scale = "tai" ), visits ], names = ["MJD" , "visit" ]
250
+ )
251
+ dataWithMJD = join (dataWithPM , mjdTable , keys = "visit" , join_type = "left" )
252
+ # After astropy 7.0, it should be possible to use "keep_order=True"
253
+ # in the join and avoid sorting.
254
+ dataWithMJD .sort ("__index__" )
255
+ medianMJD = astropy .time .Time (np .median (mjds ), format = "mjd" , scale = "tai" )
256
+
257
+ raCorrection , decCorrection = calculate_apparent_motion (dataWithMJD , medianMJD )
258
+
259
+ dataJoined ["coord_ra" ][bandInd ] = dataWithMJD ["coord_ra" ] - raCorrection .value
260
+ dataJoined ["coord_dec" ][bandInd ] = dataWithMJD ["coord_dec" ] - decCorrection .value
261
+
126
262
def runQuantum (self , butlerQC , inputRefs , outputRefs ):
127
263
inputs = butlerQC .get (inputRefs )
128
264
@@ -135,6 +271,18 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
135
271
sourceCatalogs .append (self .loadData (handle , names ))
136
272
inputs ["sourceCatalogs" ] = sourceCatalogs
137
273
274
+ if self .config .applyAstrometricCorrections :
275
+ astrometricCorrections = {}
276
+ for pmCatRef in inputs ["astrometricCorrectionCatalogs" ]:
277
+ pmCat = pmCatRef .get (
278
+ parameters = {"columns" : self .config .astrometricCorrectionParameters .values ()}
279
+ )
280
+ astrometricCorrections [pmCatRef .dataId ["band" ]] = pd .DataFrame (pmCat )
281
+ inputs ["astrometricCorrectionCatalogs" ] = astrometricCorrections
282
+ else :
283
+ inputs ["astrometricCorrectionCatalogs" ] = None
284
+ inputs ["visitTable" ] = None
285
+
138
286
dataId = butlerQC .quantum .dataId
139
287
plotInfo = self .parsePlotInfo (inputs , dataId , connectionName = "associatedSources" )
140
288
0 commit comments