Skip to content

liber145/rlpack

Folders and files

NameName
Last commit message
Last commit date

Latest commit

8820873 · Aug 2, 2020
Aug 8, 2019
Aug 2, 2020
Nov 24, 2018
Aug 2, 2020
Jul 28, 2019
Aug 2, 2020
Aug 6, 2019
Mar 20, 2019

Repository files navigation

本包简介


rlpack是一个基于tensorflow的强化学习算法库,解耦算法和环境,方便调用。

使用方法


下面展示如何使用rlpackMuJoCo环境中运行PPO算法。

# -*- coding: utf-8 -*-


import argparse
import time
from collections import namedtuple

import gym
import numpy as np
import tensorflow as tf

from rlpack.algos import PPO
from rlpack.utils import mlp, mlp_gaussian_policy

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--env',  type=str, default="Reacher-v2")
args = parser.parse_args()

Transition = namedtuple('Transition', ('state', 'action', 'reward', 'done', 'early_stop', 'next_state'))


class Memory(object):
    def __init__(self):
        self.memory = []

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self):
        return Transition(*zip(*self.memory))


def policy_fn(x, a):
    return mlp_gaussian_policy(x, a, hidden_sizes=[64, 64], activation=tf.tanh)


def value_fn(x):
    v = mlp(x, [64, 64, 1])
    return tf.squeeze(v, axis=1)


def run_main():
    env = gym.make(args.env)
    dim_obs = env.observation_space.shape[0]
    dim_act = env.action_space.shape[0]
    max_ep_len = 1000

    agent = PPO(dim_act=dim_act, dim_obs=dim_obs, policy_fn=policy_fn, value_fn=value_fn, save_path="./log/ppo")

    start_time = time.time()
    o, ep_ret, ep_len = env.reset(), 0, 0
    for epoch in range(50):
        memory, ep_ret_list, ep_len_list = Memory(), [], []
        for t in range(1000):
            a = agent.get_action(o[np.newaxis, :])[0]
            nexto, r, d, _ = env.step(a)
            ep_ret += r
            ep_len += 1

            memory.push(o, a, r, int(d), int(ep_len == max_ep_len or t == 1000-1), nexto)

            o = nexto

            terminal = d or (ep_len == max_ep_len)
            if terminal or (t == 1000-1):
                if not(terminal):
                    print('Warning: trajectory cut off by epoch at %d steps.' % ep_len)
                if terminal:
                    # 当到达完结状态或是最长状态时,记录结果
                    ep_ret_list.append(ep_ret)
                    ep_len_list.append(ep_len)
                o, ep_ret, ep_len = env.reset(), 0, 0

        print(f"{epoch}th epoch. average_return={np.mean(ep_ret_list)}, average_len={np.mean(ep_len_list)}")

        # 更新策略。
        batch = memory.sample()
        agent.update([np.array(x) for x in batch])

    elapsed_time = time.time() - start_time
    print("elapsed time:", elapsed_time)


if __name__ == "__main__":
    run_main()

安装流程


  1. 安装依赖包

安装所需依赖软件包,请看environment.yml. 建议使用Anaconda配置python运行环境,可用以下脚本安装。

    $ git clone https://github.com/liber145/rlpack
    $ cd rlpack
    $ conda env create -f environment.yml
    $ conda activate py36
  1. 安装rlpack
    $ python setup.py install

以上流程会安装一个常用的强化学习运行环境gym. 该环境还支持一些复杂的强化学习环境,比如MuJoCo,具体请看gym的介绍。

算法列表


算法 论文链接 类型 连续动作 离散动作
DQN Playing Atari with Deep Reinforcement Learning off-policy
DoubleDQN Deep Reinforcement Learning with Double Q-learning off-policy
DuelDQN Dueling Network Architectures for Deep Reinforcement Learning off-policy
DistDQN A Distributional Perspective on Reinforcement Learning off-policy
PG Introduction to Reinforcement Learning on-policy
A2C Asynchronous Methods for Deep Reinforcement Learning on-policy
TRPO Trust Region Policy Optimization on-policy
PPO Proximal Policy Optimization Algorithms on-policy
TD3 Addressing Function Approximation Error in Actor-Critic Methods off-policy
DDPG Continuous control with deep reinforcement learning off-policy
SAC Soft Actor-Critic off-policy
GAIL Generative Adversarial Imitation Learning on-policy

部分算法解释请看文档

参考代码


在实现过程中,参考了其他优秀代码,帮助比较大的列举如下:

学习资料


About

A pack of reinforcement learning algorithms.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages