Skip to content

Commit

Permalink
Resolve merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
myungsub committed Mar 20, 2020
2 parents 136f2fd + 89e4c13 commit e325846
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 31 deletions.
32 changes: 29 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Channel Attention Is All You Need for Video Frame Interpolation

[Project](https://myungsub.github.io/CAIN) | [Paper](https://aaai.org/Papers/AAAI/2020GB/AAAI-ChoiM.4773.pdf)
#### Myungsub Choi, Heewon Kim, Bohyung Han, Ning Xu, Kyoung Mu Lee

#### 2nd place in [[AIM 2019 ICCV Workshop](http://www.vision.ee.ethz.ch/aim19/)] - Video Temporal Super-Resolution Challenge

[Project](https://myungsub.github.io/CAIN) | [Paper](https://aaai.org/Papers/AAAI/2020GB/AAAI-ChoiM.4773.pdf) | [FactSheet]()

<center><img src="./figures/overall_architecture.png" width="90%"></center>

Expand All @@ -12,7 +16,11 @@ project
| run.sh - main script to train CAIN model
| run_noca.sh - script to train CAIN_NoCA model
| test_custom.sh - script to run interpolation on custom dataset
| eval.sh - script to evaluate on SNU-FILM benchmark
| main.py - main file to run train/val
| config.py - check & change training/testing configurations here
| loss.py - defines different loss functions
| utils.py - misc.
└───model
│ │ common.py
│ │ cain.py - main model
Expand Down Expand Up @@ -44,17 +52,35 @@ conda activate cain
conda install python=3.7
conda install pip numpy
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
conda install tqdm opencv
conda install tqdm opencv tensorboard
```

## Dataset Preparation

- We use **[Vimeo90K Triplet dataset](http://toflow.csail.mit.edu/)** for training + testing.
- After downloading the full dataset, make symbolic links in `data/` folder :
- `ln -s /path/to/vimeo_triplet_data/ ./data/vimeo_triplet`
- Then you're done!
- For more thorough evaluation, we built **[SNU-FILM (SNU Frame Interpolation with Large Motion)](https://myungsub.github.io/CAIN)** benchmark.
- Download links can be found in the [project page](https://myungsub.github.io/CAIN).
- Also make symbolic links after download :
- `ln -s /path/to/SNU-FILM_data/ ./data/SNU-FILM`
- Done!

## Usage

<<<<<<< HEAD
#### Training / Testing with Vimeo90K dataset
- First make symbolic links in `data/` folder : `ln -s /path/to/vimeo_triplet_data/ ./data/vimeo_triplet`
- [Vimeo90K dataset](http://toflow.csail.mit.edu/)
- For training: `CUDA_VISIBLE_DEVICES=0 python main.py --exp_name EXPNAME --batch_size 16 --test_batch_size 16 --dataset vimeo90k --model cain --loss 1*L1 --max_epoch 200 --lr 0.0002`
=======
- Training: `CUDA_VISIBLE_DEVICES=0 python main.py --exp_name EXPNAME --batch_size 16 --test_batch_size 16 --dataset vimeo90k --model cain --loss 1*L1 --max_epoch 200 --lr 0.0002`
>>>>>>> 89e4c137938ab1348477fda6e1f31ab0e3a7f594
- Or, just run `./run.sh`
- For testing performance on Vimeo90K dataset, just add `--mode test` option
- For testing on SNU-FILM dataset, run `./eval.sh`
- Testing mode (choose from ['easy', 'medium', 'hard', 'extreme']) can be modified by changing `--test_mode` option in `eval.sh`.

#### Interpolating with custom video
- Download pretrained models from [Here](https://drive.google.com/open?id=1BHy1gkejHxy-7vCwKczTb4Jviks8KOd3)
Expand Down Expand Up @@ -85,4 +111,4 @@ Many parts of this code is adapted from:
- [EDSR-Pytorch](https://github.com/thstkdgus35/EDSR-PyTorch)
- [RCAN](https://github.com/yulunzhang/RCAN)

We thank the authors for sharing codes for their great works.
We thank the authors for sharing codes for their great works.
1 change: 1 addition & 0 deletions _config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
theme: jekyll-theme-cayman
1 change: 1 addition & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def add_argument_group(name):
learn_arg.add_argument('--batch_size', type=int, default=16)
learn_arg.add_argument('--val_batch_size', type=int, default=4)
learn_arg.add_argument('--test_batch_size', type=int, default=1)
learn_arg.add_argument('--test_mode', type=str, default='hard', help='Test mode to evaluate on SNU-FILM dataset')
learn_arg.add_argument('--start_epoch', type=int, default=0)
learn_arg.add_argument('--max_epoch', type=int, default=200)
learn_arg.add_argument('--resume', action='store_true')
Expand Down
55 changes: 55 additions & 0 deletions data/snufilm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

class SNUFILM(Dataset):
def __init__(self, data_root, mode='hard'):
'''
:param data_root: ./data/SNU-FILM
:param mode: ['easy', 'medium', 'hard', 'extreme']
'''
test_root = os.path.join(data_root, 'test')
test_fn = os.path.join(data_root, 'test-%s.txt' % mode)
with open(test_fn, 'r') as f:
self.frame_list = f.read().splitlines()
self.frame_list = [v.split(' ') for v in self.frame_list]

self.transforms = transforms.Compose([
transforms.ToTensor()
])

print("[%s] Test dataset has %d triplets" % (mode, len(self.frame_list)))


def __getitem__(self, index):

# Use self.test_all_images:
imgpaths = self.frame_list[index]

img1 = Image.open(imgpaths[0])
img2 = Image.open(imgpaths[1])
img3 = Image.open(imgpaths[2])

img1 = self.transforms(img1)
img2 = self.transforms(img2)
img3 = self.transforms(img3)

imgs = [img1, img2, img3]

return imgs, imgpaths

def __len__(self):
return len(self.frame_list)


def check_already_extracted(vid):
return bool(os.path.exists(vid + '/0001.png'))


def get_loader(mode, data_root, batch_size, shuffle, num_workers, test_mode='hard'):
# data_root = 'data/SNUFILM'
dataset = SNUFILM(data_root, mode=test_mode)
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
10 changes: 4 additions & 6 deletions data/vimeo90k.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def __getitem__(self, index):
img3 = T(img3)

imgs = [img1, img2, img3]
meta = {'imgpath': imgpaths}
return imgs, meta

return imgs, imgpaths

def __len__(self):
if self.training:
Expand All @@ -71,12 +71,10 @@ def __len__(self):
return 0


def get_loader(mode, data_root, batch_size, shuffle, num_workers, n_frames=1):
def get_loader(mode, data_root, batch_size, shuffle, num_workers, test_mode=None):
if mode == 'train':
is_training = True
else:
is_training = False
dataset = VimeoTriplet(data_root, is_training=is_training)
#dataset = VimeoFineTune('data/vimeo_triplet', n_int_frames, is_training=is_training)
meta = [dataset.trainlist, dataset.testlist]
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True), meta
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
13 changes: 13 additions & 0 deletions eval.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash

CUDA_VISIBLE_DEVICES=1 python main.py \
--exp_name CAIN_eval \
--dataset snufilm \
--data_root data/SNU-FILM \
--test_batch_size 1 \
--model cain \
--depth 3 \
--mode test \
--resume \
--resume_exp CAIN_train \
--test_mode hard
Binary file modified figures/overall_architecture.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 12 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

##### Load Dataset #####
train_loader, test_loader = utils.load_dataset(
args.dataset, args.data_root, args.batch_size, args.test_batch_size, args.num_workers)
args.dataset, args.data_root, args.batch_size, args.test_batch_size, args.num_workers, args.test_mode)


##### Build Model #####
Expand Down Expand Up @@ -94,10 +94,10 @@ def train(args, epoch):
criterion.train()

t = time.time()
for i, (images, meta) in enumerate(train_loader):
for i, (images, imgpaths) in enumerate(train_loader):

# Build input batch
im1, im2, gt = utils.build_input(images, meta)
im1, im2, gt = utils.build_input(images, imgpaths)

# Forward
optimizer.zero_grad()
Expand Down Expand Up @@ -143,6 +143,10 @@ def test(args, epoch, eval_alpha=0.5):
criterion.eval()

save_folder = 'test%03d' % epoch
if args.dataset == 'snufilm':
save_folder = os.path.join(save_folder, args.dataset, args.test_mode)
else:
save_folder = os.path.join(save_folder, args.dataset)
save_dir = os.path.join('checkpoint', args.exp_name, save_folder)
utils.makedirs(save_dir)
save_fn = os.path.join(save_dir, 'results.txt')
Expand All @@ -152,10 +156,10 @@ def test(args, epoch, eval_alpha=0.5):

t = time.time()
with torch.no_grad():
for i, (images, meta) in enumerate(tqdm(test_loader)):
for i, (images, imgpaths) in enumerate(tqdm(test_loader)):

# Build input batch
im1, im2, gt = utils.build_input(images, meta, is_training=False)
im1, im2, gt = utils.build_input(images, imgpaths, is_training=False)

# Forward
out, feats = model(im1, im2)
Expand All @@ -172,16 +176,17 @@ def test(args, epoch, eval_alpha=0.5):

# Log examples that have bad performance
if (ssims.val < 0.9 or psnrs.val < 25) and epoch > 50:
print(imgpaths)
print("\nLoss: %f, PSNR: %f, SSIM: %f, LPIPS: %f" %
(losses['total'].val, psnrs.val, ssims.val, lpips.val))
print(meta['imgpath'][1][-1])
print(imgpaths[1][-1])

# Save result images
if ((epoch + 1) % 1 == 0 and i < 20) or args.mode == 'test':
savepath = os.path.join('checkpoint', args.exp_name, save_folder)

for b in range(images[0].size(0)):
paths = meta['imgpath'][1][b].split('/')
paths = imgpaths[1][b].split('/')
fp = os.path.join(savepath, paths[-3], paths[-2])
if not os.path.exists(fp):
os.makedirs(fp)
Expand Down
7 changes: 2 additions & 5 deletions run.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

CUDA_VISIBLE_DEVICES=0 python main.py \
--exp_name CAIN_test \
--exp_name CAIN_train \
--dataset vimeo90k \
--batch_size 16 \
--test_batch_size 16 \
Expand All @@ -11,7 +11,4 @@ CUDA_VISIBLE_DEVICES=0 python main.py \
--max_epoch 200 \
--lr 0.0002 \
--log_iter 100 \
# --mode test
# --resume True \
# --resume_exp SH_5_12
# --fix_encoder
# --mode test
19 changes: 9 additions & 10 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
# Training Helper Functions for making main.py clean
##########################

def load_dataset(dataset_str, data_root, batch_size, test_batch_size, num_workers, img_fmt='png'):
def load_dataset(dataset_str, data_root, batch_size, test_batch_size, num_workers, test_mode='medium', img_fmt='png'):

if dataset_str == 'snufilm':
from data.snufilm import get_loader
from data.vimeo90k import get_loader as get_test_loader
test_loader = get_loader('test', data_root, test_batch_size, shuffle=False, num_workers=num_workers, test_mode=test_mode)
return None, test_loader
elif dataset_str == 'vimeo90k':
from data.vimeo90k import get_loader
elif dataset_str == 'aim':
Expand All @@ -46,24 +47,22 @@ def load_dataset(dataset_str, data_root, batch_size, test_batch_size, num_worker
else:
raise NotImplementedError('Training / Testing for this dataset is not implemented.')

train_loader, _ = get_loader('train', data_root, batch_size, shuffle=True, num_workers=num_workers, n_frames=2)
if dataset_str == 'snufilm':
test_loader, _ = get_test_loader('test', data_root, test_batch_size, shuffle=False, num_workers=num_workers, n_frames=1)
elif dataset_str == 'aim':
test_loader, _ = get_loader('val', data_root, test_batch_size, shuffle=False, num_workers=num_workers, n_frames=1)
train_loader = get_loader('train', data_root, batch_size, shuffle=True, num_workers=num_workers)
if dataset_str == 'aim':
test_loader = get_loader('val', data_root, test_batch_size, shuffle=False, num_workers=num_workers)
else:
test_loader, _ = get_loader('test', data_root, test_batch_size, shuffle=False, num_workers=num_workers, n_frames=1)
test_loader = get_loader('test', data_root, test_batch_size, shuffle=False, num_workers=num_workers)

return train_loader, test_loader


def build_input(images, meta, is_training=True, include_edge=False, device=torch.device('cuda')):
def build_input(images, imgpaths, is_training=True, include_edge=False, device=torch.device('cuda')):
if isinstance(images[0], list):
images_gathered = [None, None, None]
for j in range(len(images[0])): # 3
_images = [images[k][j] for k in range(len(images))]
images_gathered[j] = torch.cat(_images, 0)
meta['imgpath'] = [p for _ in images for p in meta['imgpath']]
imgpaths = [p for _ in images for p in imgpaths]
images = images_gathered

im1, im2 = images[0].to(device), images[2].to(device)
Expand Down

0 comments on commit e325846

Please sign in to comment.