Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: lukeluocn/dqn-breakout
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: main
Choose a base ref
...
head repository: guzy0324/dqn-breakout
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: main
Choose a head ref
Can’t automatically merge. Don’t worry, you can still create the pull request.

Commits on Nov 6, 2020

  1. some comments

    guzy0324 committed Nov 6, 2020
    Copy the full SHA
    7b28217 View commit details
  2. some comments

    guzy0324 committed Nov 6, 2020
    Copy the full SHA
    7282edb View commit details
  3. 测试git

    gggxxl committed Nov 6, 2020
    Copy the full SHA
    4b96f43 View commit details
  4. 测试git

    gggxxl committed Nov 6, 2020
    Copy the full SHA
    0b19009 View commit details
  5. Copy the full SHA
    7241445 View commit details
  6. 测试git

    gggxxl committed Nov 6, 2020
    Copy the full SHA
    e714732 View commit details
  7. some comments

    guzy0324 committed Nov 6, 2020
    Copy the full SHA
    0d8f59d View commit details
  8. some comments

    guzy0324 committed Nov 6, 2020
    Copy the full SHA
    e6fb2b9 View commit details
  9. some comments

    guzy0324 committed Nov 6, 2020
    Copy the full SHA
    32b9c21 View commit details

Commits on Nov 7, 2020

  1. some comments

    guzy0324 committed Nov 7, 2020
    Copy the full SHA
    b81e30f View commit details
  2. Copy the full SHA
    f387f24 View commit details
  3. some comments

    gggxxl committed Nov 7, 2020
    Copy the full SHA
    4b3a72c View commit details

Commits on Nov 8, 2020

  1. Copy the full SHA
    7e4bdeb View commit details
  2. ddqn

    guzy0324 committed Nov 8, 2020
    Copy the full SHA
    b64be4e View commit details
  3. ddqn

    guzy0324 committed Nov 8, 2020
    Copy the full SHA
    f408bce View commit details
  4. Merge remote-tracking branch 'origin/main' into main

    # Conflicts:
    #	utils_drl.py
    gggxxl committed Nov 8, 2020
    Copy the full SHA
    449525f View commit details
  5. Copy the full SHA
    f461988 View commit details
  6. reformat

    gggxxl committed Nov 8, 2020
    Copy the full SHA
    ec409d4 View commit details
  7. reformat

    gggxxl committed Nov 8, 2020
    Copy the full SHA
    eec144e View commit details
  8. fix index

    gggxxl committed Nov 8, 2020
    Copy the full SHA
    aaf8407 View commit details
  9. dueling dqn

    gggxxl committed Nov 8, 2020
    Copy the full SHA
    4137084 View commit details

Commits on Nov 12, 2020

  1. remove .idea

    guzy0324 committed Nov 12, 2020
    Copy the full SHA
    60b77a2 View commit details
  2. stable rewards

    gggxxl committed Nov 12, 2020
    Copy the full SHA
    a6d9969 View commit details
  3. Copy the full SHA
    bb8fdd0 View commit details
Showing with 522 additions and 138 deletions.
  1. +1 −0 .gitignore
  2. +0 −23 .vscode/settings.json
  3. +21 −0 LICENSE
  4. +27 −19 main.py
  5. +1 −1 run.sh
  6. +21 −0 utils.py
  7. +33 −20 utils_drl.py
  8. +6 −6 utils_env.py
  9. +180 −63 utils_memory.py
  10. +11 −6 utils_model.py
  11. +221 −0 utils_prior.py
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -5,3 +5,4 @@ eval_*/
models/
saved_models/
rewards*.txt
.idea
23 changes: 0 additions & 23 deletions .vscode/settings.json

This file was deleted.

21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2017 Damcy

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
46 changes: 27 additions & 19 deletions main.py
Original file line number Diff line number Diff line change
@@ -7,8 +7,7 @@

from utils_drl import Agent
from utils_env import MyEnv
from utils_memory import ReplayMemory

from utils_memory import Experience

