-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathplanner.py
67 lines (63 loc) · 3.28 KB
/
planner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from torch import jit
# Model-predictive control planner with cross-entropy method and learned transition model
class MPCPlanner(jit.ScriptModule):
__constants__ = ['action_size', 'planning_horizon', 'optimisation_iters', 'candidates', 'top_candidates']
def __init__(
self,
action_size,
planning_horizon,
optimisation_iters,
candidates,
top_candidates,
transition_model,
reward_model,
):
super().__init__()
self.transition_model, self.reward_model = transition_model, reward_model
self.action_size = action_size
self.planning_horizon = planning_horizon
self.optimisation_iters = optimisation_iters
self.candidates, self.top_candidates = candidates, top_candidates
@jit.script_method
def forward(self, belief, state):
B, H, Z = belief.size(0), belief.size(1), state.size(1)
belief, state = belief.unsqueeze(dim=1).expand(B, self.candidates, H).reshape(-1, H), state.unsqueeze(
dim=1
).expand(B, self.candidates, Z).reshape(-1, Z)
# Initialize factorized belief over action sequences q(a_t:t+H) ~ N(0, I)
action_mean, action_std_dev = torch.zeros(
self.planning_horizon, B, 1, self.action_size, device=belief.device
), torch.ones(self.planning_horizon, B, 1, self.action_size, device=belief.device)
for _ in range(self.optimisation_iters):
# print("optimization_iters",_)
# Evaluate J action sequences from the current belief (over entire sequence at once, batched over particles)
actions = (
action_mean
+ action_std_dev
* torch.randn(self.planning_horizon, B, self.candidates, self.action_size, device=action_mean.device)
).view(
self.planning_horizon, B * self.candidates, self.action_size
) # Sample actions (time x (batch x candidates) x actions)
# Sample next states
beliefs, states, _, _ = self.transition_model(
state, actions, belief
) # [12, 1000, 200] [12, 1000, 30] : 12 horizon steps; 1000 candidates
# Calculate expected returns (technically sum of rewards over planning horizon)
returns = (
self.reward_model(beliefs.view(-1, H), states.view(-1, Z)).view(self.planning_horizon, -1).sum(dim=0)
) # output from r-model[12000]->view[12, 1000]->sum[1000]
# Re-fit belief to the K best action sequences
_, topk = returns.reshape(B, self.candidates).topk(self.top_candidates, dim=1, largest=True, sorted=False)
topk += self.candidates * torch.arange(0, B, dtype=torch.int64, device=topk.device).unsqueeze(
dim=1
) # Fix indices for unrolled actions
best_actions = actions[:, topk.view(-1)].reshape(
self.planning_horizon, B, self.top_candidates, self.action_size
)
# Update belief with new means and standard deviations
action_mean, action_std_dev = best_actions.mean(dim=2, keepdim=True), best_actions.std(
dim=2, unbiased=False, keepdim=True
)
# Return first action mean µ_t
return action_mean[0].squeeze(dim=1)