-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathstl.py
308 lines (260 loc) · 10.8 KB
/
stl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
"""
@file stl.py
@brief STL objects (primitives & formulas)
@author Tzu-Yi Chiu <[email protected]>
"""
from __future__ import annotations
import numpy as np
import itertools
from dataclasses import dataclass
from typing import List, FrozenSet, Set, Tuple
class Primitive:
def robust(self, s: np.ndarray) -> float:
"Compute the robustness degree relative to signal `s`"
raise NotImplementedError("`robust` not implemented")
def satisfied(self, s: np.ndarray) -> bool:
"Verify if satisfied by signal `s`"
return self.robust(s) > 0
def is_child_of(self, parent: Primitive) -> bool:
raise NotImplementedError("`is_child_of` not implemented")
def __hash__(self):
raise NotImplementedError("`__hash__` not implemented")
def __eq__(self, other):
return isinstance(other, Primitive) and hash(self) == hash(other)
def __repr__(self):
raise NotImplementedError("`__repr__` not implemented")
@dataclass
class Eventually(Primitive):
"Ex: Eventually((0, 5), (0, '>', 20), 1) <=> F[0,5](s1>20)"
_interval: Tuple[int, int] # bound delay
_phi: Tuple[int, str, float] # s_d > mu or s_d < mu
_normalize: float = 1. # max - min of mu (to normalize robustness)
def __post_init__(self):
try:
a, b = self._interval
d, comp, mu = self._phi
except ValueError:
raise ValueError('Invalid primitive parameters')
repr_ = f'F[{a},{b}](s{d+1}{comp}{mu:.2f})'
if a * b < 0 or a > b or (a < 0 and b == 0):
raise ValueError(f'Invalid interval: {repr_}')
if d < 0:
raise ValueError(f'Invalid dimension: {repr_}')
if comp not in ['<', '>']:
raise ValueError(f'Invalid comparison: {repr_}')
def robust(self, s: np.ndarray) -> float:
"Compute the robustness degree relative to signal `s`"
a, b = self._interval
slicing = np.arange(a, b+1)
if not slicing.size:
return -1
d, comp, mu = self._phi
if comp == '<':
return (mu - np.min(s[d, slicing])) / self._normalize
return (np.max(s[d, slicing]) - mu) / self._normalize
def is_child_of(self, parent: Primitive) -> bool:
if self == parent or isinstance(parent, Globally):
return False
a, b = self._interval
d, comp, mu = self._phi
# Eventually
a_, b_ = parent._interval
d_, comp_, mu_ = parent._phi
if comp != comp_ or d != d_:
return False
if a < a_ or b > b_:
return False
if comp == '<' and mu <= mu_:
return True
if comp == '>' and mu >= mu_:
return True
return False
def __hash__(self):
return hash(('F', self._interval, self._phi))
def __repr__(self):
a, b = self._interval
d, comp, mu = self._phi
return f'F[{a},{b}](s{d+1}{comp}{mu:.2f})'
@dataclass
class Globally(Primitive):
"Ex: Globally((0, 5), (0, '>', 20), 1) <=> G[0,5](s1>20)"
_interval: Tuple[int, int] # bound delay
_phi: Tuple[int, str, float] # s_d > mu or s_d < mu
_normalize: float = 1. # max - min of mu (to normalize robustness)
def __post_init__(self):
try:
a, b = self._interval
d, comp, mu = self._phi
except ValueError:
raise ValueError('Invalid primitive parameters')
if a == b:
raise ValueError('In case a = b, use Eventually (F).')
repr_ = f'G[{a},{b}](s{d+1}{comp}{mu:.2f})'
if a * b < 0 or a > b or (a < 0 and b == 0):
raise ValueError(f'Invalid interval: {repr_}')
if d < 0:
raise ValueError(f'Invalid dimension: {repr_}')
if comp not in ['<', '>']:
raise ValueError(f'Invalid comparison: {repr_}')
def robust(self, s: np.ndarray) -> float:
"Compute the robustness degree relative to signal `s`"
a, b = self._interval
slicing = np.arange(a, b+1)
if not slicing.size:
return -1
d, comp, mu = self._phi
if comp == '<':
return (mu - np.max(s[d, slicing])) / self._normalize
return (np.min(s[d, slicing]) - mu) / self._normalize
def is_child_of(self, parent: Primitive) -> bool:
if parent == self:
return False
a, b = self._interval
d, comp, mu = self._phi
# Eventually
if isinstance(parent, Eventually):
a_, b_ = parent._interval
d_, comp_, mu_ = parent._phi
if comp != comp_ or d != d_:
return False
if a > b_ or b < a_: # empty intersection
return False
if comp == '<' and mu <= mu_:
return True
if comp == '>' and mu >= mu_:
return True
return False
# Globally
a_, b_ = parent._interval
d_, comp_, mu_ = parent._phi
if comp != comp_ or d != d_:
return False
if a > a_ or b < b_:
return False
if comp == '<' and mu <= mu_:
return True
if comp == '>' and mu >= mu_:
return True
return False
def __hash__(self):
return hash(('G', self._interval, self._phi))
def __repr__(self):
a, b = self._interval
d, comp, mu = self._phi
return f'G[{a},{b}](s{d+1}{comp}{mu:.2f})'
@dataclass
class PrimitiveGenerator:
"(static) generator of primitives"
# (instance attributes)
_s: np.ndarray # signal being explained
_srange: list # list of (min, max, stepsize) for each dimension
_rho: float # robustness degree (~coverage) threshold
_past: bool = False # true if PtSTL, false if STL
def generate(self) -> List[Primitive]:
"Generate STL primitives whose robustness is greater than `rho`"
result = []
sdim, slen = self._s.shape
arange = range(-slen, 0) if self._past else range(slen)
for d_ in range(sdim):
smin, smax, stepsize = self._srange[d_]
mus = np.linspace(smin, smax, num=stepsize, endpoint=False)[1:]
n = smax - smin
for a_ in arange:
for typ_ in ['Eventually', 'Globally']:
b_ = a_ + int(typ_ == 'Globally')
brange = range(b_, 0) if self._past else range(b_, slen)
l = [[typ_], [a_], brange, [d_], ['>', '<']]
for typ, a, b, d, comp in itertools.product(*l):
stop = False
phi0 = eval(typ)((a, b), (d, comp, mus[0]), n)
phi1 = eval(typ)((a, b), (d, comp, mus[-1]), n)
if phi0.robust(self._s) >= self._rho:
u = 0
if phi1.robust(self._s) < self._rho:
l = len(mus) - 1
from_begin = True
else:
stop = True
from_begin = False
else:
from_begin = False
l = 0
if phi1.robust(self._s) >= self._rho:
u = len(mus) - 1
else:
u = len(mus)
stop = True
if not stop:
while True:
phi0 = eval(typ)((a, b), (d, comp, mus[l]), n)
phi1 = eval(typ)((a, b), (d, comp, mus[u]), n)
if (phi0.robust(self._s) >= self._rho and
phi1.robust(self._s) >= self._rho):
break
elif (phi0.robust(self._s) < self._rho and
phi1.robust(self._s) < self._rho):
break
q = (u + l) // 2
if u == q or l == q:
break
phi2 = eval(typ)((a, b), (d, comp, mus[q]), n)
if phi2.robust(self._s) >= self._rho:
u = q
else:
l = q
rng = range(u+1) if from_begin else range(u, len(mus))
for q in rng:
phi = eval(typ)((a, b), (d, comp, mus[q]), n)
result.append(phi)
return result
"""
Any newly instanciated STL formula remains the same instance.
Since the primitives is fixed from the beginning, an STL formula
consisting of conjunction of some of these primitives is just
represented as a frozenset of their indices.
"""
class STL(object):
__cache = {}
# (class attributes) to be set during init
__primitives = [] # list of generated primitives
__parents = {} # dict {child: parents} among primitives
def __new__(cls, indices: FrozenSet[int]=frozenset()):
for child in indices.copy():
indices -= STL.__parents[child]
if indices in STL.__cache:
return STL.__cache[indices]
else:
o = object.__new__(cls)
STL.__cache[indices] = o
return o
def __init__(self, indices: FrozenSet[int]=frozenset()):
self._indices = indices
for child in indices:
self._indices -= STL.__parents[child]
def init(self, primitives: List[Primitive]) -> int:
STL.__primitives = primitives
nb = len(primitives)
STL.__parents = {child: {parent for parent in range(nb)
if STL.__primitives[child].is_child_of(STL.__primitives[parent])}
for child in range(nb)}
return nb
def satisfied(self, s: np.ndarray) -> bool:
"Verify if STL is satisfied by signal `s`"
if s is None:
return False
return all(STL.__primitives[i].satisfied(s) for i in self._indices)
def get_children(self) -> Set[STL]:
length = len(STL.__primitives)
parents = set()
for i in self._indices:
parents.update(STL.__parents[i])
return {STL(self._indices.union([i]))
for i in set(range(length)) - parents} - {self}
def __len__(self):
return len(self._indices)
def __hash__(self):
return hash(self._indices)
def __repr__(self):
if not len(self._indices):
return 'T'
return '^'.join(repr(STL.__primitives[i]) for i in self._indices)