GAMMA = 0.99
GLOBAL_SEED = 0
@@ -45,43 +44,52 @@
EPS_END,
EPS_DECAY,
)
memory = ReplayMemory(STACK_SIZE + 1, MEM_SIZE, device)
# memory = ReplayMemory(STACK_SIZE + 1, MEM_SIZE, device)
memory = Experience(
{'size': MEM_SIZE,
'batch_size': BATCH_SIZE,
'learn_start': WARM_STEPS,
'steps': MAX_STEPS,
'device': device,
'channels': STACK_SIZE + 1}
)

#### Training ####
obs_queue: deque = deque(maxlen=5)
obs_queue: deque = deque(maxlen=5) # 当新元素入队且队满时,会pop掉头
done = True

progressive = tqdm(range(MAX_STEPS), total=MAX_STEPS,
ncols=50, leave=False, unit="b")
ncols=50, leave=False, unit="b") # 进度条
for step in progressive:
if done:
if done: # done表示结束一次游戏,需要重置
observations, _, _ = env.reset()
for obs in observations:
obs_queue.append(obs)

training = len(memory) > WARM_STEPS
state = env.make_state(obs_queue).to(device).float()
action = agent.run(state, training)
obs, reward, done = env.step(action)
obs_queue.append(obs)
memory.push(env.make_folded_state(obs_queue), action, reward, done)
state = env.make_state(obs_queue).to(device).float() # 将长度5的观察队列做成state(只用到了后4个obs
action = agent.run(state, training) # 根据policy network获得当前action
obs, reward, done = env.step(action) # 运行一步
obs_queue.append(obs) # 将头pop,队列中剩后4个加1个新的

memory.store(env.make_folded_state(obs_queue), action, reward, done) # folded_state:[:4]是state,[1:]是next_state

if step % POLICY_UPDATE == 0 and training:
agent.learn(memory, BATCH_SIZE)
if step % POLICY_UPDATE == 0 and training: # 如果training,每过POLICY_UPDATE,就更新一次policy network
agent.learn(memory, step)

if step % TARGET_UPDATE == 0:
if step % TARGET_UPDATE == 0: # 每过TARGET_UPDATE,就更新一次target network
agent.sync()

if step % EVALUATE_FREQ == 0:
if step % EVALUATE_FREQ == 0: # 每过EVALUATE_FREQ,就评价一次
avg_reward, frames = env.evaluate(obs_queue, agent, render=RENDER)
with open("rewards.txt", "a") as fp:
fp.write(f"{step//EVALUATE_FREQ:3d} {step:8d} {avg_reward:.1f}\n")
if RENDER:
prefix = f"eval_{step//EVALUATE_FREQ:03d}"
fp.write(f"{step // EVALUATE_FREQ:3d} {step:8d} {avg_reward:.1f}\n") # 可以从rewards.txt中画出学习曲线
if RENDER: # 如果RENDER,就绘图
prefix = f"eval_{step // EVALUATE_FREQ:03d}"
os.mkdir(prefix)
for ind, frame in enumerate(frames):
with open(os.path.join(prefix, f"{ind:06d}.png"), "wb") as fp:
frame.save(fp, format="png")
agent.save(os.path.join(
SAVE_PREFIX, f"model_{step//EVALUATE_FREQ:03d}"))
SAVE_PREFIX, f"model_{step // EVALUATE_FREQ:03d}"))
done = True
2 changes: 1 addition & 1 deletion run.sh
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@ fi
rm -r eval_*

if [ -z ${CUDA_VISIBLE_DEVICES} ]; then
export CUDA_VISIBLE_DEVICES="0"
export CUDA_VISIBLE_DEVICES="7"
fi

python main.py
21 changes: 21 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/python
# -*- encoding=utf-8 -*-
# author: Ian
# e-mail: stmayue@gmail.com
# description:


def list_to_dict(in_list):
return dict((i, in_list[i]) for i in range(0, len(in_list)))


def exchange_key_value(in_dict):
return dict((in_dict[i], i) for i in in_dict)


def main():
pass


if __name__ == '__main__':
main()
53 changes: 33 additions & 20 deletions utils_drl.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
TorchDevice,
)

from utils_memory import ReplayMemory
from utils_memory import Experience
from utils_model import DQN


