Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
dbaranchuk committed Jan 11, 2020
1 parent cfc9085 commit 6e06c93
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 20 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Code for [paper](http://arxiv.org/abs/1907.04155)

## Overview
Our approach utilizes non-autoregressive Variational Autoencoders with Gaussian Process prior for time series imputation.
Our approach utilizes Variational Autoencoders with Gaussian Process prior for time series imputation.

* The inference model takes time series with missingness and predicts variational parameters for multivariate Gaussian variational distribution.

Expand All @@ -16,7 +16,7 @@ Our approach utilizes non-autoregressive Variational Autoencoders with Gaussian
## Dependencies

* Python >= 3.6
* TensorFlow = 1.14
* TensorFlow = 1.15
* Some more packages: see `requirements.txt`

## Run
Expand Down
25 changes: 11 additions & 14 deletions lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# Encoders

class DiagonalEncoder(tf.keras.Model):
def __init__(self, z_size, hidden_sizes=(64, 64)):
def __init__(self, z_size, hidden_sizes=(64, 64), **kwargs):
""" Encoder with factorized Normal posterior over temporal dimension
Used by disjoint VAE and HI-VAE with Standard Normal prior
:param z_size: latent space dimensionality
Expand All @@ -33,7 +33,7 @@ def __call__(self, x):


class JointEncoder(tf.keras.Model):
def __init__(self, z_size, hidden_sizes=(64, 64), window_size=3, transpose=False):
def __init__(self, z_size, hidden_sizes=(64, 64), window_size=3, transpose=False, **kwargs):
""" Encoder with 1d-convolutional network and factorized Normal posterior
Used by joint VAE and HI-VAE with Standard Normal prior or GP-VAE with factorized Normal posterior
:param z_size: latent space dimensionality
Expand Down Expand Up @@ -62,17 +62,21 @@ def __call__(self, x):


class BandedJointEncoder(tf.keras.Model):
def __init__(self, z_size, hidden_sizes=(64, 64), window_size=3):
def __init__(self, z_size, hidden_sizes=(64, 64), window_size=3, data_type=None, **kwargs):
""" Encoder with 1d-convolutional network and multivariate Normal posterior
Used by GP-VAE with proposed banded covariance matrix
:param z_size: latent space dimensionality
:param hidden_sizes: tuple of hidden layer sizes.
The tuple length sets the number of hidden layers.
:param window_size: kernel size for Conv1D layer
:param data_type: needed for some data specific modifications, e.g:
tf.nn.softplus is a more common and correct choice, however
tf.nn.sigmoid provides more stable performance on Physionet dataset
"""
super(BandedJointEncoder, self).__init__()
self.z_size = int(z_size)
self.net = make_cnn(3*z_size, hidden_sizes, window_size)
self.data_type = data_type

def __call__(self, x):
mapped = self.net(x)
Expand All @@ -87,8 +91,8 @@ def __call__(self, x):
mapped_mean = mapped_transposed[:, :self.z_size]
mapped_covar = mapped_transposed[:, self.z_size:]

# Hard coded. Sigmoid produces more stable results on Physionet data.
if time_length == 48:
# tf.nn.sigmoid provides more stable performance on Physionet dataset
if self.data_type == 'physionet':
mapped_covar = tf.nn.sigmoid(mapped_covar)
else:
mapped_covar = tf.nn.softplus(mapped_covar)
Expand Down Expand Up @@ -176,7 +180,7 @@ class VAE(tf.keras.Model):
def __init__(self, latent_dim, data_dim, time_length,
encoder_sizes=(64, 64), encoder=DiagonalEncoder,
decoder_sizes=(64, 64), decoder=BernoulliDecoder,
image_preprocessor=None, window_size=3., beta=1.0, M=1, K=1):
image_preprocessor=None, beta=1.0, M=1, K=1, **kwargs):
""" Basic Variational Autoencoder with Standard Normal prior
:param latent_dim: latent space dimensionality
:param data_dim: original data dimensionality
Expand All @@ -188,7 +192,6 @@ def __init__(self, latent_dim, data_dim, time_length,
:param decoder: decoder model class {Bernoulli, Gaussian}Decoder
:param image_preprocessor: 2d-convolutional network used for image data preprocessing
:param window_size: kernel size for 1d-convolution in {Joint, BandedJoint}Encoder models
:param beta: tradeoff coefficient between reconstruction and KL terms in ELBO
:param M: number of Monte Carlo samples for ELBO estimation
:param K: number of importance weights for IWAE model (see: https://arxiv.org/abs/1509.00519)
Expand All @@ -198,13 +201,7 @@ def __init__(self, latent_dim, data_dim, time_length,
self.data_dim = data_dim
self.time_length = time_length

if issubclass(encoder, DiagonalEncoder):
self.encoder = encoder(latent_dim, encoder_sizes)
elif issubclass(encoder, (JointEncoder, BandedJointEncoder)):
self.encoder = encoder(latent_dim, encoder_sizes, window_size=window_size)
else:
raise NotImplementedError("Such encoder class is not implemented")

self.encoder = encoder(latent_dim, encoder_sizes, **kwargs)
self.decoder = decoder(data_dim, decoder_sizes)
self.preprocessor = image_preprocessor

Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
absl-py==0.7.0
numpy==1.16.4
scipy==1.2.0
tensorflow==1.14.0
tensorflow-gpu==1.14.0
tensorflow==1.15.0
tensorflow-gpu==1.15.0
tensorflow_probability==0.7.0
matplotlib
sklearn
scikit-learn
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def main(argv):
kernel=FLAGS.kernel, sigma=FLAGS.sigma,
length_scale=FLAGS.length_scale, kernel_scales = FLAGS.kernel_scales,
image_preprocessor=image_preprocessor, window_size=FLAGS.window_size,
beta=FLAGS.beta, M=FLAGS.M, K=FLAGS.K)
beta=FLAGS.beta, M=FLAGS.M, K=FLAGS.K, data_type=FLAGS.data_type)
else:
raise ValueError("Model type must be one of ['vae', 'hi-vae', 'gp-vae']")

Expand Down

0 comments on commit 6e06c93

Please sign in to comment.