|
| 1 | +# Accelerating Reinforcement Learning with Learned Skill Priors |
| 2 | +#### [[Project Website]](https://clvrai.github.io/spirl/) [[Paper]](https://arxiv.org/abs/2010.11944) |
| 3 | + |
| 4 | +[Karl Pertsch](https://kpertsch.github.io/)<sup>1</sup>, [Youngwoon Lee](https://youngwoon.github.io/)<sup>1</sup>, |
| 5 | +[Joseph Lim](https://www.clvrai.com/)<sup>1</sup> |
| 6 | + |
| 7 | +<sup>1</sup>CLVR Lab, University of Southern California |
| 8 | + |
| 9 | +<a href="https://kpertsch.github.io/spirl/"> |
| 10 | +<p align="center"> |
| 11 | +<img src="docs/resources/spirl_teaser.png" width="800"> |
| 12 | +</p> |
| 13 | +</img></a> |
| 14 | + |
| 15 | +This is the official PyTorch implementation of the paper "**Accelerating Reinforcement Learning with Learned Skill Priors**" |
| 16 | +(CoRL 2020). |
| 17 | + |
| 18 | +## Requirements |
| 19 | + |
| 20 | +- python 3.7+ |
| 21 | +- mujoco 2.0 (for RL experiments) |
| 22 | +- Ubuntu 18.04 |
| 23 | + |
| 24 | +## Installation Instructions |
| 25 | + |
| 26 | +Create a virtual environment and install all required packages. |
| 27 | +``` |
| 28 | +cd spirl |
| 29 | +pip3 install virtualenv |
| 30 | +virtualenv -p $(which python3) ./venv |
| 31 | +source ./venv/bin/activate |
| 32 | +
|
| 33 | +# Install dependencies and package |
| 34 | +pip3 install -r requirements.txt |
| 35 | +pip3 install -e . |
| 36 | +``` |
| 37 | + |
| 38 | +Set the environment variables that specify the root experiment and data directories. For example: |
| 39 | +``` |
| 40 | +mkdir ./experiments |
| 41 | +mkdir ./data |
| 42 | +export EXP_DIR=./experiments |
| 43 | +export DATA_DIR=./data |
| 44 | +``` |
| 45 | + |
| 46 | +Finally, install our fork of the [D4RL benchmark](https://github.com/kpertsch/d4rl) repository by following its installation instructions. |
| 47 | +It will provide both, the kitchen environment as well as the training data for the skill prior model in kitchen and maze environment. |
| 48 | + |
| 49 | +## Example Commands |
| 50 | +To train a skill prior model for the kitchen environment, run: |
| 51 | +``` |
| 52 | +python3 spirl/train.py --path=spirl/configs/skill_prior_learning/kitchen/hierarchical --val_data_size=160 |
| 53 | +``` |
| 54 | +Results can be visualized using tensorboard in the experiment directory: `tensorboard --logdir=$EXP_DIR`. |
| 55 | + |
| 56 | +For training a SPIRL agent on the kitchen environment using the pre-trained skill prior from above, run: |
| 57 | +``` |
| 58 | +python3 spirl/rl/train.py --path=spirl/configs/hrl/kitchen/spirl --seed=0 --prefix=SPIRL_kitchen_seed0 |
| 59 | +``` |
| 60 | +Results will be written to [WandB](https://www.wandb.com/). Before running RL, |
| 61 | +create an account and then change the WandB entity and project name at the top of [rl/train.py](spirl/rl/train.py) to match your account. |
| 62 | + |
| 63 | +In both commands, `kitchen` can be replaced with `maze / block_stacking` to run on the respective environment. Before training models |
| 64 | +on these environments, the corresponding datasets need to be downloaded (the kitchen dataset gets downloaded automatically) |
| 65 | +-- download links are provided below. |
| 66 | +Additional commands for training baseline models / agents are also provided below. |
| 67 | + |
| 68 | +### Baseline Commands |
| 69 | + |
| 70 | +- Train **Single-step action prior**: |
| 71 | +``` |
| 72 | +python3 spirl/train.py --path=spirl/configs/skill_prior_learning/kitchen/flat --val_data_size=160 |
| 73 | +``` |
| 74 | + |
| 75 | +- Run **Vanilla SAC**: |
| 76 | +``` |
| 77 | +python3 spirl/rl/train.py --path=spirl/configs/rl/kitchen/SAC --seed=0 --prefix=SAC_kitchen_seed0 |
| 78 | +``` |
| 79 | + |
| 80 | +- Run **SAC w/ single-step action prior**: |
| 81 | +``` |
| 82 | +python3 spirl/rl/train.py --path=spirl/configs/rl/kitchen/prior_initialized/flat_prior/ --seed=0 --prefix=flatPrior_kitchen_seed0 |
| 83 | +``` |
| 84 | + |
| 85 | +- Run **BC + finetune**: |
| 86 | +``` |
| 87 | +python3 spirl/rl/train.py --path=spirl/configs/rl/kitchen/prior_initialized/bc_finetune/ --seed=0 --prefix=bcFinetune_kitchen_seed0 |
| 88 | +``` |
| 89 | + |
| 90 | +- Run **Skill Space Policy w/o prior**: |
| 91 | +``` |
| 92 | +python3 spirl/rl/train.py --path=spirl/configs/hrl/kitchen/no_prior/ --seed=0 --prefix=SSP_noPrior_kitchen_seed0 |
| 93 | +``` |
| 94 | + |
| 95 | +Again, all commands can be run on `maze / block stacking` by replacing `kitchen` with the respective environment in the paths |
| 96 | +(after downloading the datasets). |
| 97 | + |
| 98 | + |
| 99 | +## Starting to Modify the Code |
| 100 | + |
| 101 | +### Modifying the hyperparameters |
| 102 | +The default hyperparameters are defined in the respective model files, e.g. in [```skill_prior_mdl.py```](spirl/models/skill_prior_mdl.py#L47) |
| 103 | +for the SPIRL model. Modifications to these parameters can be defined through the experiment config files (passed to the respective |
| 104 | +command via the `--path` variable). For an example, see [```kitchen/hierarchical/conf.py```](spirl/configs/skill_prior_learning/kitchen/hierarchical/conf.py). |
| 105 | + |
| 106 | + |
| 107 | +### Adding a new dataset for model training |
| 108 | +All code that is dataset-specific should be placed in a corresponding subfolder in `spirl/data`. |
| 109 | +To add a data loader for a new dataset, the `Dataset` classes from [```data_loader.py```](spirl/components/data_loader.py) need to be subclassed |
| 110 | +and the `__getitem__` function needs to be overwritten to load a single data sample. The output `dict` should include the following |
| 111 | +keys: |
| 112 | + |
| 113 | +``` |
| 114 | +dict({ |
| 115 | + 'states': (time, state_dim) # state sequence (for state-based prior inputs) |
| 116 | + 'actions': (time, action_dim) # action sequence (as skill input for training prior model) |
| 117 | + 'images': (time, channels, width, height) # image sequence (for image-based prior inputs) |
| 118 | +}) |
| 119 | +``` |
| 120 | + |
| 121 | +All datasets used with the codebase so far have been based on `HDF5` files. The `GlobalSplitDataset` provides functionality to read all |
| 122 | +HDF5-files in a directory and split them in `train/val/test` based on percentages. The `VideoDataset` class provides |
| 123 | +many functionalities for manipulating sequences, like randomly cropping subsequences, padding etc. |
| 124 | + |
| 125 | +### Adding a new RL environment |
| 126 | +To add a new RL environment, simply define a new environent class in `spirl/rl/envs` that inherits from the environment interface |
| 127 | +in [```spirl/rl/components/environment.py```](spirl/rl/components/environment.py). |
| 128 | + |
| 129 | + |
| 130 | +### Modifying the skill prior model architecture |
| 131 | +Start by defining a model class in the `spirl/models` directory that inherits from the `BaseModel` or `SkillPriorMdl` class. |
| 132 | +The new model needs to define the architecture in the constructor (e.g. by overwriting the `build_network()` function), |
| 133 | +implement the forward pass and loss functions, |
| 134 | +as well as model-specific logging functionality if desired. For an example, see [```spirl/models/skill_prior_mdl.py```](spirl/models/skill_prior_mdl.py). |
| 135 | + |
| 136 | +Note, that most basic architecture components (MLPs, CNNs, LSTMs, Flow models etc) are defined in `spirl/modules` and can be |
| 137 | +conveniently reused for easy architecture definitions. Below are some links to the most important classes. |
| 138 | + |
| 139 | +|Component | File | Description | |
| 140 | +|:------------- |:-------------|:-------------| |
| 141 | +| MLP | [```Predictor```](spirl/modules/subnetworks.py#L33) | Basic N-layer fully-connected network. Defines number of inputs, outputs, layers and hidden units. | |
| 142 | +| CNN-Encoder | [```ConvEncoder```](spirl/modules/subnetworks.py#L66) | Convolutional encoder, number of layers determined by input dimensionality (resolution halved per layer). Number of channels doubles per layer. Returns encoded vector + skip activations. | |
| 143 | +| CNN-Decoder | [```ConvDecoder```](spirl/modules/subnetworks.py#L145) | Mirrors architecture of conv. encoder. Can take skip connections as input, also versions that copy pixels etc. | |
| 144 | +| Processing-LSTM | [```BaseProcessingLSTM```](spirl/modules/recurrent_modules.py#L70) | Basic N-layer LSTM for processing an input sequence. Produces one output per timestep, number of layers / hidden size configurable.| |
| 145 | +| Prediction-LSTM | [```RecurrentPredictor```](spirl/modules/recurrent_modules.py#L241) | Same as processing LSTM, but for autoregressive prediction. | |
| 146 | +| Mixture-Density Network | [```MDN```](spirl/modules/mdn.py#L10) | MLP that outputs GMM distribution. | |
| 147 | +| Normalizing Flow Model | [```NormalizingFlowModel```](spirl/modules/flow_models.py#L9) | Implements normalizing flow model that stacks multiple flow blocks. Implementation for RealNVP block provided. | |
| 148 | + |
| 149 | +### Adding a new RL algorithm |
| 150 | +The core RL algorithms are implemented within the `Agent` class. For adding a new algorithm, a new file needs to be created in |
| 151 | +`spirl/rl/agents` and [```BaseAgent```](spirl/rl/components/agent.py#L19) needs to be subclassed. In particular, any required |
| 152 | +networks (actor, critic etc) need to be constructed and the `update(...)` function needs to be overwritten. For an example, |
| 153 | +see the SAC implementation in [```SACAgent```](spirl/rl/agents/ac_agent.py#L67). |
| 154 | + |
| 155 | +The main SPIRL skill prior regularized RL algorithm is implemented in [```ActionPriorSACAgent```](spirl/rl/agents/prior_sac_agent.py#L12). |
| 156 | + |
| 157 | + |
| 158 | +## Detailed Code Structure Overview |
| 159 | +``` |
| 160 | +spirl |
| 161 | + |- components # reusable infrastructure for model training |
| 162 | + | |- base_model.py # basic model class that all models inherit from |
| 163 | + | |- checkpointer.py # handles storing + loading of model checkpoints |
| 164 | + | |- data_loader.py # basic dataset classes, new datasets need to inherit from here |
| 165 | + | |- evaluator.py # defines basic evaluation routines, eg top-of-N evaluation, + eval logging |
| 166 | + | |- logger.py # implements core logging functionality using tensorboardX |
| 167 | + | |- params.py # definition of command line params for model training |
| 168 | + | |- trainer_base.py # basic training utils used in main trainer file |
| 169 | + | |
| 170 | + |- configs # all experiment configs should be placed here |
| 171 | + | |- data_collect # configs for data collection runs |
| 172 | + | |- default_data_configs # defines one default data config per dataset, e.g. state/action dim etc |
| 173 | + | |- hrl # configs for hierarchical downstream RL |
| 174 | + | |- rl # configs for non-hierarchical downstream RL |
| 175 | + | |- skill_prior_learning # configs for skill embedding and prior training (both hierarchical and flat) |
| 176 | + | |
| 177 | + |- data # any dataset-specific code (like data generation scripts, custom loaders etc) |
| 178 | + |- models # holds all model classes that implement forward, loss, visualization |
| 179 | + |- modules # reusable architecture components (like MLPs, CNNs, LSTMs, Flows etc) |
| 180 | + |- rl # all code related to RL |
| 181 | + | |- agents # implements core algorithms in agent classes, like SAC etc |
| 182 | + | |- components # reusable infrastructure for RL experiments |
| 183 | + | |- agent.py # basic agent and hierarchial agent classes - do not implement any specific RL algo |
| 184 | + | |- critic.py # basic critic implementations (eg MLP-based critic) |
| 185 | + | |- environment.py # defines environment interface, basic gym env |
| 186 | + | |- normalization.py # observation normalization classes, only optional |
| 187 | + | |- params.py # definition of command line params for RL training |
| 188 | + | |- policy.py # basic policy interface definition |
| 189 | + | |- replay_buffer.py # simple numpy-array replay buffer, uniform sampling and versions |
| 190 | + | |- sampler.py # rollout sampler for collecting experience, for flat and hierarchical agents |
| 191 | + | |- envs # all custom RL environments should be defined here |
| 192 | + | |- policies # policy implementations go here, MLP-policy and RandomAction are implemented |
| 193 | + | |- utils # utilities for RL code like MPI, WandB related code |
| 194 | + | |- train.py # main RL training script, builds all components + runs training |
| 195 | + | |
| 196 | + |- utils # general utilities, pytorch / visualization utilities etc |
| 197 | + |- train.py # main model training script, builds all components + runs training loop and logging |
| 198 | +``` |
| 199 | + |
| 200 | +The general philosophy is that each new experiment gets a new config file that captures all hyperparameters etc. so that experiments |
| 201 | +themselves are version controllable. |
| 202 | + |
| 203 | +## Datasets |
| 204 | + |
| 205 | +|Dataset | Link | Size | |
| 206 | +|:------------- |:-------------|:-----| |
| 207 | +| Maze | [https://drive.google.com/file/d/1pXM-EDCwFrfgUjxITBsR48FqW9gMoXYZ/view?usp=sharing](https://drive.google.com/file/d/1pXM-EDCwFrfgUjxITBsR48FqW9gMoXYZ/view?usp=sharing) | 12GB | |
| 208 | +| Block Stacking |[https://drive.google.com/file/d/1VobNYJQw_Uwax0kbFG7KOXTgv6ja2s1M/view?usp=sharing](https://drive.google.com/file/d/1VobNYJQw_Uwax0kbFG7KOXTgv6ja2s1M/view?usp=sharing)| 11GB| |
| 209 | + |
| 210 | +You can download the datasets used for the experiments in the paper above. |
| 211 | + |
| 212 | +If you want to generate more data |
| 213 | +or make other modifications to the data generating procedure, we provide instructions for regenerating the |
| 214 | +`maze` and `block stacking` datasets [here](spirl/data/). |
| 215 | + |
| 216 | + |
| 217 | +## Citation |
| 218 | +If you find this work useful in your research, please consider citing: |
| 219 | +``` |
| 220 | +@inproceedings{pertsch2020spirl, |
| 221 | + title={Accelerating Reinforcement Learning with Learned Skill Priors}, |
| 222 | + author={Karl Pertsch and Youngwoon Lee and Joseph J. Lim}, |
| 223 | + booktitle={Conference on Robot Learning (CoRL)}, |
| 224 | + year={2020}, |
| 225 | +} |
| 226 | +``` |
| 227 | + |
| 228 | +## Acknowledgements |
| 229 | +The model architecture and training code builds on a code base which we jointly developed with [Oleh Rybkin](https://www.seas.upenn.edu/~oleh/) for our previous project on [hierarchial prediction](https://github.com/orybkin/video-gcp). |
| 230 | + |
| 231 | +We also published many of the utils / architectural building blocks in a stand-alone package for easy import into your |
| 232 | +own research projects: check out the [blox](https://github.com/orybkin/blox-nn) python module. |
| 233 | + |
| 234 | + |
| 235 | + |
| 236 | + |
0 commit comments