-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplottingFuncs.py
More file actions
428 lines (399 loc) · 16 KB
/
plottingFuncs.py
File metadata and controls
428 lines (399 loc) · 16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
#!/usr/bin/env python3
"""
.. module:: plottingFuncs
:synopsis: Main methods for dealing with the plotting of a validation plot
.. moduleauthor:: Andre Lessa <[email protected]>
"""
import logging,os,sys,numpy,random,copy
from typing import Union, Optional, Set, List
#sys.path.append('../')
from array import array
import math, ctypes
logger = logging.getLogger(__name__)
from smodels.base.physicsUnits import fb, GeV, pb
from smodels.experiment.txnameObj import TxNameData
from smodels_utils.dataPreparation.massPlaneObjects import MassPlane
from smodels_utils.helper.prettyDescriptions import prettyTxname
from validationHelpers import getAxisType, prettyAxes
import numpy as np
try:
from smodels.theory.auxiliaryFunctions import unscaleWidth,rescaleWidth
except:
pass
try:
from smodels.theory.auxiliaryFunctions import removeUnits
except:
from backwardCompatibility import removeUnits
import time
rt0 = [ time.time() ]
def timeStamp ( comment, t = None ):
if t == "start":
t0 = time.time()
rt0[0] = t0
t = t0
if t == None:
t = time.time()
dt = t-rt0[0]
print ( f"{dt:.2f}: {comment}" )
def getColormap():
""" our matplotlib colormap for pretty plots """
# return plt.cm.RdYlBu_r
# return plt.cm.RdYlGn_r
from matplotlib.colors import LinearSegmentedColormap
# c = ["darkred","red","lightcoral","lightyellow", "palegreen","green","darkgreen"]
c = ["darkgreen", "green", "palegreen", "lightgoldenrodyellow", "lightcoral", "red", (.9,0,0), (.7,0,0) ]
#v = [0,.15,.4,.5,0.6,.9,1.]
# v = [0,.1,.3,.67,0.8,.9,1.]
v = [0,.11,.22,.33,0.52,.7,.85,1.]
l = list(zip(v,c))
cmap=LinearSegmentedColormap.from_list('rg',l, N=256)
return cmap
errMsgIssued = { "axis": False }
def convertNewAxes ( newa ):
""" convert new types of axes (dictionary) to old (lists) """
axes = copy.deepcopy(newa)
if type(newa)==list:
return axes[::-1]
if type(newa)==dict:
if len ( newa ) == 0:
return []
axes = [ newa["x"] ]
if "y" in newa:
axes.append ( newa["y"] )
if "z" in newa:
axes.append ( newa["z"] )
return axes[::-1]
if not errMsgIssued["axis"]:
print ( f"[plotRatio] cannot convert axis '{newa}'" )
errMsgIssued["axis"]=True
return None
def isWithinRange ( xyrange : list, xy : float ):
""" check if xy is within xyrange """
if xyrange == None:
return True
return xyrange[0] <= xy <= xyrange[1]
def filterWithinRanges ( points : dict, xrange : Optional[list], \
yrange : Optional[list], defRetZeroes : bool = False ):
""" filter from points all that is not within xrange or yrange
:param defRetZeroes: if true, then return list of zeroes if no y coordinates
"""
if type(points)==list:
## we have v2
return filterWithinRangesV2 ( points, xrange,
yrange, defRetZeroes )
pxs = points["x"]
px = []
if not "y" in points:
for x in pxs:
if not isWithinRange ( xrange, x ):
continue
px.append ( x )
py = [0.] * len(px) if defRetZeroes else None
return px, py
pys = points["y"]
px, py = [], []
for x,y in zip ( pxs, pys ):
if not isWithinRange ( xrange, x ):
continue
if not isWithinRange ( yrange, y ):
continue
px.append ( x )
py.append ( y )
return px, py
def filterWithinRangesV2 ( points : list, xrange : Optional[list], \
yrange : Optional[list], defRetZeroes : bool = False ):
""" filter from points all that is not within xrange or yrange
:param defRetZeroes: if true, then return list of zeroes if no y coordinates
"""
# v2
ret = []
px, py = [], [] ## backwards compatibility
for line in points:
newline = []
for point in line:
pointIsGood = True
if not isWithinRange ( xrange, point["x"] ):
pointIsGood = False
if "y" in point and not isWithinRange( yrange, point["y"] ):
pointIsGood = False
if not "y" in point and defRetZeroes:
point["y"]=0.
if pointIsGood:
newline.append ( point )
px.append ( point["x"] )
py.append ( point["y"] )
if len(newline)>0:
ret.append ( newline )
# return ret ## eventually we will return this list of points
return px, py ## fixme eventually drop this
def getAxisRange ( options : dict, label : str = "xaxis" ):
""" given an options dictionary, obtain a range for the axis named
<label>
:returns: range list, e.g. [0,1000], or None
"""
if not "style" in options:
return None
styles = options["style"].split(";")
for style in styles:
if label in style:
plabel = style.find(label)
if style.find ( ":", plabel ) > 0:
plabel = style.find(":", plabel)
pstart = style.find("[",plabel)
pend = style.find("]",pstart)
try:
xrange=eval(style[pstart:pend+1] )
return xrange
except Exception as e:
logger.error ( f"when evaluating {label} range: {e}" )
logger.error ( f" ´-- style {options['style']}->{style[pstart:pend+1]}" )
return None
def getClosestValue ( x : float, y : float , graph : dict , dmax : float = 1. ):
""" from the graph dictionary, return point closest to x,y
:returns: closest value of graph dictionary, as long as its closer than dmax.
else return nan
"""
dmin, v = float("inf"), None
for t in graph:
d = (t["x"]-x)**2 + (t["y"]-y)**2
if d < dmax:
return v
if d < dmin:
dmin = d
v = t["r"]
#if dmin < dmax:
# return v
return float("nan")
def getExclusionCurvesFor(expResult,txname=None,axes=None, get_all=False,
expected=False ):
"""
Reads exclusion_lines.json and returns the TGraph objects for the exclusion
curves. If txname is defined, returns only the curves corresponding
to the respective txname. If axes is defined, only returns the curves
for that axis
:param expResult: an ExpResult object
:param txname: the TxName in string format (i.e. T1tttt)
:param axes: the axes definition in string format (e.g. [x, y, 60.0], [x, y, 60.0]])
:param get_all: Get also the +-1 sigma curves?
:param expected: if true, get expected, not observed
:return: a dictionary, where the keys are the TxName strings
and the values are the respective list of TGraph objects.
"""
import json
if type(expResult)==list:
expResult=expResult[0]
jsonfile = os.path.join(expResult.path,'exclusions.json')
if not os.path.isfile(jsonfile):
jsonfile = os.path.join(expResult.path,'exclusion_lines.json')
if not os.path.isfile(jsonfile):
logger.error( f"json file {jsonfile} not found" )
if os.path.exists ( os.path.join ( expResult.path, "sms.root" ) ):
logger.warning ( f"trying with sms.root, but please switch!" )
from rootPlottingFuncs import getExclusionCurvesForFromSmsRoot
return getExclusionCurvesForFromSmsRoot ( expResult, txname, axes,
get_all, expected )
# no exclusion_lines as well as no sms.root
return None
from smodels_utils.helper import various
return various.getExclusionCurvesFor ( jsonfile, txname, axes, get_all,
expected )
def getDatasetDescription ( validationPlot, maxLength : int = 100 ) -> str:
""" get the description of the dataset that appears as a subtitle
in e.g. the ugly plots """
subtitle = f"best of {len(validationPlot.expRes.datasets)} SRs: "
if len(validationPlot.expRes.datasets)==1:
subtitle = f"single SR: "
if validationPlot.validationType == "tpredcomb":
subtitle = f"{len(validationPlot.expRes.datasets)} tpreds: "
if hasattr ( validationPlot.expRes.globalInfo, "jsonFiles" ) and \
validationPlot.combine == True:
## pyhf combination
subtitle = f"pyhf combining {len(validationPlot.expRes.datasets)} SRs: "
if hasattr ( validationPlot.expRes.globalInfo, "mlModels" ) and \
validationPlot.combine == True:
subtitle = f"NN combining {len(validationPlot.expRes.datasets)} SRs: "
for dataset in validationPlot.expRes.datasets:
ds_txnames = map ( str, dataset.txnameList )
if not validationPlot.txName in ds_txnames:
continue
dataId = str(dataset.dataInfo.dataId)
if len(dataId)>8:
dataId = f"{dataId[:7]}*"
subtitle+=f"{dataId}, "
subtitle = subtitle[:-2]
if hasattr ( validationPlot.expRes.globalInfo, "covariance" ) and \
validationPlot.combine == True:
ver = ""
dI = validationPlot.expRes.datasets[0].dataInfo
if hasattr ( dI, "thirdMoment") and dI.thirdMoment != None:
ver=" (SLv2)"
subtitle = f"combination{ver} of {len(validationPlot.expRes.datasets)} signal regions"
def find_all(a_str, sub):
start = 0
while True:
start = a_str.find(sub, start)
if start == -1: return
yield start
start += len(sub) # use start += 1 to find overlapping matches
if len(subtitle) > maxLength:
pos = maxLength
idx = numpy.array ( list ( find_all ( subtitle, "," ) ) )
p1 = idx[idx<maxLength]
if len(p1)>0:
pos = p1[-1]
subtitle = f"{subtitle[:pos]}, ..."
if len(validationPlot.expRes.datasets) == 1 and \
type(validationPlot.expRes.datasets[0].dataInfo.dataId)==type(None):
subtitle = ""
if hasattr ( validationPlot, "meta" ) and "spey" in validationPlot.meta:
# subtitle += f" [spey {validationPlot.meta['spey']}]"
subtitle += f" [spey]"
return subtitle
def getFigureUrl( validationPlot ):
""" get the URL of the figure, as a string """
txname = validationPlot.expRes.datasets[0].txnameList[0]
txurl = txname.figureUrl
txaxes = "???"
if hasattr ( txname, "axes" ):
txaxes = txname.axes
else:
txaxes = txname.axesMap
if isinstance(txurl,str):
return txname.figureUrl
if not txurl:
return None
if type(txurl) != type(txaxes):
logger.error( f"figureUrl ({txurl}) and axes ({txaxes}) are not of the same type" )
return None
elif isinstance(txurl,list) and len(txurl) != len(txaxes):
logger.warning( f"for {txname} -- figureUrl ({len(txurl)}) and axes ({len(txaxes)}) are not of the same length:" )
"""
for i in txurl:
print ( f" `- {i}" )
for i in txaxes:
print ( f" `- {i}" )
"""
return None
if not validationPlot.axes in txaxes:
return None
pos = [i for i,x in enumerate(txaxes) if x==validationPlot.axes ]
if len(pos)!=1:
logger.error(f"found axes {len(pos)} times. Did you declare several maps for the same analysis/dataset/topology combo? Will exit, please fix!")
sys.exit()
return txurl[pos[0]]
def convertOrigData ( txnameData : TxNameData ):
""" convert the original data in txnameobj to lists """
ret = []
for t in txnameData.origdata:
ret.append ( list(t))
return ret
def getGridPointsV2 ( validationPlot ):
""" retrieve the grid points of the upper limit / efficiency map,
for axes of SModelS v2 type """
ret = []
massPlane = MassPlane.fromString( validationPlot.txName, validationPlot.axes )
for dataset in validationPlot.expRes.datasets:
txNameObj = None
for ctr,txn in enumerate(dataset.txnameList):
if txn.txName == validationPlot.txName:
txNameObj = dataset.txnameList[ctr]
break
if txNameObj == None:
logger.info ( "no grid points: did not find txName" )
return []
if not txNameObj.txnameData._keep_values:
logger.info ( "no grid points: _keep_values is set to False" )
return []
if not hasattr ( txNameObj.txnameData, "origdata"):
logger.info ( "no grid points: cannot find origdata (maybe try a forced rebuild of the database via runValidation.py -f)" )
return []
origdata = convertOrigData ( txNameObj.txnameData )
axisType = getAxisType ( validationPlot.axes )
if axisType == "v2":
from sympy import var
x,y,z,w = var ( "x y z w" )
axes = eval ( validationPlot.axes )
for ctr,pt in enumerate(origdata):
# masses = removeUnits ( pt[0], standardUnits=GeV )
# n = int ( len(pt)/2 )
masses = []
offset = 0
for ax in axes:
tmp = pt[offset:offset+len(ax)]
offset += len(ax)
masses.append ( tmp )
# masses = [ pt[:n], pt[n:] ] ## silly hack for now
coords = massPlane.getXYValues(masses)
if not coords == None and not coords in ret:
ret.append ( coords )
else:
for ctr,masses in enumerate(origdata):
print ( "masses", masses)
coords = massPlane.getXYValues(masses)
if not coords == None and not coords in ret:
ret.append ( coords )
logger.info ( f"found {len(ret)} gridpoints" )
## we will need this for .dataToCoordinates
return ret
def getGridPoints ( validationPlot ) -> List:
""" retrieve the grid points of the upper limit / efficiency map.
"""
ret = []
axisType = getAxisType(validationPlot.axes)
if axisType == "v2":
return getGridPointsV2 ( validationPlot )
massPlane = MassPlane.fromString( validationPlot.txName, validationPlot.axes )
massesToCoords = {} ## cache the massesToCoords mapping
for dataset in validationPlot.expRes.datasets:
txNameObj = None
for ctr,txn in enumerate(dataset.txnameList):
if txn.txName == validationPlot.txName:
txNameObj = dataset.txnameList[ctr]
break
if txNameObj == None:
logger.info ( "no grid points: did not find txName" )
return []
if not txNameObj.txnameData._keep_values:
logger.info ( "no grid points: _keep_values is set to False" )
return []
if not hasattr ( txNameObj.txnameData, "origdata"):
logger.info ( "no grid points: cannot find origdata (maybe try a forced rebuild of the database via runValidation.py -f)" )
return []
origdata = convertOrigData ( txNameObj.txnameData )
for ctr,cmasses in enumerate(origdata):
if tuple(cmasses) in massesToCoords:
continue
masses = copy.deepcopy ( cmasses )
## FIXME not sure if this works for widths
for i,mass in enumerate(masses):
# info is, e.g.: (1,'mass',GeV)
masses[i]=(i+1,mass)
coords = massPlane.getXYValues(masses)
massesToCoords[tuple(cmasses)] = coords
if not coords == None and not coords in ret:
ret.append ( coords )
logger.info ( f"found {len(ret)} gridpoints" )
## we will need this for .dataToCoordinates
return ret
def yIsLog ( validationPlot ):
""" determine if to use log for y axis """
logY = False
if not "{" in validationPlot.axes: ## axis v2
A = validationPlot.axes.replace(" ","")
p1 = A.find("(")
p2 = A.find(")")
py = A.find("y")
if py == -1:
py = A.find("w")
if p1 < py < p2 and A[py-1]==",":
logY = True
return logY
# for v3 we look at the axis["y"] values
yvalues = set()
for d in validationPlot.data:
if "axes" in d and "y" in d["axes"]:
yvalues.add ( d["axes"]["y"] )
if len(yvalues)>0:
if 1e-40<max(yvalues)<1e-1:
logY = True
return logY