-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhelper.py
More file actions
289 lines (239 loc) · 10.6 KB
/
helper.py
File metadata and controls
289 lines (239 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions as pyd
from torch.distributions.utils import _standard_normal
__REDUCE__ = lambda b: 'mean' if b else 'none'
def l1(pred, target, reduce=False):
"""Computes the L1-loss between predictions and targets."""
return F.l1_loss(pred, target, reduction=__REDUCE__(reduce))
def mse(pred, target, reduce=False):
"""Computes the MSE loss between predictions and targets."""
return F.mse_loss(pred, target, reduction=__REDUCE__(reduce))
def _get_out_shape(in_shape, layers):
"""Utility function. Returns the output shape of a network for a given input shape."""
x = torch.randn(*in_shape).unsqueeze(0)
return (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x).squeeze(0).shape
def orthogonal_init(m):
"""Orthogonal layer initialization."""
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
gain = nn.init.calculate_gain('relu')
nn.init.orthogonal_(m.weight.data, gain)
if m.bias is not None:
nn.init.zeros_(m.bias)
def ema(m, m_target, tau):
"""Update slow-moving average of online network (target network) at rate tau."""
with torch.no_grad():
for p, p_target in zip(m.parameters(), m_target.parameters()):
p_target.data.lerp_(p.data, tau)
def set_requires_grad(net, value):
"""Enable/disable gradients for a given (sub)network."""
for param in net.parameters():
param.requires_grad_(value)
class TruncatedNormal(pyd.Normal):
"""Utility class implementing the truncated normal distribution."""
def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
super().__init__(loc, scale, validate_args=False)
self.low = low
self.high = high
self.eps = eps
def _clamp(self, x):
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
x = x - x.detach() + clamped_x.detach()
return x
def sample(self, clip=None, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape,
dtype=self.loc.dtype,
device=self.loc.device)
eps *= self.scale
if clip is not None:
eps = torch.clamp(eps, -clip, clip)
x = self.loc + eps
return self._clamp(x)
class NormalizeImg(nn.Module):
"""Normalizes pixel observations to [0,1) range."""
def __init__(self):
super().__init__()
def forward(self, x):
return x.div(255.)
class Flatten(nn.Module):
"""Flattens its input to a (batched) vector."""
def __init__(self):
super().__init__()
def forward(self, x):
return x.view(x.size(0), -1)
def enc(cfg):
"""Returns a TOLD encoder."""
if cfg.modality == 'pixels':
C = int(3*cfg.frame_stack)
layers = [NormalizeImg(),
nn.Conv2d(C, cfg.num_channels, 7, stride=2), nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2), nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2), nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2), nn.ReLU()]
out_shape = _get_out_shape((C, cfg.img_size, cfg.img_size), layers)
layers.extend([Flatten(), nn.Linear(np.prod(out_shape), cfg.latent_dim)])
else:
layers = [nn.Linear(cfg.obs_shape[0], cfg.enc_dim), nn.ELU(),
nn.Linear(cfg.enc_dim, cfg.latent_dim)]
return nn.Sequential(*layers)
def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.ELU()):
"""Returns an MLP."""
if isinstance(mlp_dim, int):
mlp_dim = [mlp_dim, mlp_dim]
return nn.Sequential(
nn.Linear(in_dim, mlp_dim[0]), act_fn,
nn.Linear(mlp_dim[0], mlp_dim[1]), act_fn,
nn.Linear(mlp_dim[1], out_dim))
def q(cfg, act_fn=nn.ELU()):
"""Returns a Q-function that uses Layer Normalization."""
return nn.Sequential(nn.Linear(cfg.latent_dim+cfg.action_dim, cfg.mlp_dim), nn.LayerNorm(cfg.mlp_dim), nn.Tanh(),
nn.Linear(cfg.mlp_dim, cfg.mlp_dim), nn.ELU(),
nn.Linear(cfg.mlp_dim, 1))
class RandomShiftsAug(nn.Module):
"""
Random shift image augmentation.
Adapted from https://github.com/facebookresearch/drqv2
"""
def __init__(self, cfg):
super().__init__()
self.pad = int(cfg.img_size/21) if cfg.modality == 'pixels' else None
def forward(self, x):
if not self.pad:
return x
n, c, h, w = x.size()
assert h == w
padding = tuple([self.pad] * 4)
x = F.pad(x, padding, 'replicate')
eps = 1.0 / (h + 2 * self.pad)
arange = torch.linspace(-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype)[:h]
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
shift = torch.randint(0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype)
shift *= 2.0 / (h + 2 * self.pad)
grid = base_grid + shift
return F.grid_sample(x, grid, padding_mode='zeros', align_corners=False)
class Episode(object):
"""Storage object for a single episode."""
def __init__(self, cfg, init_obs):
self.cfg = cfg
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float32 if cfg.modality == 'state' else torch.uint8
self.obs = torch.empty((cfg.episode_length+1, *init_obs.shape), dtype=dtype, device=self.device)
self.obs[0] = torch.tensor(init_obs, dtype=dtype, device=self.device)
self.action = torch.empty((cfg.episode_length, cfg.action_dim), dtype=torch.float32, device=self.device)
self.reward = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
self.cumulative_reward = 0
self.done = False
self._idx = 0
def __len__(self):
return self._idx
@property
def first(self):
return len(self) == 0
def __add__(self, transition):
self.add(*transition)
return self
def add(self, obs, action, reward, done):
self.obs[self._idx+1] = torch.tensor(obs, dtype=self.obs.dtype, device=self.obs.device)
self.action[self._idx] = action
self.reward[self._idx] = reward
self.cumulative_reward += reward
self.done = done
self._idx += 1
class ReplayBuffer():
"""
Storage and sampling functionality for training TD-MPC / TOLD.
The replay buffer is stored in GPU memory when training from state.
Uses prioritized experience replay by default."""
def __init__(self, cfg):
self.cfg = cfg
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.capacity = min(cfg.train_steps, cfg.max_buffer_size)
dtype = torch.float32 if cfg.modality == 'state' else torch.uint8
obs_shape = cfg.obs_shape if cfg.modality == 'state' else (3, *cfg.obs_shape[-2:])
self._obs = torch.empty((self.capacity+1, *obs_shape), dtype=dtype, device=self.device)
self._last_obs = torch.empty((self.capacity//cfg.episode_length, *cfg.obs_shape), dtype=dtype, device=self.device)
self._action = torch.empty((self.capacity, cfg.action_dim), dtype=torch.float32, device=self.device)
self._reward = torch.empty((self.capacity,), dtype=torch.float32, device=self.device)
self._priorities = torch.ones((self.capacity,), dtype=torch.float32, device=self.device)
self._eps = 1e-6
self._full = False
self.idx = 0
def __add__(self, episode: Episode):
self.add(episode)
return self
def add(self, episode: Episode):
self._obs[self.idx:self.idx+self.cfg.episode_length] = episode.obs[:-1] if self.cfg.modality == 'state' else episode.obs[:-1, -3:]
self._last_obs[self.idx//self.cfg.episode_length] = episode.obs[-1]
self._action[self.idx:self.idx+self.cfg.episode_length] = episode.action
self._reward[self.idx:self.idx+self.cfg.episode_length] = episode.reward
if self._full:
max_priority = self._priorities.max().to(self.device).item()
else:
max_priority = 1. if self.idx == 0 else self._priorities[:self.idx].max().to(self.device).item()
mask = torch.arange(self.cfg.episode_length) >= self.cfg.episode_length-self.cfg.horizon
new_priorities = torch.full((self.cfg.episode_length,), max_priority, device=self.device)
new_priorities[mask] = 0
self._priorities[self.idx:self.idx+self.cfg.episode_length] = new_priorities
self.idx = (self.idx + self.cfg.episode_length) % self.capacity
self._full = self._full or self.idx == 0
def update_priorities(self, idxs, priorities):
self._priorities[idxs] = priorities.squeeze(1).to(self.device) + self._eps
def _get_obs(self, arr, idxs):
if self.cfg.modality == 'state':
return arr[idxs]
obs = torch.empty((self.cfg.batch_size, 3*self.cfg.frame_stack, *arr.shape[-2:]), dtype=arr.dtype, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
obs[:, -3:] = arr[idxs].to(self.device)
_idxs = idxs.clone()
mask = torch.ones_like(_idxs, dtype=torch.bool)
for i in range(1, self.cfg.frame_stack):
mask[_idxs % self.cfg.episode_length == 0] = False
_idxs[mask] -= 1
obs[:, -(i+1)*3:-i*3] = arr[_idxs].to(self.device)
return obs.float()
def sample(self):
probs = (self._priorities if self._full else self._priorities[:self.idx]) ** self.cfg.per_alpha
probs /= probs.sum()
total = len(probs)
idxs = torch.from_numpy(np.random.choice(total, self.cfg.batch_size, p=probs.cpu().numpy(), replace=not self._full)).to(self.device)
weights = (total * probs[idxs]) ** (-self.cfg.per_beta)
weights /= weights.max()
obs = self._get_obs(self._obs, idxs)
next_obs_shape = self._last_obs.shape[1:] if self.cfg.modality == 'state' else (3*self.cfg.frame_stack, *self._last_obs.shape[-2:])
next_obs = torch.empty((self.cfg.horizon+1, self.cfg.batch_size, *next_obs_shape), dtype=obs.dtype, device=obs.device)
action = torch.empty((self.cfg.horizon+1, self.cfg.batch_size, *self._action.shape[1:]), dtype=torch.float32, device=self.device)
reward = torch.empty((self.cfg.horizon+1, self.cfg.batch_size), dtype=torch.float32, device=self.device)
for t in range(self.cfg.horizon+1):
_idxs = idxs + t
next_obs[t] = self._get_obs(self._obs, _idxs+1)
action[t] = self._action[_idxs]
reward[t] = self._reward[_idxs]
mask = (_idxs+1) % self.cfg.episode_length == 0
next_obs[-1, mask] = self._last_obs[_idxs[mask]//self.cfg.episode_length].to(self.device).float()
if not action.is_cuda:
action, reward, idxs, weights = \
action.to(self.device), reward.to(self.device), idxs.to(self.device), weights.to(self.device)
return obs, next_obs, action, reward.unsqueeze(2), idxs, weights
def linear_schedule(schdl, step):
"""
Outputs values following a linear decay schedule.
Adapted from https://github.com/facebookresearch/drqv2
"""
try:
return float(schdl)
except ValueError:
match = re.match(r'linear\((.+),(.+),(.+)\)', schdl)
if match:
init, final, duration = [float(g) for g in match.groups()]
mix = np.clip(step / duration, 0.0, 1.0)
return (1.0 - mix) * init + mix * final
raise NotImplementedError(schdl)