@@ -44,20 +44,24 @@ def __init__(
self.__r = random.Random()
self.__r.seed(seed)

self.__policy = DQN(action_dim, device).to(device)
self.__target = DQN(action_dim, device).to(device)
# 将所有最开始读取数据时的tensor变量copy一份到device所指定的GPU上去,之后的运算都在GPU上进行
self.__policy = DQN(action_dim, device).to(device) # policy network
self.__target = DQN(action_dim, device).to(device) # target network
if restore is None:
self.__policy.apply(DQN.init_weights)
self.__policy.apply(DQN.init_weights) # policy自定义参数初始化方式
else:
self.__policy.load_state_dict(torch.load(restore))
self.__target.load_state_dict(self.__policy.state_dict())
self.__optimizer = optim.Adam(
self.__policy.load_state_dict(
torch.load(restore)) # policy加载之前学习到的参数
self.__target.load_state_dict(
self.__policy.state_dict()) # target拷贝policy的参数
self.__optimizer = optim.Adam( # 优化器采用Adam
self.__policy.parameters(),
lr=0.0000625,
eps=1.5e-4,
)
self.__target.eval()
self.__target.eval() # 将模型转变为evaluation(测试)模式,这样就可以排除BN和Dropout对测试的干扰

# epsilon-greedy
def run(self, state: TensorStack4, training: bool = False) -> int:
"""run suggests an action for the given state."""
if training:
@@ -70,22 +74,31 @@ def run(self, state: TensorStack4, training: bool = False) -> int:
return self.__policy(state).max(1).indices.item()
return self.__r.randint(0, self.__action_dim - 1)

def learn(self, memory: ReplayMemory, batch_size: int) -> float:
def learn(self, memory: Experience, step: int) -> float:
"""learn trains the value network via TD-learning."""
state_batch, action_batch, reward_batch, next_batch, done_batch = \
memory.sample(batch_size)
state_batch, action_batch, reward_batch, next_batch, done_batch, w, rank_e_id = \
memory.sample(step) # 随机取样 state是5帧的前4帧 next是5帧的后4帧

values = self.__policy(state_batch.float()).gather(1, action_batch)
values_next = self.__target(next_batch.float()).max(1).values.detach()
expected = (self.__gamma * values_next.unsqueeze(1)) * \
(1. - done_batch) + reward_batch
loss = F.smooth_l1_loss(values, expected)

self.__optimizer.zero_grad()
loss.backward()
# values_next = self.__target(next_batch.float()).max(1).values.detach() # 这里还是nature dqn 没有用ddqn 虽都是双网络
values_next = self.__target(next_batch.float()).gather(
1, self.__policy(next_batch.float()).max(1).indices.unsqueeze(1)).detach() # 改成ddqn
reward_batch[action_batch == 0] += 0.1 # stable reward
expected = (self.__gamma * values_next) * \
(1. - done_batch) + reward_batch # 如果done则是r(考虑t时刻done,没有t+1时刻),否则是r + gamma * max Q

td_error = (expected - values).detach()
memory.update_priority(rank_e_id, td_error.cpu().numpy())

values = values.mul(w)
expected = expected.mul(w)
loss = F.smooth_l1_loss(values, expected) # smooth l1损失

self.__optimizer.zero_grad() # 将模型的参数梯度初始化为0
loss.backward() # 计算梯度,存到__policy.parameters.grad()中
for param in self.__policy.parameters():
param.grad.data.clamp_(-1, 1)
self.__optimizer.step()
param.grad.data.clamp_(-1, 1) # 固定所有梯度为[-1, 1]
self.__optimizer.step() # 做一步最优化

return loss.item()

12 changes: 6 additions & 6 deletions utils_env.py
Original file line number Diff line number Diff line change
@@ -34,29 +34,29 @@ class MyEnv(object):

def __init__(self, device: TorchDevice) -> None:
env_raw = make_atari("BreakoutNoFrameskip-v4")
self.__env_train = wrap_deepmind(env_raw, episode_life=True)
self.__env_train = wrap_deepmind(env_raw, episode_life=True) # 创建训练环境
env_raw = make_atari("BreakoutNoFrameskip-v4")
self.__env_eval = wrap_deepmind(env_raw, episode_life=True)
self.__env_eval = wrap_deepmind(env_raw, episode_life=True) # 创建测试环境
self.__env = self.__env_train
self.__device = device

def reset(
self,
render: bool = False,
render: bool = False, # 是否渲染
) -> Tuple[List[TensorObs], float, List[GymImg]]:
"""reset resets and initializes the underlying gym environment."""
self.__env.reset()
init_reward = 0.
observations = []
frames = []
for _ in range(5): # no-op
obs, reward, done = self.step(0)
for _ in range(5): # no-op
obs, reward, done = self.step(0) # 运行一步
observations.append(obs)
init_reward += reward
if done:
return self.reset(render)
if render:
frames.append(self.get_frame())
frames.append(self.get_frame()) # 如果渲染,获得一帧

return observations, init_reward, frames

243 changes: 180 additions & 63 deletions utils_memory.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,185 @@
from typing import (
Tuple,
)
#!/usr/bin/python
# -*- encoding=utf-8 -*-
# author: Ian
# e-mail: stmayue@gmail.com
# description:

import sys
import math
import random
import numpy as np
import torch

from utils_types import (
BatchAction,
BatchDone,
BatchNext,
BatchReward,
BatchState,
TensorStack5,
TorchDevice,
)


class ReplayMemory(object):

def __init__(
self,
channels: int,
capacity: int,
device: TorchDevice,
) -> None:
self.__device = device
self.__capacity = capacity
self.__size = 0
self.__pos = 0
import utils_prior


class Experience(object):

def __init__(self, conf):
self.size = conf['size']
self.batch_size = conf['batch_size'] if 'batch_size' in conf else 32
self.learn_start = conf['learn_start'] if 'learn_start' in conf else 1000
self.total_steps = conf['steps'] if 'steps' in conf else 100000

self.priority_size = conf['priority_size'] if 'priority_size' in conf else self.size
self.alpha = conf['alpha'] if 'alpha' in conf else 0.7
self.beta_zero = conf['beta_zero'] if 'beta_zero' in conf else 0.5
# partition number N, split total size to N part
self.partition_num = conf['partition_num'] if 'partition_num' in conf else 100

self.index = 0
self.record_size = 0

self.__device = conf['device']
self.__m_states = torch.zeros(
(capacity, channels, 84, 84), dtype=torch.uint8)
self.__m_actions = torch.zeros((capacity, 1), dtype=torch.long)
self.__m_rewards = torch.zeros((capacity, 1), dtype=torch.int8)
self.__m_dones = torch.zeros((capacity, 1), dtype=torch.bool)

def push(
self,
folded_state: TensorStack5,
action: int,
reward: int,
done: bool,
) -> None:
self.__m_states[self.__pos] = folded_state
self.__m_actions[self.__pos, 0] = action
self.__m_rewards[self.__pos, 0] = reward
self.__m_dones[self.__pos, 0] = done

self.__pos = (self.__pos + 1) % self.__capacity
self.__size = max(self.__size, self.__pos)

def sample(self, batch_size: int) -> Tuple[
BatchState,
BatchAction,
BatchReward,
BatchNext,
BatchDone,
]:
indices = torch.randint(0, high=self.__size, size=(batch_size,))
b_state = self.__m_states[indices, :4].to(self.__device).float()
b_next = self.__m_states[indices, 1:].to(self.__device).float()
b_action = self.__m_actions[indices].to(self.__device)
b_reward = self.__m_rewards[indices].to(self.__device).float()
b_done = self.__m_dones[indices].to(self.__device).float()
return b_state, b_action, b_reward, b_next, b_done

def __len__(self) -> int:
return self.__size
(self.size+1, conf['channels'], 84, 84), dtype=torch.uint8)
self.__m_actions = torch.zeros((self.size+1, 1), dtype=torch.long)
self.__m_rewards = torch.zeros((self.size+1, 1), dtype=torch.int8)
self.__m_dones = torch.zeros((self.size+1, 1), dtype=torch.bool)

self.priority_queue = utils_prior.BinaryHeap(self.priority_size)
self.distributions = self.build_distributions()

self.beta_grad = (1 - self.beta_zero) / float(self.total_steps - self.learn_start)

def build_distributions(self):
"""
preprocess pow of rank
(rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha))
:return: distributions, dict
"""
res = {}
n_partitions = self.partition_num
partition_num = 1
# each part size
partition_size = int(math.floor(self.size / n_partitions))

for n in range(partition_size, self.size + 1, partition_size):
if self.learn_start <= n <= self.priority_size:
distribution = {}
# P(i) = (rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha))
pdf = list(
map(lambda x: math.pow(x, -self.alpha), range(1, n + 1))
)
pdf_sum = math.fsum(pdf)
distribution['pdf'] = list(map(lambda x: x / pdf_sum, pdf))
# split to k segment, and than uniform sample in each k
# set k = batch_size, each segment has total probability is 1 / batch_size
# strata_ends keep each segment start pos and end pos
cdf = np.cumsum(distribution['pdf'])
strata_ends = {1: 0, self.batch_size + 1: n}
step = 1 / float(self.batch_size)
index = 1
for s in range(2, self.batch_size + 1):
while cdf[index] < step:
index += 1
strata_ends[s] = index
step += 1 / float(self.batch_size)

distribution['strata_ends'] = strata_ends

res[partition_num] = distribution

partition_num += 1

return res

def fix_index(self):
"""
get next insert index
:return: index, int
"""
if self.record_size <= self.size:
self.record_size += 1
if self.index % self.size == 0:
self.index = 1
return self.index
else:
self.index += 1
return self.index

def store(self, folded_state, action, reward, done):
"""
store experience, suggest that experience is a tuple of (s1, a, r, s2, t)
so each experience is valid
:param folded_state: tensor
:return: bool, indicate insert status
"""
insert_index = self.fix_index()
if insert_index > 0:
self.__m_states[insert_index] = folded_state
self.__m_actions[insert_index, 0] = action
self.__m_rewards[insert_index, 0] = reward
self.__m_dones[insert_index, 0] = done

# add to priority queue
priority = self.priority_queue.get_max_priority()
self.priority_queue.update(priority, insert_index)
return True
else:
sys.stderr.write('Insert failed\n')
return False

def rebalance(self):
"""
rebalance priority queue
:return: None
"""
self.priority_queue.balance_tree()

def update_priority(self, indices, delta):
"""
update priority according indices and deltas
:param indices: list of experience id
:param delta: list of delta, order correspond to indices
:return: None
"""
for i in range(0, len(indices)):
self.priority_queue.update(math.fabs(delta[i]), indices[i])

def sample(self, global_step):
"""
sample a mini batch from experience replay
:param global_step: now training step
:return: experience, list, samples
:return: w, list, weights
:return: rank_e_id, list, samples id, used for update priority
"""
if self.record_size < self.learn_start:
sys.stderr.write('Record size less than learn start! Sample failed\n')
return False, False, False

dist_index = math.floor(self.record_size / self.size * self.partition_num)
# issue 1 by @camigord
partition_size = math.floor(self.size / self.partition_num)
partition_max = dist_index * partition_size
distribution = self.distributions[dist_index]
rank_list = []
# sample from k segments
for n in range(1, self.batch_size + 1):
index = random.randint(distribution['strata_ends'][n] + 1,
distribution['strata_ends'][n + 1])
rank_list.append(index)

# beta, increase by global_step, max 1
beta = min(self.beta_zero + (global_step - self.learn_start - 1) * self.beta_grad, 1)
# find all alpha pow, notice that pdf is a list, start from 0
alpha_pow = [distribution['pdf'][v - 1] for v in rank_list]
# w = (N * P(i)) ^ (-beta) / max w
w = np.power(np.array(alpha_pow) * partition_max, -beta)
w_max = max(w)
w = torch.tensor(np.divide(w, w_max)).to(self.__device).float()
# rank list is priority id
# convert to experience id
rank_e_id = self.priority_queue.priority_to_experience(rank_list)
# get experience id according rank_e_id
b_state = self.__m_states[rank_e_id, :4].to(self.__device).float()
b_next = self.__m_states[rank_e_id, 1:].to(self.__device).float()
b_action = self.__m_actions[rank_e_id].to(self.__device)
b_reward = self.__m_rewards[rank_e_id].to(self.__device).float()
b_done = self.__m_dones[rank_e_id].to(self.__device).float()
return b_state, b_action, b_reward, b_next, b_done, w, rank_e_id

def __len__(self):
return self.record_size

17 changes: 11 additions & 6 deletions utils_model.py
Original file line number Diff line number Diff line change
@@ -3,24 +3,29 @@
import torch.nn.functional as F


class DQN(nn.Module):
class DQN(nn.Module): # 修改网络结构 为duel dqn

def __init__(self, action_dim, device):
super(DQN, self).__init__()
self.__conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4, bias=False)
self.__conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, bias=False)
self.__conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, bias=False)
self.__fc1 = nn.Linear(64*7*7, 512)
self.__fc2 = nn.Linear(512, action_dim)
self.__fc1 = nn.Linear(64 * 7 * 7, 512) # 优势函数第一层fc
self.__fc2 = nn.Linear(512, action_dim) # 优势函数第二层fc
self.__fc1a = nn.Linear(64 * 7 * 7, 512) # 值函数第一层fc
self.__fc2a = nn.Linear(512, 1) # 值函数第二层fc
self.__device = device

