forked from lukeluocn/dqn-breakout
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils_memory.py
185 lines (159 loc) · 6.95 KB
/
utils_memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
#!/usr/bin/python
# -*- encoding=utf-8 -*-
# author: Ian
# e-mail: [email protected]
# description:
import sys
import math
import random
import numpy as np
import torch
import utils_prior
class Experience(object):
def __init__(self, conf):
self.size = conf['size']
self.batch_size = conf['batch_size'] if 'batch_size' in conf else 32
self.learn_start = conf['learn_start'] if 'learn_start' in conf else 1000
self.total_steps = conf['steps'] if 'steps' in conf else 100000
self.priority_size = conf['priority_size'] if 'priority_size' in conf else self.size
self.alpha = conf['alpha'] if 'alpha' in conf else 0.7
self.beta_zero = conf['beta_zero'] if 'beta_zero' in conf else 0.5
# partition number N, split total size to N part
self.partition_num = conf['partition_num'] if 'partition_num' in conf else 100
self.index = 0
self.record_size = 0
self.__device = conf['device']
self.__m_states = torch.zeros(
(self.size+1, conf['channels'], 84, 84), dtype=torch.uint8)
self.__m_actions = torch.zeros((self.size+1, 1), dtype=torch.long)
self.__m_rewards = torch.zeros((self.size+1, 1), dtype=torch.int8)
self.__m_dones = torch.zeros((self.size+1, 1), dtype=torch.bool)
self.priority_queue = utils_prior.BinaryHeap(self.priority_size)
self.distributions = self.build_distributions()
self.beta_grad = (1 - self.beta_zero) / float(self.total_steps - self.learn_start)
def build_distributions(self):
"""
preprocess pow of rank
(rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha))
:return: distributions, dict
"""
res = {}
n_partitions = self.partition_num
partition_num = 1
# each part size
partition_size = int(math.floor(self.size / n_partitions))
for n in range(partition_size, self.size + 1, partition_size):
if self.learn_start <= n <= self.priority_size:
distribution = {}
# P(i) = (rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha))
pdf = list(
map(lambda x: math.pow(x, -self.alpha), range(1, n + 1))
)
pdf_sum = math.fsum(pdf)
distribution['pdf'] = list(map(lambda x: x / pdf_sum, pdf))
# split to k segment, and than uniform sample in each k
# set k = batch_size, each segment has total probability is 1 / batch_size
# strata_ends keep each segment start pos and end pos
cdf = np.cumsum(distribution['pdf'])
strata_ends = {1: 0, self.batch_size + 1: n}
step = 1 / float(self.batch_size)
index = 1
for s in range(2, self.batch_size + 1):
while cdf[index] < step:
index += 1
strata_ends[s] = index
step += 1 / float(self.batch_size)
distribution['strata_ends'] = strata_ends
res[partition_num] = distribution
partition_num += 1
return res
def fix_index(self):
"""
get next insert index
:return: index, int
"""
if self.record_size <= self.size:
self.record_size += 1
if self.index % self.size == 0:
self.index = 1
return self.index
else:
self.index += 1
return self.index
def store(self, folded_state, action, reward, done):
"""
store experience, suggest that experience is a tuple of (s1, a, r, s2, t)
so each experience is valid
:param folded_state: tensor
:return: bool, indicate insert status
"""
insert_index = self.fix_index()
if insert_index > 0:
self.__m_states[insert_index] = folded_state
self.__m_actions[insert_index, 0] = action
self.__m_rewards[insert_index, 0] = reward
self.__m_dones[insert_index, 0] = done
# add to priority queue
priority = self.priority_queue.get_max_priority()
self.priority_queue.update(priority, insert_index)
return True
else:
sys.stderr.write('Insert failed\n')
return False
def rebalance(self):
"""
rebalance priority queue
:return: None
"""
self.priority_queue.balance_tree()
def update_priority(self, indices, delta):
"""
update priority according indices and deltas
:param indices: list of experience id
:param delta: list of delta, order correspond to indices
:return: None
"""
for i in range(0, len(indices)):
self.priority_queue.update(math.fabs(delta[i]), indices[i])
def sample(self, global_step):
"""
sample a mini batch from experience replay
:param global_step: now training step
:return: experience, list, samples
:return: w, list, weights
:return: rank_e_id, list, samples id, used for update priority
"""
if self.record_size < self.learn_start:
sys.stderr.write('Record size less than learn start! Sample failed\n')
return False, False, False
dist_index = math.floor(self.record_size / self.size * self.partition_num)
# issue 1 by @camigord
partition_size = math.floor(self.size / self.partition_num)
partition_max = dist_index * partition_size
distribution = self.distributions[dist_index]
rank_list = []
# sample from k segments
for n in range(1, self.batch_size + 1):
index = random.randint(distribution['strata_ends'][n] + 1,
distribution['strata_ends'][n + 1])
rank_list.append(index)
# beta, increase by global_step, max 1
beta = min(self.beta_zero + (global_step - self.learn_start - 1) * self.beta_grad, 1)
# find all alpha pow, notice that pdf is a list, start from 0
alpha_pow = [distribution['pdf'][v - 1] for v in rank_list]
# w = (N * P(i)) ^ (-beta) / max w
w = np.power(np.array(alpha_pow) * partition_max, -beta)
w_max = max(w)
w = torch.tensor(np.divide(w, w_max)).to(self.__device).float()
# rank list is priority id
# convert to experience id
rank_e_id = self.priority_queue.priority_to_experience(rank_list)
# get experience id according rank_e_id
b_state = self.__m_states[rank_e_id, :4].to(self.__device).float()
b_next = self.__m_states[rank_e_id, 1:].to(self.__device).float()
b_action = self.__m_actions[rank_e_id].to(self.__device)
b_reward = self.__m_rewards[rank_e_id].to(self.__device).float()
b_done = self.__m_dones[rank_e_id].to(self.__device).float()
return b_state, b_action, b_reward, b_next, b_done, w, rank_e_id
def __len__(self):
return self.record_size