Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
dbaranchuk committed Oct 21, 2019
1 parent 0d81d65 commit f0334b9
Show file tree
Hide file tree
Showing 8 changed files with 719 additions and 2 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
Expand Down
5 changes: 5 additions & 0 deletions lib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .healing_mnist import *
from .utils import *
from .nn_utils import *
from .gp_kernel import *
from .models import *
49 changes: 49 additions & 0 deletions lib/gp_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import tensorflow as tf

'''
GP kernel functions
'''


def rbf_kernel(T, length_scale):
xs = tf.range(T, dtype=tf.float32)
xs_in = tf.expand_dims(xs, 0)
xs_out = tf.expand_dims(xs, 1)
distance_matrix = tf.math.squared_difference(xs_in, xs_out)
distance_matrix_scaled = distance_matrix / length_scale ** 2
kernel_matrix = tf.math.exp(-distance_matrix_scaled)
return kernel_matrix


def diffusion_kernel(T, length_scale):
assert length_scale < 0.5, "length_scale has to be smaller than 0.5 for the "\
"kernel matrix to be diagonally dominant"
sigmas = tf.ones(shape=[T, T]) * length_scale
sigmas_tridiag = tf.linalg.band_part(sigmas, 1, 1)
kernel_matrix = sigmas_tridiag + tf.eye(T)*(1. - length_scale)
return kernel_matrix


def matern_kernel(T, length_scale):
xs = tf.range(T, dtype=tf.float32)
xs_in = tf.expand_dims(xs, 0)
xs_out = tf.expand_dims(xs, 1)
distance_matrix = tf.math.abs(xs_in - xs_out)
distance_matrix_scaled = distance_matrix / tf.cast(tf.math.sqrt(length_scale), dtype=tf.float32)
kernel_matrix = tf.math.exp(-distance_matrix_scaled)
return kernel_matrix


def cauchy_kernel(T, sigma, length_scale):
xs = tf.range(T, dtype=tf.float32)
xs_in = tf.expand_dims(xs, 0)
xs_out = tf.expand_dims(xs, 1)
distance_matrix = tf.math.squared_difference(xs_in, xs_out)
distance_matrix_scaled = distance_matrix / length_scale ** 2
kernel_matrix = tf.math.divide(sigma, (distance_matrix_scaled + 1.))

alpha = 0.001
eye = tf.eye(num_rows=kernel_matrix.shape.as_list()[-1])
return kernel_matrix + alpha * eye
87 changes: 87 additions & 0 deletions lib/healing_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
Data loader for the Healing MNIST data set (c.f. https://arxiv.org/abs/1511.05121)
Adapted from https://github.com/Nikita6000/deep_kalman_filter_for_BM/blob/master/healing_mnist.py
"""


import numpy as np
import scipy.ndimage
from tensorflow.keras.datasets import mnist


def apply_square(img, square_size):
img = np.array(img)
img[:square_size, :square_size] = 255
return img


def apply_noise(img, bit_flip_ratio):
img = np.array(img)
mask = np.random.random(size=(28,28)) < bit_flip_ratio
img[mask] = 255 - img[mask]
return img


def get_rotations(img, rotation_steps):
for rot in rotation_steps:
img = scipy.ndimage.rotate(img, rot, reshape=False)
yield img


def binarize(img):
return (img > 127).astype(np.int)


def heal_image(img, seq_len, square_count, square_size, noise_ratio, max_angle):
squares_begin = np.random.randint(0, seq_len - square_count)
squares_end = squares_begin + square_count

rotations = []
rotation_steps = np.random.normal(size=seq_len, scale=max_angle)

for idx, rotation in enumerate(get_rotations(img, rotation_steps)):
# Don't add the squares right now
# if idx >= squares_begin and idx < squares_end:
# rotation = apply_square(rotation, square_size)

# Don't add noise for now
# noisy_img = apply_noise(rotation, noise_ratio)
noisy_img = rotation
binarized_img = binarize(noisy_img)
rotations.append(binarized_img)

return rotations, rotation_steps


class HealingMNIST():
def __init__(self, seq_len=5, square_count=3, square_size=5, noise_ratio=0.15, digits=range(10), max_angle=180):
(x_train, y_train),(x_test, y_test) = mnist.load_data()
mnist_train = [(img,label) for img, label in zip(x_train, y_train) if label in digits]
mnist_test = [(img, label) for img, label in zip(x_test, y_test) if label in digits]

train_images = []
test_images = []
train_rotations = []
test_rotations = []
train_labels = []
test_labels = []

for img, label in mnist_train:
train_img, train_rot = heal_image(img, seq_len, square_count, square_size, noise_ratio, max_angle)
train_images.append(train_img)
train_rotations.append(train_rot)
train_labels.append(label)

for img, label in mnist_test:
test_img, test_rot = heal_image(img, seq_len, square_count, square_size, noise_ratio, max_angle)
test_images.append(test_img)
test_rotations.append(test_rot)
test_labels.append(label)

self.train_images = np.array(train_images)
self.test_images = np.array(test_images)
self.train_rotations = np.array(train_rotations)
self.test_rotations = np.array(test_rotations)
self.train_labels = np.array(train_labels)
self.test_labels = np.array(test_labels)
Loading

0 comments on commit f0334b9

Please sign in to comment.