-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a887a52
commit 0d81d65
Showing
8 changed files
with
546 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,40 @@ | ||
# GP-VAE | ||
Code for the GP-VAE model described in https://arxiv.org/abs/1907.04155 | ||
# GP-VAE: Deep Probabilistic Time Series Imputation | ||
|
||
Code for [paper](http://arxiv.org/abs/1907.04155) | ||
|
||
## Overview | ||
Our approach utilizes non-autoregressive Variational Autoencoders with Gaussian Process prior for time series imputation. | ||
|
||
* The inference model takes time series with missingness and predicts variational parameters for multivariate Gaussian variational distribution. | ||
|
||
* The Gaussian Process prior encourages latent representations to capture the temporal correlations in data. | ||
|
||
* The generative model takes the sample from posterior approximation and reconstructs the original time series with imputed missing values. | ||
|
||
 | ||
|
||
## Dependencies | ||
|
||
* Python >= 3.6 | ||
* TensorFlow = 1.14 | ||
* Some more packages: see `requirements.txt` | ||
|
||
## Run | ||
1. Clone or download this repo. `cd` yourself to it's root directory. | ||
2. Grab or build a working python enviromnent. [Anaconda](https://www.anaconda.com/) works fine. | ||
3. Install packages from `requirements.txt` | ||
4. Download data: 'bash data/load_{hmnist, sprites, physionet}.sh'. | ||
5. Run command `CUDA_VISIBLE_DEVICES=* python train.py --model_type {vae, hi-vae, gp-vae} --data_type {hmnist, sprites, physionet} --exp_name <your_name> ...` | ||
|
||
To see all available flags run: `python train.py --help` | ||
|
||
## Reproducibility | ||
|
||
We provide a set of hyperparameters used in our final runs. Some flags have common values for all datasets by default. For reproducibility of reported results run: | ||
* HMNIST: `python train.py --model_type gp-vae --data_type hmnist --exp_name reproduce_hmnist --seed $RANDOM --testing --banded_covar | ||
--latent_dim 256 --encoder_sizes=256,256 --decoder_sizes=256,256,256 --window_size 3 --sigma 1 --length_scale 2 --beta 0.8 --num_epochs 20` | ||
* SPRITES: `python train.py --model_type gp-vae --data_type sprites --exp_name reproduce_sprites --seed $RANDOM --testing --banded_covar | ||
--latent_dim 256 --encoder_sizes=32,256,256 --decoder_sizes=256,256,256 --window_size 3 --sigma 1 --length_scale 2 --beta 0.1 --num_epochs 20` | ||
* Physionet: `python train.py --model_type gp-vae --data_type physionet --exp_name reproduce_physionet --seed $RANDOM --testing --banded_covar | ||
--latent_dim 35 --encoder_sizes=128,128 --decoder_sizes=256,256 --window_size 24 --sigma 1.005 --length_scale 7 --beta 0.2 --num_epochs 40` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
DATA_DIR="data/hmnist" | ||
random_mechanism="mnar" | ||
|
||
mkdir -p ${DATA_DIR} | ||
|
||
if [ "$random_mechanism" == "mnar" ] ; then | ||
wget https://www.dropbox.com/s/aidkzh525mvwf44/hmnist_mnar.npz?dl=1 -O ${DATA_DIR}/hmnist_${random_mechanism}.npz | ||
elif [ "$random_mechanism" == "spatial"] ; then | ||
wget https://www.dropbox.com/s/ccxlqvu80hk0jfn/hmnist_spatial.npz?dl=1 -O ${DATA_DIR}/hmnist_${random_mechanism}.npz | ||
elif [ "$random_mechanism" == "random" ] ; then | ||
#wget https://www.dropbox.com/s/7iudp0q7fed5map/hmnist_random.npz?dl=1 -O ${DATA_DIR}/hmnist_${random_mechanism}.npz | ||
elif [ "$random_mechanism" == "temporal_neg" ] ; then | ||
wget https://www.dropbox.com/s/aw2dj0ikd48zf89/hmnist_temporal_neg.npz?dl=1 -O ${DATA_DIR}/hmnist_${random_mechanism}.npz | ||
elif [ "$random_mechanism" == "temporal_pos" ] ; then | ||
wget https://www.dropbox.com/s/qktos9t0i6i2ee3/hmnist_temporal_pos.npz?dl=1 -O ${DATA_DIR}/hmnist_${random_mechanism}.npz | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
DATA_DIR="data/physionet" | ||
|
||
mkdir -p ${DATA_DIR} | ||
wget https://www.dropbox.com/s/651d86winb4cy9n/physionet.npz?dl=1 -O ${DATA_DIR}/physionet.npz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
DATA_DIR="data/sprites" | ||
|
||
mkdir -p ${DATA_DIR} | ||
wget https://www.dropbox.com/s/cjuzj71v5sgwcge/sprites.npz?dl=1 -O ${DATA_DIR}/sprites.zip |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
absl-py==0.7.0 | ||
numpy==1.16.4 | ||
scipy==1.2.0 | ||
tensorflow==1.14.0 | ||
tensorflow-gpu==1.14.0 | ||
tensorflow_probsbility==0.7.0 | ||
matplotlib | ||
sklearn |
Oops, something went wrong.