-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTraVeLGAN.py
87 lines (61 loc) · 3.47 KB
/
TraVeLGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import tensorflow as tf
import tensorflow_datasets as tfds
import os
import hdf5storage
from Discriminator import Discriminator
from DataAugmentation import DataAugmentation
from config import cfg
from Losses import MaxMarginLoss, DistanceLoss
from U_Net import UNet
class TraVeLGAN:
def __init__(self):
self.data_augmentation = DataAugmentation(cfg.train)
self.__load_data()
self.__create_graph()
return
def __load_data(self):
# (self.x_train, self.y_train), (self.x_test, self.y_test) = tf.keras.datasets.cifar10.load_data()
# self.dataset = tfds.load(name="celeb_a", split=tfds.Split.TRAIN)
# with h5py.File('/home/firiuza/PycharmProjects/TraVeLGAN/shoes_128.hdf5', 'r') as f:
# dset = f
x = hdf5storage.loadmat('/home/firiuza/MachineLearning/zap_dataset/ut-zap50k-data/image-path.mat')
self.data_paths = [os.path.join(cfg.dataset.data_dir, el) for el in os.listdir(cfg.dataset.data_dir)[:10]]
return
def __create_graph(self):
self.discriminator = Discriminator(cfg.train.discriminator_size)
self.discriminator.build((None, cfg.train.image_size, cfg.train.image_size, cfg.train.image_channels))
self.siamese = Discriminator(cfg.train.siamese_size)
self.siamese.build((None, cfg.train.image_size, cfg.train.image_size, cfg.train.image_channels))
self.unet = UNet()
self.unet.build((None, cfg.train.image_size, cfg.train.image_size, cfg.train.image_channels))
self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
self.BC_loss = tf.keras.losses.BinaryCrossentropy()
self.max_margin_loss = MaxMarginLoss(cfg.train.delta)
self.distance_loss = DistanceLoss()
return
def run_train_epoch(self):
dataset = tf.data.Dataset.from_tensor_slices((self.data_paths))
dataset = dataset.shuffle(buffer_size=len(self.data_paths))
dataset = dataset.map(map_func=self.data_augmentation.preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size=cfg.train.batch_size).prefetch(buffer_size=cfg.train.prefetch_buffer_size)
for image in dataset:
with tf.GradientTape() as tape:
image = tf.ones(shape=(cfg.train.batch_size, cfg.train.image_size, cfg.train.image_size, cfg.train.image_channels))
generated_image = self.unet(image, True)
real_predictions = tf.sigmoid(self.discriminator(image, True))
fake_predictions = tf.sigmoid(self.discriminator(generated_image, True))
D_real = self.BC_loss(real_predictions, tf.ones_like(real_predictions))
D_fake = self.BC_loss(fake_predictions, tf.zeros_like(fake_predictions))
D_loss = D_real + D_fake
G_adv = self.BC_loss(fake_predictions, tf.ones_like(fake_predictions))
real_embeddings = self.siamese(image, True)
fake_embeddings = self.siamese(generated_image, True)
TraVeL_loss = self.distance_loss(real_embeddings, fake_embeddings)
siamese_loss = self.max_margin_loss(real_embeddings, None)
G_loss = G_adv + TraVeL_loss
S_loss = siamese_loss + TraVeL_loss
# grads = tape.gradient(loss_value, self.unet.trainable_variables)
# self.optimizer.apply_gradients(zip(grads, self.unet.trainable_variables))
return
gan = TraVeLGAN()
gan.run_train_epoch()