def forward(self, x):
def forward(self, x): # 输入状态x(由连续多个frame构成的stack)
x = x / 255.
x = F.relu(self.__conv1(x))
x = F.relu(self.__conv2(x))
x = F.relu(self.__conv3(x))
x = F.relu(self.__fc1(x.view(x.size(0), -1)))
return self.__fc2(x)
advantagex = F.relu(self.__fc1(x.view(x.size(0), -1)))
advantage = self.__fc2(advantagex)
valuex = F.relu(self.__fc1a(x.view(x.size(0), -1)))
value = self.__fc2a(valuex)
return value + (advantage - advantage.mean(1, keepdim=True)) # dueling dqn:

@staticmethod
def init_weights(module):
221 changes: 221 additions & 0 deletions utils_prior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
#!/usr/bin/python
# -*- encoding=utf-8 -*-
# author: Ian
# e-mail: stmayue@gmail.com
# description:

import sys
import math

import utils


class BinaryHeap(object):

def __init__(self, priority_size=100, priority_init=None, replace=True):
self.e2p = {}
self.p2e = {}
self.replace = replace

if priority_init is None:
self.priority_queue = {}
self.size = 0
self.max_size = priority_size
else:
# not yet test
self.priority_queue = priority_init
self.size = len(self.priority_queue)
self.max_size = None or self.size

experience_list = list(map(lambda x: self.priority_queue[x], self.priority_queue))
self.p2e = utils.list_to_dict(experience_list)
self.e2p = utils.exchange_key_value(self.p2e)
for i in range(int(self.size / 2), -1, -1):
self.down_heap(i)

def __repr__(self):
"""
:return: string of the priority queue, with level info
"""
if self.size == 0:
return 'No element in heap!'
to_string = ''
level = -1
max_level = int(math.floor(math.log(self.size, 2)))

for i in range(1, self.size + 1):
now_level = int(math.floor(math.log(i, 2)))
if level != now_level:
to_string = to_string + ('\n' if level != -1 else '') \
+ ' ' * (max_level - now_level)
level = now_level

to_string = to_string + '%.2f ' % self.priority_queue[i][1] + ' ' * (max_level - now_level)

return to_string

def check_full(self):
return self.size > self.max_size

def _insert(self, priority, e_id):
"""
insert new experience id with priority
(maybe don't need get_max_priority and implement it in this function)
:param priority: priority value
:param e_id: experience id
:return: bool
"""
self.size += 1

if self.check_full() and not self.replace:
sys.stderr.write('Error: no space left to add experience id %d with priority value %f\n' % (e_id, priority))
return False
else:
self.size = min(self.size, self.max_size)

