-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patha2cagent.py
74 lines (62 loc) · 3.09 KB
/
a2cagent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#######################################################################
# Copyright (C) 2017 Shangtong Zhang([email protected]) #
# Permission given to modify the code as long as you keep this #
# declaration at the top #
#######################################################################
from ..network import *
from ..component import *
from .BaseAgent import *
class A2CAgent(BaseAgent):
def __init__(self, config):
BaseAgent.__init__(self, config)
self.config = config
self.task = config.task_fn()
self.network = config.network_fn()
self.optimizer = config.optimizer_fn(self.network.parameters())
self.total_steps = 0
self.states = self.task.reset()
self.episode_rewards = []
self.online_rewards = np.zeros(config.num_workers)
def step(self):
config = self.config
storage = Storage(config.rollout_length)
states = self.states
for _ in range(config.rollout_length):
prediction = self.network(config.state_normalizer(states))
next_states, rewards, terminals, _ = self.task.step(to_np(prediction['a']))
self.online_rewards += rewards
rewards = config.reward_normalizer(rewards)
for i, terminal in enumerate(terminals):
if terminals[i]:
self.episode_rewards.append(self.online_rewards[i])
self.online_rewards[i] = 0
storage.add(prediction)
storage.add({'r': tensor(rewards).unsqueeze(-1),
'm': tensor(1 - terminals).unsqueeze(-1)})
states = next_states
self.states = states
prediction = self.network(config.state_normalizer(states))
storage.add(prediction)
storage.placeholder()
advantages = tensor(np.zeros((config.num_workers, 1)))
returns = prediction['v'].detach()
for i in reversed(range(config.rollout_length)):
returns = storage.r[i] + config.discount * storage.m[i] * returns
if not config.use_gae:
advantages = returns - storage.v[i].detach()
else:
td_error = storage.r[i] + config.discount * storage.m[i] * storage.v[i + 1] - storage.v[i]
advantages = advantages * config.gae_tau * config.discount * storage.m[i] + td_error
storage.adv[i] = advantages.detach()
storage.ret[i] = returns.detach()
log_prob, value, returns, advantages, entropy = storage.cat(['log_pi_a', 'v', 'ret', 'adv', 'ent'])
policy_loss = -(log_prob * advantages).mean()
value_loss = 0.5 * (returns - value).pow(2).mean()
entropy_loss = entropy.mean()
self.optimizer.zero_grad()
(policy_loss - config.entropy_weight * entropy_loss +
config.value_loss_weight * value_loss).backward()
nn.utils.clip_grad_norm_(self.network.parameters(), config.gradient_clip)
self.optimizer.step()
steps = config.rollout_length * config.num_workers
self.total_steps += steps