-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
50 changed files
with
5,234 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
import torch | ||
from torch import nn as nn | ||
from torch.nn import functional as F | ||
from torch.nn import init as init | ||
from torch.nn.modules.batchnorm import _BatchNorm | ||
|
||
def pixel_unshuffle(x, scale): | ||
""" Pixel unshuffle. | ||
Args: | ||
x (Tensor): Input feature with shape (b, c, hh, hw). | ||
scale (int): Downsample ratio. | ||
Returns: | ||
Tensor: the pixel unshuffled feature. | ||
""" | ||
b, c, hh, hw = x.size() | ||
out_channel = c * (scale**2) | ||
assert hh % scale == 0 and hw % scale == 0 | ||
h = hh // scale | ||
w = hw // scale | ||
x_view = x.view(b, c, h, scale, w, scale) | ||
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) | ||
|
||
def make_layer(basic_block, num_basic_block, **kwarg): | ||
"""Make layers by stacking the same blocks. | ||
Args: | ||
basic_block (nn.module): nn.module class for basic block. | ||
num_basic_block (int): number of blocks. | ||
Returns: | ||
nn.Sequential: Stacked blocks in nn.Sequential. | ||
""" | ||
layers = [] | ||
for _ in range(num_basic_block): | ||
layers.append(basic_block(**kwarg)) | ||
return nn.Sequential(*layers) | ||
|
||
|
||
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): | ||
"""Initialize network weights. | ||
Args: | ||
module_list (list[nn.Module] | nn.Module): Modules to be initialized. | ||
scale (float): Scale initialized weights, especially for residual | ||
blocks. Default: 1. | ||
bias_fill (float): The value to fill bias. Default: 0 | ||
kwargs (dict): Other arguments for initialization function. | ||
""" | ||
if not isinstance(module_list, list): | ||
module_list = [module_list] | ||
for module in module_list: | ||
for m in module.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
init.kaiming_normal_(m.weight, **kwargs) | ||
m.weight.data *= scale | ||
if m.bias is not None: | ||
m.bias.data.fill_(bias_fill) | ||
elif isinstance(m, nn.Linear): | ||
init.kaiming_normal_(m.weight, **kwargs) | ||
m.weight.data *= scale | ||
if m.bias is not None: | ||
m.bias.data.fill_(bias_fill) | ||
elif isinstance(m, _BatchNorm): | ||
init.constant_(m.weight, 1) | ||
if m.bias is not None: | ||
m.bias.data.fill_(bias_fill) | ||
|
||
class ResidualDenseBlock(nn.Module): | ||
"""Residual Dense Block. | ||
Used in RRDB block in ESRGAN. | ||
Args: | ||
num_feat (int): Channel number of intermediate features. | ||
num_grow_ch (int): Channels for each growth. | ||
""" | ||
|
||
def __init__(self, num_feat=64, num_grow_ch=32): | ||
super(ResidualDenseBlock, self).__init__() | ||
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) | ||
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) | ||
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) | ||
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) | ||
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) | ||
|
||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | ||
|
||
# initialization | ||
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) | ||
|
||
def forward(self, x): | ||
x1 = self.lrelu(self.conv1(x)) | ||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) | ||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) | ||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) | ||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) | ||
# Emperically, we use 0.2 to scale the residual for better performance | ||
return x5 * 0.2 + x | ||
|
||
|
||
class RRDB(nn.Module): | ||
"""Residual in Residual Dense Block. | ||
Used in RRDB-Net in ESRGAN. | ||
Args: | ||
num_feat (int): Channel number of intermediate features. | ||
num_grow_ch (int): Channels for each growth. | ||
""" | ||
|
||
def __init__(self, num_feat, num_grow_ch=32): | ||
super(RRDB, self).__init__() | ||
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) | ||
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) | ||
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) | ||
|
||
def forward(self, x): | ||
out = self.rdb1(x) | ||
out = self.rdb2(out) | ||
out = self.rdb3(out) | ||
# Emperically, we use 0.2 to scale the residual for better performance | ||
return out * 0.2 + x | ||
|
||
|
||
class RRDBNet(nn.Module): | ||
"""Networks consisting of Residual in Residual Dense Block, which is used | ||
in ESRGAN. | ||
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. | ||
We extend ESRGAN for scale x2 and scale x1. | ||
Note: This is one option for scale 1, scale 2 in RRDBNet. | ||
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size | ||
and enlarge the channel size before feeding inputs into the main ESRGAN architecture. | ||
Args: | ||
num_in_ch (int): Channel number of inputs. | ||
num_out_ch (int): Channel number of outputs. | ||
num_feat (int): Channel number of intermediate features. | ||
Default: 64 | ||
num_block (int): Block number in the trunk network. Defaults: 23 | ||
num_grow_ch (int): Channels for each growth. Default: 32. | ||
""" | ||
|
||
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): | ||
super(RRDBNet, self).__init__() | ||
self.scale = scale | ||
if scale == 2: | ||
num_in_ch = num_in_ch * 4 | ||
elif scale == 1: | ||
num_in_ch = num_in_ch * 16 | ||
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) | ||
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) | ||
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | ||
# upsample | ||
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | ||
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | ||
self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | ||
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | ||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) | ||
|
||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | ||
|
||
def forward(self, x): | ||
# import pdb | ||
# pdb.set_trace() | ||
if x.shape[0] != 1 and len(x.shape) == 3: | ||
x = x.unsqueeze(0) | ||
if self.scale == 2: | ||
feat = pixel_unshuffle(x, scale=2) | ||
elif self.scale == 1: | ||
feat = pixel_unshuffle(x, scale=4) | ||
else: | ||
feat = x | ||
feat = self.conv_first(feat) | ||
body_feat = self.conv_body(self.body(feat)) | ||
feat = feat + body_feat | ||
# upsample | ||
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) | ||
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) | ||
feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode='nearest'))) | ||
out = self.conv_last(self.lrelu(self.conv_hr(feat))) | ||
return out | ||
|
||
|
||
if __name__ == '__main__': | ||
from thop import profile | ||
from thop import clever_format | ||
x = torch.randn(1,3,28,28) | ||
|
||
rrdb = RRDBNet(3, 3) | ||
flops, params = profile(rrdb, inputs=x) | ||
flops, params = clever_format([flops, params], "%.3f") | ||
print(flops, params) | ||
|
||
|
||
print('done') |
Oops, something went wrong.