self.priority_queue[self.size] = (priority, e_id)
self.p2e[self.size] = e_id
self.e2p[e_id] = self.size

self.up_heap(self.size)
return True

def update(self, priority, e_id):
"""
update priority value according its experience id
:param priority: new priority value
:param e_id: experience id
:return: bool
"""
if e_id in self.e2p:
p_id = self.e2p[e_id]
self.priority_queue[p_id] = (priority, e_id)
self.p2e[p_id] = e_id

self.down_heap(p_id)
self.up_heap(p_id)
return True
else:
# this e id is new, do insert
return self._insert(priority, e_id)

def get_max_priority(self):
"""
get max priority, if no experience, return 1
:return: max priority if size > 0 else 1
"""
if self.size > 0:
return self.priority_queue[1][0]
else:
return 1

def pop(self):
"""
pop out the max priority value with its experience id
:return: priority value & experience id
"""
if self.size == 0:
sys.stderr.write('Error: no value in heap, pop failed\n')
return False, False

pop_priority, pop_e_id = self.priority_queue[1]
self.e2p[pop_e_id] = -1
# replace first
last_priority, last_e_id = self.priority_queue[self.size]
self.priority_queue[1] = (last_priority, last_e_id)
self.size -= 1
self.e2p[last_e_id] = 1
self.p2e[1] = last_e_id

