Skip to content

Commit

Permalink
Add custom generation script
Browse files Browse the repository at this point in the history
  • Loading branch information
myungsub committed Mar 20, 2020
1 parent cbbf0d6 commit 56d2122
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 6 deletions.
1 change: 1 addition & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def add_argument_group(name):
data_arg.add_argument('--dataset', type=str, default='vimeo90k')
data_arg.add_argument('--num_frames', type=int, default=3)
data_arg.add_argument('--data_root', type=str, default='data/vimeo_triplet')
data_arg.add_argument('--img_fmt', type=str, default='png')

# Model
model_arg = add_argument_group('Model')
Expand Down
48 changes: 48 additions & 0 deletions data/video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
import glob
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

class Video(Dataset):
def __init__(self, data_root, fmt='png'):
images = sorted(glob.glob(os.path.join(data_root, '*.%s' % fmt)))
for im in images:
try:
float_ind = float(im.split('_')[-1][:-4])
except ValueError:
os.rename(im, '%s_%.06f.%s' % (im[:-4], 0.0, fmt))
# re
images = sorted(glob.glob(os.path.join(data_root, '*.%s' % fmt)))
self.imglist = [[images[i], images[i+1]] for i in range(len(images)-1)]
print('[%d] images ready to be loaded' % len(self.imglist))


def __getitem__(self, index):
imgpaths = self.imglist[index]

# Load images
img1 = Image.open(imgpaths[0])
img2 = Image.open(imgpaths[1])

T = transforms.ToTensor()
img1 = T(img1)
img2 = T(img2)

imgs = [img1, img2]
meta = {'imgpath': imgpaths}
return imgs, meta

def __len__(self):
return len(self.imglist)


def get_loader(mode, data_root, batch_size, img_fmt='png', shuffle=False, num_workers=0, n_frames=1):
if mode == 'train':
is_training = True
else:
is_training = False
dataset = Video(data_root, fmt=img_fmt)
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
124 changes: 124 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import os
import sys
import time
import copy
import shutil
import random

import torch
import numpy as np
from tqdm import tqdm

import config
import utils


##### Parse CmdLine Arguments #####
args, unparsed = config.get_args()
cwd = os.getcwd()
print(args)


device = torch.device('cuda' if args.cuda else 'cpu')
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

torch.manual_seed(args.random_seed)
if args.cuda:
torch.cuda.manual_seed(args.random_seed)




##### Build Model #####
if args.model.lower() == 'cain_encdec':
from model.cain_encdec import CAIN_EncDec
print('Building model: CAIN_EncDec')
model = CAIN_EncDec(depth=args.depth, start_filts=32)
elif args.model.lower() == 'cain':
from model.cain import CAIN
print("Building model: CAIN")
model = CAIN(depth=args.depth)
elif args.model.lower() == 'cain_noca':
from model.cain_noca import CAIN_NoCA
print("Building model: CAIN_NoCA")
model = CAIN_NoCA(depth=args.depth)
else:
raise NotImplementedError("Unknown model!")
# Just make every model to DataParallel
model = torch.nn.DataParallel(model).to(device)
#print(model)

print('# of parameters: %d' % sum(p.numel() for p in model.parameters()))


# If resume, load checkpoint: model
if args.resume:
#utils.load_checkpoint(args, model, optimizer=None)
checkpoint = torch.load('pretrained_cain.pth')
args.start_epoch = checkpoint['epoch'] + 1
model.load_state_dict(checkpoint['state_dict'])
del checkpoint



def test(args, epoch):
print('Evaluating for epoch = %d' % epoch)
##### Load Dataset #####
test_loader = utils.load_dataset(
args.dataset, args.data_root, args.batch_size, args.test_batch_size, args.num_workers, img_fmt=args.img_fmt)
model.eval()

t = time.time()
with torch.no_grad():
for i, (images, meta) in enumerate(tqdm(test_loader)):

# Build input batch
im1, im2 = images[0].to(device), images[1].to(device)

# Forward
out, _ = model(im1, im2)

# Save result images
if args.mode == 'test':
for b in range(images[0].size(0)):
paths = meta['imgpath'][0][b].split('/')
fp = args.data_root
fp = os.path.join(fp, paths[-1][:-4]) # remove '.png' extension

