Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
gggxxl committed Nov 8, 2020
1 parent f461988 commit ec409d4
Show file tree
Hide file tree
Showing 7 changed files with 4 additions and 32 deletions.
23 changes: 0 additions & 23 deletions .vscode/settings.json

This file was deleted.

1 change: 0 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,3 @@ def main():

if __name__ == '__main__':
main()

5 changes: 1 addition & 4 deletions utils_drl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from utils_memory import Experience
from utils_model import DQN

import pdb


class Agent(object):

Expand Down Expand Up @@ -84,12 +82,11 @@ def learn(self, memory: Experience, step: int) -> float:
values = self.__policy(state_batch.float()).gather(1, action_batch)
# values_next = self.__target(next_batch.float()).max(1).values.detach() # 这里还是nature dqn 没有用ddqn 虽都是双网络
values_next = self.__target(next_batch.float()).gather(
1, self.__policy(next_batch.float()).max(1).indices.unsqueeze(1)).detach() # 改成dqn
1, self.__policy(next_batch.float()).max(1).indices.unsqueeze(1)).detach() # 改成dqn
expected = (self.__gamma * values_next) * \
(1. - done_batch) + reward_batch # 如果done则是r(考虑t时刻done,没有t+1时刻),否则是r + gamma * max Q

td_error = (expected - values).detach()
pdb.set_trace()
memory.update_priority(rank_e_id, td_error.cpu().numpy())

values = values.mul(w)
Expand Down
2 changes: 1 addition & 1 deletion utils_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def reset(
init_reward = 0.
observations = []
frames = []
for _ in range(5): # no-op
for _ in range(5): # no-op
obs, reward, done = self.step(0) # 运行一步
observations.append(obs)
init_reward += reward
Expand Down
1 change: 0 additions & 1 deletion utils_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,4 +182,3 @@ def sample(self, global_step):

def __len__(self):
return self.record_size

2 changes: 1 addition & 1 deletion utils_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self, action_dim, device):
self.__conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4, bias=False)
self.__conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, bias=False)
self.__conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, bias=False)
self.__fc1 = nn.Linear(64*7*7, 512)
self.__fc1 = nn.Linear(64 * 7 * 7, 512)
self.__fc2 = nn.Linear(512, action_dim)
self.__device = device

Expand Down
2 changes: 1 addition & 1 deletion utils_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _insert(self, priority, e_id):
:return: bool
"""
self.size += 1

if self.check_full() and not self.replace:
sys.stderr.write('Error: no space left to add experience id %d with priority value %f\n' % (e_id, priority))
return False
Expand Down

0 comments on commit ec409d4

Please sign in to comment.