-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmcts.py
246 lines (211 loc) · 7.99 KB
/
mcts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""
@file mcts.py
@brief tree object implementing MCTS steps
@author Tzu-Yi Chiu <[email protected]>
Inspired from:
A minimal implementation of Monte Carlo tree search (MCTS) in Python 3
Luke Harold Miles, July 2019, Public Domain Dedication
https://gist.github.com/qpwo/c538c6f73727e254fdc7fab81024f6e1
"""
import os
import numpy as np
from typing import List, Tuple
from collections import defaultdict
import math
from stl import STL
from simulator import Simulator
class MCTS:
"Monte Carlo tree searcher. Roll-out the tree then choose a move."
def __init__(self, simulator: Simulator,
epsilon: float,
tau: float,
batch_size: int,
max_depth: int,
max_iter: int) -> None:
self.simulator = simulator
self.epsilon = epsilon
self.tau = tau
self.batch_size = batch_size
self.max_depth = max_depth
self.max_iter = max_iter
self.children = defaultdict(set)
self.ancestors = defaultdict(set)
self.pruned = set()
# Monte-Carlo (precision = Q/N)
self.Q = defaultdict(int)
self.N = defaultdict(int)
# Hyperparameter for ucb1tuned
self.ce = 2
# Found an anchor
self.finished = False
def choose(self, node: STL) -> List[STL]:
"Choose the best successor of node (choose a move in the game)"
def ucb1tuned(n):
p, N = self.precision(n), self.N[n]
if not N:
return float('-inf')
tmp = math.sqrt(self.ce * math.log(self.N[node]) / N)
return p - tmp * math.sqrt(min(0.25, p * (1 - p) + tmp))
best = max(self.children[node], key=ucb1tuned)
if self.precision(best) < self.tau:
return best
self.finished = True
return self._best_anchors(best)
def _best_anchors(self, node: STL) -> List[STL]:
anchors = [node]
while True:
try:
node = next(iter(n for n in self.ancestors[node]
if self.precision(n) >= self.tau))
anchors.append(node)
except StopIteration:
return anchors
def train(self, node: STL) -> Tuple[int, float]:
"""
Rollout the tree from `node` until error is smaller than `epsilon`
:param node: STL formula from which rollouts are performed
:returns: number of rollouts and error
"""
err = 1.0
i = 0
best = None
while err > self.epsilon:
if i >= self.max_iter:
self.finished = True
break
i += 1
to_print = f'\033[1;93m Iter {i} Err {err:5.2%} Best {best} '
to_print += f'({self.precision(best):5.2%}'
to_print += f'={self.Q[best]}/{self.N[best]})\033[1;m'
print(f'{to_print:<80}', end='\r')
self._rollout(node)
if self.children[node]:
best = self._select(node)
p, N = self.precision(best), self.N[best]
if N:
tmp = math.sqrt(self.ce * math.log(self.N[node]) / N)
err = tmp * math.sqrt(min(0.25, p * (1 - p) + tmp))
err = min(err, 1.0)
return i, err
def precision(self, node: STL) -> float:
"Empirical precision of `node`"
if not self.N[node]:
return 0.0
return self.Q[node] / self.N[node]
def clean(self, parent: STL, child: STL) -> None:
"""
Clean up useless memory in the tree.
"""
self.Q.pop(parent, None)
self.N.pop(parent, None)
for node in self.children[parent] - self.children[child] - {child}:
self.Q.pop(node, None)
self.N.pop(node, None)
self.children.pop(node, None)
self.children.pop(parent, None)
def _rollout(self, node: STL) -> List[STL]:
"""
Make the tree one layer better (train for one iteration).
:param node: STL formula from which rollouts are performed
:returns: the selected path
"""
path = self._select_path(node)
# Sample in mini-batch mode
samples, scores = [], []
for _ in range(self.batch_size):
sample, score = self.simulator.simulate()
samples.append(sample)
scores.append(score)
# If cov(leaf) is too low then prune the leaf
# else backpropagate sample and score to relevant ancestors
leaf = path[-1]
if any(leaf.satisfied(s) for s in samples):
for i in range(self.batch_size):
self._backpropagate(path, samples[i], scores[i])
else:
self._prune(leaf)
return path
def _select_path(self, node: STL) -> List[STL]:
"Find a path leading to an unexplored descendent of `node`"
path = []
while True:
path.append(node)
if len(node) >= self.max_depth:
return path
if not self.children[node]: # not yet expanded
if self.N[node]: # already explored
self._expand(node)
else:
return path
node = self._select(node)
def _expand(self, node: STL) -> None:
"Update the `children` dict with the children of `node`"
self.children[node] = node.get_children() - self.pruned
for child in self.children[node]:
self.ancestors[child].add(node)
def _prune(self, node: STL) -> None:
"Prune `node` from the tree"
self.pruned.add(node)
self.Q.pop(node, None)
self.N.pop(node, None)
for n in self.ancestors[node]:
self.children[n].discard(node)
self.ancestors.pop(node, None)
def _backpropagate(self, path: List[STL],
sample: np.ndarray,
score: int) -> None:
"""(binary search)
path: root =: phi_0 -> ... -> phi_m := leaf
find i* s.t. sample satisfies phi_l for all 0 <= l <= i*
otherwise for all l > i*
"""
lo, hi = 0, len(path)
while lo < hi:
mid = (lo + hi) // 2
if path[mid].satisfied(sample):
lo = mid + 1
else:
hi = mid
#ancestors = set(path[:lo])
#for node in path[:lo]:
# ancestors.update(self.ancestors[node])
#for node in ancestors:
# self.Q[node] += score
# self.N[node] += 1
# find all ancestors of the critical node
if lo == 0:
return
critical = path[lo - 1]
length = 0
ancestors = {critical}
iterator = iter(ancestors)
while len(ancestors) > length:
length = len(ancestors)
to_update = set()
while True:
try:
node = next(iterator)
to_update.update(self.ancestors[node])
except StopIteration:
break
iterator = iter(to_update)
ancestors.update(to_update)
for node in ancestors:
self.Q[node] += score
self.N[node] += 1
def _select(self, node: STL) -> STL:
"""
Select a child of `node`, balancing exploration & exploitation
"""
def ucb1tuned(n):
p, N = self.precision(n), self.N[n]
if not N:
return float('inf')
tmp = math.sqrt(self.ce * math.log(self.N[node]) / N)
return p + tmp * math.sqrt(min(0.25, p * (1 - p) + tmp))
return max(self.children[node], key=ucb1tuned)
def log(self, stl: STL) -> str:
q, n = self.Q[stl], self.N[stl]
if not n:
return f'{stl} ({q}/{n})'
return f'{stl} ({q}/{n}={q/n:5.2%})'