-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathReplayMemory.py
31 lines (22 loc) · 1008 Bytes
/
ReplayMemory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import random
class ReplayBuffer:
def __init__(self, mem_size) -> None:
self.replay_buffer = []
self.mem_size = mem_size
self.mem_pointer = 0
def store(self, state, action, next_state, reward, done):
if len(self.replay_buffer) < self.mem_size:
self.replay_buffer.append(None)
self.replay_buffer[self.mem_pointer] = (state, action, next_state, reward, done)
self.mem_pointer = (self.mem_pointer + 1) % self.mem_size
def sample(self, batch_size):
batch = random.sample(self.replay_buffer, batch_size)
states, actions, next_states, rewards, dones = [], [], [], [], []
for sample in batch:
state, action, next_state, reward, done = sample
states.append(state)
actions.append(action)
next_states.append(next_state)
rewards.append(reward)
dones.append(done)
return states, actions, next_states, rewards, dones