diff --git a/README.rst b/README.rst
index 0bf99e7..1d75292 100644
--- a/README.rst
+++ b/README.rst
@@ -94,7 +94,7 @@ Pre-trained models are provided in the GitHub releases. Training your own is a
The easiest way to get up-and-running is to `install Docker `_. Then, you should be able to download and run the pre-built image using the ``docker`` command line tool. Find out more about the ``alexjc/neural-enhance`` image on its `Docker Hub `_ page.
-Here's the simplest way you can call the script using ``docker``, assuming you're familiar with using ``-v`` argument to mount folders you can use this directly to specify files to enhance:
+Here's the simplest way you can call the script using ``docker``, assuming you're familiar with using ``-v`` argument to mount folders (see `documentation `_) you can use this directly to specify files to enhance:
.. code:: bash
@@ -161,7 +161,9 @@ This code uses a combination of techniques from the following papers, as well as
Special thanks for their help and support in various ways:
+* Roelof Pieters — Provided a rack of TitanX GPUs for training model variations on OpenImages dataset.
* Eder Santana — Discussions, encouragement, and his ideas on `sub-pixel deconvolution `_.
+* Wenzhe Shi — Practical advice and feedback on training procedures for the super-resolution GAN [4].
* Andrew Brock — This sub-pixel layer code is based on `his project repository `_ using Lasagne.
* Casper Kaae Sønderby — For suggesting a more stable alternative to sigmoid + log as GAN loss functions.
diff --git a/enhance.py b/enhance.py
index 5d704d3..4e9c75e 100755
--- a/enhance.py
+++ b/enhance.py
@@ -14,7 +14,7 @@
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
#
-__version__ = '0.3'
+__version__ = '0.4'
import io
import os
@@ -32,7 +32,7 @@
# Configure all options first so we can later custom-load other libraries (Theano) based on device specified by user.
-parser = argparse.ArgumentParser(description='Generate a new image by applying style onto a content image.',
+parser = argparse.ArgumentParser(description='Enhance a low-res image into high-def using neural networks.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
add_arg = parser.add_argument
add_arg('files', nargs='*', default=[])
@@ -43,10 +43,11 @@
add_arg('--type', default='photo', type=str, help='Name of the neural network to load/save.')
add_arg('--model', default='default', type=str, help='Specific trained version of the model.')
add_arg('--train', default=False, type=str, help='File pattern to load for training.')
-add_arg('--train-scales', default=0, type=int, help='Randomly resize images this many times.')
-add_arg('--train-blur', default=None, type=int, help='Sigma value for gaussian blur preprocess.')
-add_arg('--train-noise', default=None, type=float, help='Radius for preprocessing gaussian blur.')
-add_arg('--train-jpeg', default=[], nargs='+', type=int, help='JPEG compression level & range in preproc.')
+add_arg('--train-scales', default=[0], nargs='+', type=int, help='Randomly resize images, specify min/max.')
+add_arg('--train-blur', default=[], nargs='+', type=int, help='Sigma value for gaussian blur, min/max.')
+add_arg('--train-noise', default=None, type=float, help='Distribution for gaussian noise preprocess.')
+add_arg('--train-jpeg', default=[], nargs='+', type=int, help='JPEG compression level, specify min/max.')
+add_arg('--train-plugin', default=None, type=str, help='Filename for python pre-processing script.')
add_arg('--epochs', default=10, type=int, help='Total number of iterations in training.')
add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.')
add_arg('--save-every', default=10, type=int, help='Save generator after every training epoch.')
@@ -139,14 +140,24 @@ def __init__(self):
self.data_ready = threading.Event()
self.data_copied = threading.Event()
+ if args.train_plugin is not None:
+ import importlib.util
+ spec = importlib.util.spec_from_file_location('enhance.plugin', 'plugins/{}.py'.format(args.train_plugin))
+ plugin = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(plugin)
+
+ self.iterate_files = plugin.iterate_files
+ self.load_original = plugin.load_original
+ self.load_seed = plugin.load_seed
+
self.orig_shape, self.seed_shape = args.batch_shape, args.batch_shape // args.zoom
self.orig_buffer = np.zeros((args.buffer_size, 3, self.orig_shape, self.orig_shape), dtype=np.float32)
self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_shape, self.seed_shape), dtype=np.float32)
self.files = glob.glob(args.train)
if len(self.files) == 0:
- error("There were no files found to train from searching for `{}`".format(args.train),
- " - Try putting all your images in one folder and using `--train=data/*.jpg`")
+ error('There were no files found to train from searching for `{}`'.format(args.train),
+ ' - Try putting all your images in one folder and using `--train="data/*.jpg"`')
self.available = set(range(args.buffer_size))
self.ready = set()
@@ -154,43 +165,58 @@ def __init__(self):
self.cwd = os.getcwd()
self.start()
- def run(self):
+ def iterate_files(self):
while True:
random.shuffle(self.files)
for f in self.files:
- self.add_to_buffer(f)
+ yield f
- def add_to_buffer(self, f):
- filename = os.path.join(self.cwd, f)
+ def load_original(self, filename):
try:
orig = PIL.Image.open(filename).convert('RGB')
- scale = 2 ** random.randint(0, args.train_scales)
+ scale = 2 ** random.randint(args.train_scales[0], args.train_scales[-1])
if scale > 1 and all(s//scale >= args.batch_shape for s in orig.size):
- orig = orig.resize((orig.size[0]//scale, orig.size[1]//scale), resample=PIL.Image.LANCZOS)
+ orig = orig.resize((orig.size[0]//scale, orig.size[1]//scale), resample=random.randint(0,3))
if any(s < args.batch_shape for s in orig.size):
raise ValueError('Image is too small for training with size {}'.format(orig.size))
+ return scipy.misc.fromimage(orig).astype(np.float32)
except Exception as e:
warn('Could not load `{}` as image.'.format(filename),
' - Try fixing or removing the file before next run.')
- self.files.remove(f)
- return
+ self.files.remove(filename)
+ return None
- seed = orig
- if args.train_blur is not None:
- seed = seed.filter(PIL.ImageFilter.GaussianBlur(radius=random.randint(0, args.train_blur*2)))
+ def load_seed(self, filename, original, zoom):
+ seed = scipy.misc.toimage(original)
+ if len(args.train_blur):
+ seed = seed.filter(PIL.ImageFilter.GaussianBlur(radius=random.randint(args.train_blur[0], args.train_blur[-1])))
if args.zoom > 1:
- seed = seed.resize((orig.size[0]//args.zoom, orig.size[1]//args.zoom), resample=PIL.Image.LANCZOS)
+ seed = seed.resize((seed.size[0]//zoom, seed.size[1]//zoom), resample=random.randint(0,3))
+
if len(args.train_jpeg) > 0:
- buffer, rng = io.BytesIO(), args.train_jpeg[-1] if len(args.train_jpeg) > 1 else 15
- seed.save(buffer, format='jpeg', quality=args.train_jpeg[0]+random.randrange(-rng, +rng))
+ buffer = io.BytesIO()
+ seed.save(buffer, format='jpeg', quality=random.randrange(args.train_jpeg[0], args.train_jpeg[-1]))
seed = PIL.Image.open(buffer)
- orig = scipy.misc.fromimage(orig).astype(np.float32)
seed = scipy.misc.fromimage(seed).astype(np.float32)
-
if args.train_noise is not None:
seed += scipy.random.normal(scale=args.train_noise, size=(seed.shape[0], seed.shape[1], 1))
+ return seed
+ def run(self):
+ for filename in self.iterate_files():
+ f = os.path.join(self.cwd, filename)
+ orig = self.load_original(f)
+ if orig is None: continue
+
+ seed = self.load_seed(f, orig, args.zoom)
+ if seed is None: continue
+
+ self.enqueue(orig, seed)
+
+ raise ValueError('Insufficient number of files found for training.')
+
+ def enqueue(self, orig, seed):
for _ in range(seed.shape[0] * seed.shape[1] // (args.buffer_fraction * self.seed_shape ** 2)):
h = random.randint(0, seed.shape[0] - self.seed_shape)
w = random.randint(0, seed.shape[1] - self.seed_shape)
@@ -241,7 +267,34 @@ def up(d): return self.upscale * d if d else d
def get_output_for(self, input, deterministic=False, **kwargs):
out, r = T.zeros(self.get_output_shape_for(input.shape)), self.upscale
for y, x in itertools.product(range(r), repeat=2):
- out=T.inc_subtensor(out[:,:,y::r,x::r], input[:,r*y+x::r*r,:,:])
+ out = T.set_subtensor(out[:,:,y::r,x::r], input[:,r*y+x::r*r,:,:])
+ return out
+
+
+class ReflectLayer(lasagne.layers.Layer):
+ """Based on more code by ajbrock: https://gist.github.com/ajbrock/a3858c26282d9731191901b397b3ce9f
+ """
+
+ def __init__(self, incoming, pad, batch_ndim=2, **kwargs):
+ super(ReflectLayer, self).__init__(incoming, **kwargs)
+ self.pad = pad
+ self.batch_ndim = batch_ndim
+
+ def get_output_shape_for(self, input_shape):
+ output_shape = list(input_shape)
+ for k, p in enumerate(self.pad):
+ if output_shape[k + self.batch_ndim] is None: continue
+ output_shape[k + self.batch_ndim] += p * 2
+ return tuple(output_shape)
+
+ def get_output_for(self, x, **kwargs):
+ out = T.zeros(self.get_output_shape_for(x.shape))
+ p0, p1 = self.pad
+ out = T.set_subtensor(out[:,:,:p0,p1:-p1], x[:,:,p0:0:-1,:])
+ out = T.set_subtensor(out[:,:,-p0:,p1:-p1], x[:,:,-2:-(2+p0):-1,:])
+ out = T.set_subtensor(out[:,:,p0:-p0,p1:-p1], x)
+ out = T.set_subtensor(out[:,:,:,:p1], out[:,:,:,(2*p1):p1:-1])
+ out = T.set_subtensor(out[:,:,:,-p1:], out[:,:,:,-(p1+2):-(2*p1+2):-1])
return out
@@ -270,17 +323,30 @@ def __init__(self):
def last_layer(self):
return list(self.network.values())[-1]
- def make_layer(self, name, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), alpha=0.25):
- conv = ConvLayer(input, units, filter_size, stride=stride, pad=pad, nonlinearity=None)
- prelu = lasagne.layers.ParametricRectifierLayer(conv, alpha=lasagne.init.Constant(alpha))
+ def make_layer(self, name, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), alpha=0.25, reuse=False):
+ clone = '0/'+name.split('/')[-1]
+ if reuse and clone+'x' in self.network:
+ extra = {'W': self.network[clone+'x'].W, 'b': self.network[clone+'x'].b}
+ else:
+ extra = {}
+
+ padded = ReflectLayer(input, pad) if pad[0] > 0 and pad[1] > 0 else input
+ conv = ConvLayer(padded, units, filter_size, stride=stride, pad=0, nonlinearity=None, **extra)
self.network[name+'x'] = conv
- self.network[name+'>'] = prelu
- return prelu
+
+ if reuse and clone+'>' in self.network:
+ extra = {'alpha': self.network[clone+'>'].alpha}
+ else:
+ extra = {}
+ self.network[name+'>'] = lasagne.layers.ParametricRectifierLayer(conv, **extra)
+ return self.last_layer()
def make_block(self, name, input, units):
- self.make_layer(name+'-A', input, units, alpha=0.1)
- # self.make_layer(name+'-B', self.last_layer(), units, alpha=1.0)
- return ElemwiseSumLayer([input, self.last_layer()]) if args.generator_residual else self.last_layer()
+ self.make_layer(name+'-A', input, units, alpha=0.25)
+ self.make_layer(name+'-B', self.last_layer(), units, alpha=1.0)
+ if args.generator_residual:
+ self.network[name+'-R'] = ElemwiseSumLayer([input, self.last_layer()])
+ return self.last_layer()
def setup_generator(self, input, config):
for k, v in config.items(): setattr(args, k, v)
@@ -288,21 +354,22 @@ def setup_generator(self, input, config):
units_iter = extend(args.generator_filters)
units = next(units_iter)
- self.make_layer('iter.0', input, units, filter_size=(7,7), pad=(3,3))
+ self.make_layer('encode', input, units, filter_size=(7,7), pad=(3,3))
for i in range(0, args.generator_downscale):
- self.make_layer('downscale%i'%i, self.last_layer(), next(units_iter), filter_size=(4,4), stride=(2,2))
+ self.make_layer('%i/downscale'%i, self.last_layer(), next(units_iter), filter_size=4, stride=2, reuse=True)
units = next(units_iter)
for i in range(0, args.generator_blocks):
- self.make_block('iter.%i'%(i+1), self.last_layer(), units)
+ self.make_block('default.%i'%i, self.last_layer(), units)
for i in range(0, args.generator_upscale):
u = next(units_iter)
- self.make_layer('upscale%i.2'%i, self.last_layer(), u*4)
- self.network['upscale%i.1'%i] = SubpixelReshuffleLayer(self.last_layer(), u, 2)
+ self.make_layer('%i/upscale.2'%i, self.last_layer(), u*4, reuse=True)
+ self.network['%i/upscale.1'%i] = SubpixelReshuffleLayer(self.last_layer(), u, 2)
- self.network['out'] = ConvLayer(self.last_layer(), 3, filter_size=(7,7), pad=(3,3), nonlinearity=None)
+ self.network['decode'] = ConvLayer(self.last_layer(), 3, filter_size=(7,7), pad=(3,3), nonlinearity=None)
+ self.network['out'] = self.last_layer()
def setup_perceptual(self, input):
"""Use lasagne to create a network of convolution layers using pre-trained VGG19 weights.
@@ -477,12 +544,10 @@ def show_progress(self, orign, scald, repro):
self.imsave('valid/%s_%03i_reprod.png' % (args.model, i), repro[i])
def decay_learning_rate(self):
- l_r, t_cur = args.learning_rate, 0
-
- while True:
+ l_r = args.learning_rate
+ for t_cur in itertools.count():
yield l_r
- t_cur += 1
- if t_cur % args.learning_period == 0: l_r *= args.learning_decay
+ if (t_cur+1) % args.learning_period == 0: l_r *= args.learning_decay
def train(self):
seed_size = args.batch_shape // args.zoom
diff --git a/plugins/simple.py b/plugins/simple.py
new file mode 100644
index 0000000..5e759e3
--- /dev/null
+++ b/plugins/simple.py
@@ -0,0 +1,16 @@
+import glob
+import itertools
+
+import scipy.misc
+import scipy.ndimage
+
+
+def iterate_files():
+ return itertools.cycle(glob.glob('data/*.jpg'))
+
+def load_original(filename):
+ return scipy.ndimage.imread(filename, mode='RGB')
+
+def load_seed(filename, original, zoom):
+ target_shape = (original.shape[0]//zoom, original.shape[1]//zoom)
+ return scipy.misc.imresize(original, target_shape, interp='bilinear')