# Decide float index
i1_str = paths[-1][:-4]
i2_str = meta['imgpath'][1][b].split('/')[-1][:-4]
try:
i1 = float(i1_str.split('_')[-1])
except ValueError:
i1 = 0.0
try:
i2 = float(i2_str.split('_')[-1])
if i2 == 0.0:
i2 = 1.0
except ValueError:
i2 = 1.0
fpos = max(0, fp.rfind('_'))
fInd = (i1 + i2) / 2
savepath = "%s_%06f.%s" % (fp[:fpos], fInd, args.img_fmt)
utils.save_image(out[b], savepath)

# Print progress
print('im_processed: {:d}/{:d} {:.3f}s \r'.format(i + 1, len(test_loader), time.time() - t))

return


""" Entry Point """
def main(args):

num_iter = 2 # x2**num_iter interpolation
for _ in range(num_iter):

# run test
test(args, args.start_epoch)


if __name__ == "__main__":
main(args)
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@

# Learning Rate Scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=3, verbose=True)
optimizer, mode='min', factor=0.5, patience=5, verbose=True)


# Initialize LPIPS model if used for evaluation
Expand Down Expand Up @@ -137,6 +137,7 @@ def train(args, epoch):


def test(args, epoch, eval_alpha=0.5):
print('Evaluating for epoch = %d' % epoch)
losses, psnrs, ssims, lpips = utils.init_meters(args.loss)
model.eval()
criterion.eval()
Expand Down
4 changes: 2 additions & 2 deletions model/cain_encdec.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def forward(self, x1, x2):


class Decoder(nn.Module):
def __init__(self, in_channels=256, out_channels=3, depth=3, norm=False, up_mode='shuffle'):
def __init__(self, in_channels=192, out_channels=3, depth=3, norm=False, up_mode='shuffle'):
super(Decoder, self).__init__()
self.device = torch.device('cuda')

Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(self, depth=3, n_resblocks=3, start_filts=32, up_mode='shuffle'):
self.depth = depth

self.encoder = Encoder(in_channels=3, depth=depth, norm=False)
self.decoder = Decoder(in_channels=256, depth=depth, norm=False, up_mode=up_mode)
self.decoder = Decoder(in_channels=start_filts*6, depth=depth, norm=False, up_mode=up_mode)

def forward(self, x1, x2):
x1, m1 = sub_mean(x1)
Expand Down
14 changes: 14 additions & 0 deletions test_custom.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash

CUDA_VISIBLE_DEVICES=0 python generate.py \
--exp_name CAIN_fin \
--dataset custom \
--data_root data/frame_seq \
--img_fmt png \
--batch_size 32 \
--test_batch_size 16 \
--model cain \
--depth 3 \
--loss 1*L1 \
--resume \
--mode test
13 changes: 10 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# Training Helper Functions for making main.py clean
##########################

def load_dataset(dataset_str, data_root, batch_size, test_batch_size, num_workers):
def load_dataset(dataset_str, data_root, batch_size, test_batch_size, num_workers, img_fmt='png'):

if dataset_str == 'snufilm':
from data.snufilm import get_loader
Expand All @@ -39,6 +39,10 @@ def load_dataset(dataset_str, data_root, batch_size, test_batch_size, num_worker
from data.vimeo90k import get_loader
elif dataset_str == 'aim':
from data.aim import get_loader
elif dataset_str == 'custom':
from data.video import get_loader
test_loader = get_loader('test', data_root, test_batch_size, img_fmt=img_fmt, shuffle=False, num_workers=num_workers, n_frames=1)
return test_loader
else:
raise NotImplementedError('Training / Testing for this dataset is not implemented.')

Expand Down Expand Up @@ -71,8 +75,11 @@ def build_input(images, meta, is_training=True, include_edge=False, device=torch
def load_checkpoint(args, model, optimizer, fix_loaded=False):
if args.resume_exp is None:
args.resume_exp = args.exp_name
#load_name = os.path.join('checkpoint', args.resume_exp, 'model_best.pth')
load_name = os.path.join('checkpoint', args.resume_exp, 'checkpoint.pth')
if args.mode == 'test':
load_name = os.path.join('checkpoint', args.resume_exp, 'model_best.pth')
else:
#load_name = os.path.join('checkpoint', args.resume_exp, 'model_best.pth')
load_name = os.path.join('checkpoint', args.resume_exp, 'checkpoint.pth')
print("loading checkpoint %s" % load_name)
checkpoint = torch.load(load_name)
args.start_epoch = checkpoint['epoch'] + 1
Expand Down

0 comments on commit 56d2122

Please sign in to comment.