self.down_heap(1)

return pop_priority, pop_e_id

def up_heap(self, i):
"""
upward balance
:param i: tree node i
:return: None
"""
if i > 1:
parent = math.floor(i / 2)
if self.priority_queue[parent][0] < self.priority_queue[i][0]:
tmp = self.priority_queue[i]
self.priority_queue[i] = self.priority_queue[parent]
self.priority_queue[parent] = tmp
# change e2p & p2e
self.e2p[self.priority_queue[i][1]] = i
self.e2p[self.priority_queue[parent][1]] = parent
self.p2e[i] = self.priority_queue[i][1]
self.p2e[parent] = self.priority_queue[parent][1]
# up heap parent
self.up_heap(parent)

def down_heap(self, i):
"""
downward balance
:param i: tree node i
:return: None
"""
if i < self.size:
greatest = i
left, right = i * 2, i * 2 + 1
if left < self.size and self.priority_queue[left][0] > self.priority_queue[greatest][0]:
greatest = left
if right < self.size and self.priority_queue[right][0] > self.priority_queue[greatest][0]:
greatest = right

if greatest != i:
tmp = self.priority_queue[i]
self.priority_queue[i] = self.priority_queue[greatest]
self.priority_queue[greatest] = tmp
# change e2p & p2e
self.e2p[self.priority_queue[i][1]] = i
self.e2p[self.priority_queue[greatest][1]] = greatest
self.p2e[i] = self.priority_queue[i][1]
self.p2e[greatest] = self.priority_queue[greatest][1]
# down heap greatest
self.down_heap(greatest)

def get_priority(self):
"""
get all priority value
:return: list of priority
"""
return list(map(lambda x: x[0], self.priority_queue.values()))[0:self.size]

def get_e_id(self):
"""
get all experience id in priority queue
:return: list of experience ids order by their priority
"""
return list(map(lambda x: x[1], self.priority_queue.values()))[0:self.size]

def balance_tree(self):
"""
rebalance priority queue
:return: None
"""
sort_array = sorted(self.priority_queue.values(), key=lambda x: x[0], reverse=True)
# reconstruct priority_queue
self.priority_queue.clear()
self.p2e.clear()
self.e2p.clear()
cnt = 1
while cnt <= self.size:
priority, e_id = sort_array[cnt - 1]
self.priority_queue[cnt] = (priority, e_id)
self.p2e[cnt] = e_id
self.e2p[e_id] = cnt
cnt += 1
# sort the heap
for i in range(int(math.floor(self.size / 2)), 1, -1):
self.down_heap(i)

def priority_to_experience(self, priority_ids):
"""
retrieve experience ids by priority ids
:param priority_ids: list of priority id
:return: list of experience id
"""
return [self.p2e[i] for i in priority_ids]