diff --git a/README.md b/README.md
index 3d41e7a..fe85cde 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,10 @@
# Channel Attention Is All You Need for Video Frame Interpolation
-[Project](https://myungsub.github.io/CAIN) | [Paper]()
+#### Myungsub Choi, Heewon Kim, Bohyung Han, Ning Xu, Kyoung Mu Lee
+
+#### 2nd place in [[AIM 2019 ICCV Workshop](http://www.vision.ee.ethz.ch/aim19/)] - Video Temporal Super-Resolution Challenge
+
+[Project](https://myungsub.github.io/CAIN) | [Paper]() | [FactSheet]()
@@ -11,7 +15,11 @@ project
│ README.md
| run.sh - main script to train CAIN model
| run_noca.sh - script to train CAIN_NoCA model
+| eval.sh - script to evaluate on SNU-FILM benchmark
| main.py - main file to run train/val
+| config.py - check & change training/testing configurations here
+| loss.py - defines different loss functions
+| utils.py - misc.
└───model
│ │ common.py
│ │ cain.py - main model
@@ -42,16 +50,28 @@ conda activate cain
conda install python=3.7
conda install pip numpy
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
-conda install tqdm opencv
+conda install tqdm opencv tensorboard
```
+## Dataset Preparation
+
+- We use **[Vimeo90K Triplet dataset](http://toflow.csail.mit.edu/)** for training + testing.
+ - After downloading the full dataset, make symbolic links in `data/` folder :
+ - `ln -s /path/to/vimeo_triplet_data/ ./data/vimeo_triplet`
+ - Then you're done!
+- For more thorough evaluation, we built **[SNU-FILM (SNU Frame Interpolation with Large Motion)](https://myungsub.github.io/CAIN)** benchmark.
+ - Download links can be found in the [project page](https://myungsub.github.io/CAIN).
+ - Also make symbolic links after download :
+ - `ln -s /path/to/SNU-FILM_data/ ./data/SNU-FILM`
+ - Done!
+
## Usage
-- First make symbolic links in `data/` folder : `ln -s /path/to/vimeo_triplet_data/ ./data/vimeo_triplet`
- - [Vimeo90K dataset]()
-- For training: `CUDA_VISIBLE_DEVICES=0 python main.py --exp_name EXPNAME --batch_size 16 --test_batch_size 16 --dataset vimeo90k --model cain --loss 1*L1 --max_epoch 200 --lr 0.0002`
+- Training: `CUDA_VISIBLE_DEVICES=0 python main.py --exp_name EXPNAME --batch_size 16 --test_batch_size 16 --dataset vimeo90k --model cain --loss 1*L1 --max_epoch 200 --lr 0.0002`
- Or, just run `./run.sh`
- For testing performance on Vimeo90K dataset, just add `--mode test` option
+- For testing on SNU-FILM dataset, run `./eval.sh`
+ - Testing mode (choose from ['easy', 'medium', 'hard', 'extreme']) can be modified by changing `--test_mode` option in `eval.sh`.
## Results
diff --git a/config.py b/config.py
index f283915..b569d9e 100644
--- a/config.py
+++ b/config.py
@@ -35,6 +35,7 @@ def add_argument_group(name):
learn_arg.add_argument('--batch_size', type=int, default=16)
learn_arg.add_argument('--val_batch_size', type=int, default=4)
learn_arg.add_argument('--test_batch_size', type=int, default=1)
+learn_arg.add_argument('--test_mode', type=str, default='hard', help='Test mode to evaluate on SNU-FILM dataset')
learn_arg.add_argument('--start_epoch', type=int, default=0)
learn_arg.add_argument('--max_epoch', type=int, default=200)
learn_arg.add_argument('--resume', action='store_true')
diff --git a/data/snufilm.py b/data/snufilm.py
new file mode 100644
index 0000000..ad2b0e4
--- /dev/null
+++ b/data/snufilm.py
@@ -0,0 +1,55 @@
+import os
+
+import torch
+from torch.utils.data import Dataset, DataLoader
+from torchvision import transforms
+from PIL import Image
+
+class SNUFILM(Dataset):
+ def __init__(self, data_root, mode='hard'):
+ '''
+ :param data_root: ./data/SNU-FILM
+ :param mode: ['easy', 'medium', 'hard', 'extreme']
+ '''
+ test_root = os.path.join(data_root, 'test')
+ test_fn = os.path.join(data_root, 'test-%s.txt' % mode)
+ with open(test_fn, 'r') as f:
+ self.frame_list = f.read().splitlines()
+ self.frame_list = [v.split(' ') for v in self.frame_list]
+
+ self.transforms = transforms.Compose([
+ transforms.ToTensor()
+ ])
+
+ print("[%s] Test dataset has %d triplets" % (mode, len(self.frame_list)))
+
+
+ def __getitem__(self, index):
+
+ # Use self.test_all_images:
+ imgpaths = self.frame_list[index]
+
+ img1 = Image.open(imgpaths[0])
+ img2 = Image.open(imgpaths[1])
+ img3 = Image.open(imgpaths[2])
+
+ img1 = self.transforms(img1)
+ img2 = self.transforms(img2)
+ img3 = self.transforms(img3)
+
+ imgs = [img1, img2, img3]
+
+ return imgs, imgpaths
+
+ def __len__(self):
+ return len(self.frame_list)
+
+
+def check_already_extracted(vid):
+ return bool(os.path.exists(vid + '/0001.png'))
+
+
+def get_loader(mode, data_root, batch_size, shuffle, num_workers, test_mode='hard'):
+ # data_root = 'data/SNUFILM'
+ dataset = SNUFILM(data_root, mode=test_mode)
+ return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
diff --git a/data/vimeo90k.py b/data/vimeo90k.py
index d4afcd3..27bd4bd 100644
--- a/data/vimeo90k.py
+++ b/data/vimeo90k.py
@@ -60,8 +60,8 @@ def __getitem__(self, index):
img3 = T(img3)
imgs = [img1, img2, img3]
- meta = {'imgpath': imgpaths}
- return imgs, meta
+
+ return imgs, imgpaths
def __len__(self):
if self.training:
@@ -71,12 +71,10 @@ def __len__(self):
return 0
-def get_loader(mode, data_root, batch_size, shuffle, num_workers, n_frames=1):
+def get_loader(mode, data_root, batch_size, shuffle, num_workers, test_mode=None):
if mode == 'train':
is_training = True
else:
is_training = False
dataset = VimeoTriplet(data_root, is_training=is_training)
- #dataset = VimeoFineTune('data/vimeo_triplet', n_int_frames, is_training=is_training)
- meta = [dataset.trainlist, dataset.testlist]
- return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True), meta
+ return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
diff --git a/eval.sh b/eval.sh
new file mode 100755
index 0000000..f893217
--- /dev/null
+++ b/eval.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+
+CUDA_VISIBLE_DEVICES=1 python main.py \
+ --exp_name CAIN_eval \
+ --dataset snufilm \
+ --data_root data/SNU-FILM \
+ --test_batch_size 1 \
+ --model cain \
+ --depth 3 \
+ --mode test \
+ --resume \
+ --resume_exp CAIN_train \
+ --test_mode hard
\ No newline at end of file
diff --git a/figures/overall_architecture.png b/figures/overall_architecture.png
index 53f2c16..c7a40e5 100644
Binary files a/figures/overall_architecture.png and b/figures/overall_architecture.png differ
diff --git a/main.py b/main.py
index 35ad12e..2bd48af 100644
--- a/main.py
+++ b/main.py
@@ -36,7 +36,7 @@
##### Load Dataset #####
train_loader, test_loader = utils.load_dataset(
- args.dataset, args.data_root, args.batch_size, args.test_batch_size, args.num_workers)
+ args.dataset, args.data_root, args.batch_size, args.test_batch_size, args.num_workers, args.test_mode)
##### Build Model #####
@@ -94,10 +94,10 @@ def train(args, epoch):
criterion.train()
t = time.time()
- for i, (images, meta) in enumerate(train_loader):
+ for i, (images, imgpaths) in enumerate(train_loader):
# Build input batch
- im1, im2, gt = utils.build_input(images, meta)
+ im1, im2, gt = utils.build_input(images, imgpaths)
# Forward
optimizer.zero_grad()
@@ -142,6 +142,10 @@ def test(args, epoch, eval_alpha=0.5):
criterion.eval()
save_folder = 'test%03d' % epoch
+ if args.dataset == 'snufilm':
+ save_folder = os.path.join(save_folder, args.dataset, args.test_mode)
+ else:
+ save_folder = os.path.join(save_folder, args.dataset)
save_dir = os.path.join('checkpoint', args.exp_name, save_folder)
utils.makedirs(save_dir)
save_fn = os.path.join(save_dir, 'results.txt')
@@ -151,10 +155,10 @@ def test(args, epoch, eval_alpha=0.5):
t = time.time()
with torch.no_grad():
- for i, (images, meta) in enumerate(tqdm(test_loader)):
+ for i, (images, imgpaths) in enumerate(tqdm(test_loader)):
# Build input batch
- im1, im2, gt = utils.build_input(images, meta, is_training=False)
+ im1, im2, gt = utils.build_input(images, imgpaths, is_training=False)
# Forward
out, feats = model(im1, im2)
@@ -171,16 +175,17 @@ def test(args, epoch, eval_alpha=0.5):
# Log examples that have bad performance
if (ssims.val < 0.9 or psnrs.val < 25) and epoch > 50:
+ print(imgpaths)
print("\nLoss: %f, PSNR: %f, SSIM: %f, LPIPS: %f" %
(losses['total'].val, psnrs.val, ssims.val, lpips.val))
- print(meta['imgpath'][1][-1])
+ print(imgpaths[1][-1])
# Save result images
if ((epoch + 1) % 1 == 0 and i < 20) or args.mode == 'test':
savepath = os.path.join('checkpoint', args.exp_name, save_folder)
for b in range(images[0].size(0)):
- paths = meta['imgpath'][1][b].split('/')
+ paths = imgpaths[1][b].split('/')
fp = os.path.join(savepath, paths[-3], paths[-2])
if not os.path.exists(fp):
os.makedirs(fp)
diff --git a/run.sh b/run.sh
index a12411d..db30e3b 100755
--- a/run.sh
+++ b/run.sh
@@ -1,7 +1,7 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python main.py \
- --exp_name CAIN_test \
+ --exp_name CAIN_train \
--dataset vimeo90k \
--batch_size 16 \
--test_batch_size 16 \
@@ -11,7 +11,4 @@ CUDA_VISIBLE_DEVICES=0 python main.py \
--max_epoch 200 \
--lr 0.0002 \
--log_iter 100 \
-# --mode test
-# --resume True \
-# --resume_exp SH_5_12
-# --fix_encoder
\ No newline at end of file
+# --mode test
\ No newline at end of file
diff --git a/utils.py b/utils.py
index 63260df..b7b5b10 100644
--- a/utils.py
+++ b/utils.py
@@ -30,11 +30,12 @@
# Training Helper Functions for making main.py clean
##########################
-def load_dataset(dataset_str, data_root, batch_size, test_batch_size, num_workers):
+def load_dataset(dataset_str, data_root, batch_size, test_batch_size, num_workers, test_mode):
if dataset_str == 'snufilm':
from data.snufilm import get_loader
- from data.vimeo90k import get_loader as get_test_loader
+ test_loader = get_loader('test', data_root, test_batch_size, shuffle=False, num_workers=num_workers, test_mode=test_mode)
+ return None, test_loader
elif dataset_str == 'vimeo90k':
from data.vimeo90k import get_loader
elif dataset_str == 'aim':
@@ -42,24 +43,22 @@ def load_dataset(dataset_str, data_root, batch_size, test_batch_size, num_worker
else:
raise NotImplementedError('Training / Testing for this dataset is not implemented.')
- train_loader, _ = get_loader('train', data_root, batch_size, shuffle=True, num_workers=num_workers, n_frames=2)
- if dataset_str == 'snufilm':
- test_loader, _ = get_test_loader('test', data_root, test_batch_size, shuffle=False, num_workers=num_workers, n_frames=1)
- elif dataset_str == 'aim':
- test_loader, _ = get_loader('val', data_root, test_batch_size, shuffle=False, num_workers=num_workers, n_frames=1)
+ train_loader = get_loader('train', data_root, batch_size, shuffle=True, num_workers=num_workers)
+ if dataset_str == 'aim':
+ test_loader = get_loader('val', data_root, test_batch_size, shuffle=False, num_workers=num_workers)
else:
- test_loader, _ = get_loader('test', data_root, test_batch_size, shuffle=False, num_workers=num_workers, n_frames=1)
+ test_loader = get_loader('test', data_root, test_batch_size, shuffle=False, num_workers=num_workers)
return train_loader, test_loader
-def build_input(images, meta, is_training=True, include_edge=False, device=torch.device('cuda')):
+def build_input(images, imgpaths, is_training=True, include_edge=False, device=torch.device('cuda')):
if isinstance(images[0], list):
images_gathered = [None, None, None]
for j in range(len(images[0])): # 3
_images = [images[k][j] for k in range(len(images))]
images_gathered[j] = torch.cat(_images, 0)
- meta['imgpath'] = [p for _ in images for p in meta['imgpath']]
+ imgpaths = [p for _ in images for p in imgpaths]
images = images_gathered
im1, im2 = images[0].to(device), images[2].to(device)
@@ -71,8 +70,8 @@ def build_input(images, meta, is_training=True, include_edge=False, device=torch
def load_checkpoint(args, model, optimizer, fix_loaded=False):
if args.resume_exp is None:
args.resume_exp = args.exp_name
- #load_name = os.path.join('checkpoint', args.resume_exp, 'model_best.pth')
- load_name = os.path.join('checkpoint', args.resume_exp, 'checkpoint.pth')
+ load_name = os.path.join('checkpoint', args.resume_exp, 'model_best.pth')
+ #load_name = os.path.join('checkpoint', args.resume_exp, 'checkpoint.pth')
print("loading checkpoint %s" % load_name)
checkpoint = torch.load(load_name)
args.start_epoch = checkpoint['epoch'] + 1