Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
gggxxl committed Nov 6, 2020
2 parents 0b19009 + 7282edb commit 7241445
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
BATCH_SIZE = 32
POLICY_UPDATE = 4
TARGET_UPDATE = 10_000
WARM_STEPS = 50_000
# WARM_STEPS = 50_000
WARM_STEPS = 1
MAX_STEPS = 50_000_000
EVALUATE_FREQ = 100_000

Expand Down
8 changes: 3 additions & 5 deletions utils_drl.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,13 @@ def learn(self, memory: ReplayMemory, batch_size: int) -> float:
values = self.__policy(state_batch.float()).gather(1, action_batch)
values_next = self.__target(next_batch.float()).max(1).values.detach()
expected = (self.__gamma * values_next.unsqueeze(1)) * \
(1. - done_batch) + reward_batch
(1. - done_batch) + reward_batch # 如果done则是r,否则是r + gamma * max Q
loss = F.smooth_l1_loss(values, expected) # smooth l1损失

self.__optimizer.zero_grad() # 将模型的参数梯度初始化为0
print(loss)
loss.backward() # 计算梯度
print(loss)
loss.backward() # 计算梯度,存到__policy.parameters.grad()中
for param in self.__policy.parameters():
param.grad.data.clamp_(-1, 1) # 固定所有参数为[-1, 1]
param.grad.data.clamp_(-1, 1) # 固定所有梯度为[-1, 1]
self.__optimizer.step() # 做一步最优化

return loss.item()
Expand Down

0 comments on commit 7241445

Please sign in to comment.