diff --git a/README.md b/README.md index 3d41e7a..fe85cde 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,10 @@ # Channel Attention Is All You Need for Video Frame Interpolation -[Project](https://myungsub.github.io/CAIN) | [Paper]() +#### Myungsub Choi, Heewon Kim, Bohyung Han, Ning Xu, Kyoung Mu Lee + +#### 2nd place in [[AIM 2019 ICCV Workshop](http://www.vision.ee.ethz.ch/aim19/)] - Video Temporal Super-Resolution Challenge + +[Project](https://myungsub.github.io/CAIN) | [Paper]() | [FactSheet]()
@@ -11,7 +15,11 @@ project │ README.md | run.sh - main script to train CAIN model | run_noca.sh - script to train CAIN_NoCA model +| eval.sh - script to evaluate on SNU-FILM benchmark | main.py - main file to run train/val +| config.py - check & change training/testing configurations here +| loss.py - defines different loss functions +| utils.py - misc. └───model │ │ common.py │ │ cain.py - main model @@ -42,16 +50,28 @@ conda activate cain conda install python=3.7 conda install pip numpy conda install pytorch torchvision cudatoolkit=10.1 -c pytorch -conda install tqdm opencv +conda install tqdm opencv tensorboard ``` +## Dataset Preparation + +- We use **[Vimeo90K Triplet dataset](http://toflow.csail.mit.edu/)** for training + testing. + - After downloading the full dataset, make symbolic links in `data/` folder : + - `ln -s /path/to/vimeo_triplet_data/ ./data/vimeo_triplet` + - Then you're done! +- For more thorough evaluation, we built **[SNU-FILM (SNU Frame Interpolation with Large Motion)](https://myungsub.github.io/CAIN)** benchmark. + - Download links can be found in the [project page](https://myungsub.github.io/CAIN). + - Also make symbolic links after download : + - `ln -s /path/to/SNU-FILM_data/ ./data/SNU-FILM` + - Done! + ## Usage -- First make symbolic links in `data/` folder : `ln -s /path/to/vimeo_triplet_data/ ./data/vimeo_triplet` - - [Vimeo90K dataset]() -- For training: `CUDA_VISIBLE_DEVICES=0 python main.py --exp_name EXPNAME --batch_size 16 --test_batch_size 16 --dataset vimeo90k --model cain --loss 1*L1 --max_epoch 200 --lr 0.0002` +- Training: `CUDA_VISIBLE_DEVICES=0 python main.py --exp_name EXPNAME --batch_size 16 --test_batch_size 16 --dataset vimeo90k --model cain --loss 1*L1 --max_epoch 200 --lr 0.0002` - Or, just run `./run.sh` - For testing performance on Vimeo90K dataset, just add `--mode test` option +- For testing on SNU-FILM dataset, run `./eval.sh` + - Testing mode (choose from ['easy', 'medium', 'hard', 'extreme']) can be modified by changing `--test_mode` option in `eval.sh`. ## Results diff --git a/config.py b/config.py index f283915..b569d9e 100644 --- a/config.py +++ b/config.py @@ -35,6 +35,7 @@ def add_argument_group(name): learn_arg.add_argument('--batch_size', type=int, default=16) learn_arg.add_argument('--val_batch_size', type=int, default=4) learn_arg.add_argument('--test_batch_size', type=int, default=1) +learn_arg.add_argument('--test_mode', type=str, default='hard', help='Test mode to evaluate on SNU-FILM dataset') learn_arg.add_argument('--start_epoch', type=int, default=0) learn_arg.add_argument('--max_epoch', type=int, default=200) learn_arg.add_argument('--resume', action='store_true') diff --git a/data/snufilm.py b/data/snufilm.py new file mode 100644 index 0000000..ad2b0e4 --- /dev/null +++ b/data/snufilm.py @@ -0,0 +1,55 @@ +import os + +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from PIL import Image + +class SNUFILM(Dataset): + def __init__(self, data_root, mode='hard'): + ''' + :param data_root: ./data/SNU-FILM + :param mode: ['easy', 'medium', 'hard', 'extreme'] + ''' + test_root = os.path.join(data_root, 'test') + test_fn = os.path.join(data_root, 'test-%s.txt' % mode) + with open(test_fn, 'r') as f: + self.frame_list = f.read().splitlines() + self.frame_list = [v.split(' ') for v in self.frame_list] + + self.transforms = transforms.Compose([ + transforms.ToTensor() + ]) + + print("[%s] Test dataset has %d triplets" % (mode, len(self.frame_list))) + + + def __getitem__(self, index): + + # Use self.test_all_images: + imgpaths = self.frame_list[index] + + img1 = Image.open(imgpaths[0]) + img2 = Image.open(imgpaths[1]) + img3 = Image.open(imgpaths[2]) + + img1 = self.transforms(img1) + img2 = self.transforms(img2) + img3 = self.transforms(img3) + + imgs = [img1, img2, img3] + + return imgs, imgpaths + + def __len__(self): + return len(self.frame_list) + + +def check_already_extracted(vid): + return bool(os.path.exists(vid + '/0001.png')) + + +def get_loader(mode, data_root, batch_size, shuffle, num_workers, test_mode='hard'): + # data_root = 'data/SNUFILM' + dataset = SNUFILM(data_root, mode=test_mode) + return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True) diff --git a/data/vimeo90k.py b/data/vimeo90k.py index d4afcd3..27bd4bd 100644 --- a/data/vimeo90k.py +++ b/data/vimeo90k.py @@ -60,8 +60,8 @@ def __getitem__(self, index): img3 = T(img3) imgs = [img1, img2, img3] - meta = {'imgpath': imgpaths} - return imgs, meta + + return imgs, imgpaths def __len__(self): if self.training: @@ -71,12 +71,10 @@ def __len__(self): return 0 -def get_loader(mode, data_root, batch_size, shuffle, num_workers, n_frames=1): +def get_loader(mode, data_root, batch_size, shuffle, num_workers, test_mode=None): if mode == 'train': is_training = True else: is_training = False dataset = VimeoTriplet(data_root, is_training=is_training) - #dataset = VimeoFineTune('data/vimeo_triplet', n_int_frames, is_training=is_training) - meta = [dataset.trainlist, dataset.testlist] - return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True), meta + return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True) diff --git a/eval.sh b/eval.sh new file mode 100755 index 0000000..f893217 --- /dev/null +++ b/eval.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +CUDA_VISIBLE_DEVICES=1 python main.py \ + --exp_name CAIN_eval \ + --dataset snufilm \ + --data_root data/SNU-FILM \ + --test_batch_size 1 \ + --model cain \ + --depth 3 \ + --mode test \ + --resume \ + --resume_exp CAIN_train \ + --test_mode hard \ No newline at end of file diff --git a/figures/overall_architecture.png b/figures/overall_architecture.png index 53f2c16..c7a40e5 100644 Binary files a/figures/overall_architecture.png and b/figures/overall_architecture.png differ diff --git a/main.py b/main.py index 35ad12e..2bd48af 100644 --- a/main.py +++ b/main.py @@ -36,7 +36,7 @@ ##### Load Dataset ##### train_loader, test_loader = utils.load_dataset( - args.dataset, args.data_root, args.batch_size, args.test_batch_size, args.num_workers) + args.dataset, args.data_root, args.batch_size, args.test_batch_size, args.num_workers, args.test_mode) ##### Build Model ##### @@ -94,10 +94,10 @@ def train(args, epoch): criterion.train() t = time.time() - for i, (images, meta) in enumerate(train_loader): + for i, (images, imgpaths) in enumerate(train_loader): # Build input batch - im1, im2, gt = utils.build_input(images, meta) + im1, im2, gt = utils.build_input(images, imgpaths) # Forward optimizer.zero_grad() @@ -142,6 +142,10 @@ def test(args, epoch, eval_alpha=0.5): criterion.eval() save_folder = 'test%03d' % epoch + if args.dataset == 'snufilm': + save_folder = os.path.join(save_folder, args.dataset, args.test_mode) + else: + save_folder = os.path.join(save_folder, args.dataset) save_dir = os.path.join('checkpoint', args.exp_name, save_folder) utils.makedirs(save_dir) save_fn = os.path.join(save_dir, 'results.txt') @@ -151,10 +155,10 @@ def test(args, epoch, eval_alpha=0.5): t = time.time() with torch.no_grad(): - for i, (images, meta) in enumerate(tqdm(test_loader)): + for i, (images, imgpaths) in enumerate(tqdm(test_loader)): # Build input batch - im1, im2, gt = utils.build_input(images, meta, is_training=False) + im1, im2, gt = utils.build_input(images, imgpaths, is_training=False) # Forward out, feats = model(im1, im2) @@ -171,16 +175,17 @@ def test(args, epoch, eval_alpha=0.5): # Log examples that have bad performance if (ssims.val < 0.9 or psnrs.val < 25) and epoch > 50: + print(imgpaths) print("\nLoss: %f, PSNR: %f, SSIM: %f, LPIPS: %f" % (losses['total'].val, psnrs.val, ssims.val, lpips.val)) - print(meta['imgpath'][1][-1]) + print(imgpaths[1][-1]) # Save result images if ((epoch + 1) % 1 == 0 and i < 20) or args.mode == 'test': savepath = os.path.join('checkpoint', args.exp_name, save_folder) for b in range(images[0].size(0)): - paths = meta['imgpath'][1][b].split('/') + paths = imgpaths[1][b].split('/') fp = os.path.join(savepath, paths[-3], paths[-2]) if not os.path.exists(fp): os.makedirs(fp) diff --git a/run.sh b/run.sh index a12411d..db30e3b 100755 --- a/run.sh +++ b/run.sh @@ -1,7 +1,7 @@ #!/bin/bash CUDA_VISIBLE_DEVICES=0 python main.py \ - --exp_name CAIN_test \ + --exp_name CAIN_train \ --dataset vimeo90k \ --batch_size 16 \ --test_batch_size 16 \ @@ -11,7 +11,4 @@ CUDA_VISIBLE_DEVICES=0 python main.py \ --max_epoch 200 \ --lr 0.0002 \ --log_iter 100 \ -# --mode test -# --resume True \ -# --resume_exp SH_5_12 -# --fix_encoder \ No newline at end of file +# --mode test \ No newline at end of file diff --git a/utils.py b/utils.py index 63260df..b7b5b10 100644 --- a/utils.py +++ b/utils.py @@ -30,11 +30,12 @@ # Training Helper Functions for making main.py clean ########################## -def load_dataset(dataset_str, data_root, batch_size, test_batch_size, num_workers): +def load_dataset(dataset_str, data_root, batch_size, test_batch_size, num_workers, test_mode): if dataset_str == 'snufilm': from data.snufilm import get_loader - from data.vimeo90k import get_loader as get_test_loader + test_loader = get_loader('test', data_root, test_batch_size, shuffle=False, num_workers=num_workers, test_mode=test_mode) + return None, test_loader elif dataset_str == 'vimeo90k': from data.vimeo90k import get_loader elif dataset_str == 'aim': @@ -42,24 +43,22 @@ def load_dataset(dataset_str, data_root, batch_size, test_batch_size, num_worker else: raise NotImplementedError('Training / Testing for this dataset is not implemented.') - train_loader, _ = get_loader('train', data_root, batch_size, shuffle=True, num_workers=num_workers, n_frames=2) - if dataset_str == 'snufilm': - test_loader, _ = get_test_loader('test', data_root, test_batch_size, shuffle=False, num_workers=num_workers, n_frames=1) - elif dataset_str == 'aim': - test_loader, _ = get_loader('val', data_root, test_batch_size, shuffle=False, num_workers=num_workers, n_frames=1) + train_loader = get_loader('train', data_root, batch_size, shuffle=True, num_workers=num_workers) + if dataset_str == 'aim': + test_loader = get_loader('val', data_root, test_batch_size, shuffle=False, num_workers=num_workers) else: - test_loader, _ = get_loader('test', data_root, test_batch_size, shuffle=False, num_workers=num_workers, n_frames=1) + test_loader = get_loader('test', data_root, test_batch_size, shuffle=False, num_workers=num_workers) return train_loader, test_loader -def build_input(images, meta, is_training=True, include_edge=False, device=torch.device('cuda')): +def build_input(images, imgpaths, is_training=True, include_edge=False, device=torch.device('cuda')): if isinstance(images[0], list): images_gathered = [None, None, None] for j in range(len(images[0])): # 3 _images = [images[k][j] for k in range(len(images))] images_gathered[j] = torch.cat(_images, 0) - meta['imgpath'] = [p for _ in images for p in meta['imgpath']] + imgpaths = [p for _ in images for p in imgpaths] images = images_gathered im1, im2 = images[0].to(device), images[2].to(device) @@ -71,8 +70,8 @@ def build_input(images, meta, is_training=True, include_edge=False, device=torch def load_checkpoint(args, model, optimizer, fix_loaded=False): if args.resume_exp is None: args.resume_exp = args.exp_name - #load_name = os.path.join('checkpoint', args.resume_exp, 'model_best.pth') - load_name = os.path.join('checkpoint', args.resume_exp, 'checkpoint.pth') + load_name = os.path.join('checkpoint', args.resume_exp, 'model_best.pth') + #load_name = os.path.join('checkpoint', args.resume_exp, 'checkpoint.pth') print("loading checkpoint %s" % load_name) checkpoint = torch.load(load_name) args.start_epoch = checkpoint['epoch'] + 1