Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
dontLoveBugs committed May 23, 2019
0 parents commit 2d36249
Show file tree
Hide file tree
Showing 6 changed files with 433 additions and 0 deletions.
122 changes: 122 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
.idea/
/.idea
*.iml

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# JPG PNG
*.jpg
*.png
131 changes: 131 additions & 0 deletions Module/SideWindowFilter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
@Time : 2019-05-23 16:32
@Author : Wang Xin
@Email : [email protected]
@File : SideWindowFilter.py
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class SideWindowFilter(nn.Module):

def __init__(self, radius, iteration):
super(SideWindowFilter, self).__init__()
self.radius = radius
self.iteration = iteration
self.kernel_size = 2 * self.radius + 1
self.filter = nn.Parameter(torch.Tensor(1, 1, self.kernel_size, self.kernel_size))

def forward(self, im):
b, c, h, w = im.size()

d = torch.zeros(b, 8, h, w, dtype=torch.float)
res = im.clone()

L, R, U, D = [self.filter.clone() for _ in range(4)]

L[:, :, :, self.radius + 1:] = 0
R[:, :, :, 0: self.radius] = 0
U[:, :, self.radius + 1:, :] = 0
D[:, :, 0: self.radius, :] = 0

NW, NE, SW, SE = U.clone(), U.clone(), D.clone(), D.clone()

L, R, U, D = L / ((self.radius + 1) * self.kernel_size), R / ((self.radius + 1) * self.kernel_size), \
U / ((self.radius + 1) * self.kernel_size), D / ((self.radius + 1) * self.kernel_size)

NW[:, :, :, self.radius + 1:] = 0
NE[:, :, :, 0: self.radius] = 0
SW[:, :, :, self.radius + 1:] = 0
SE[:, :, :, 0: self.radius] = 0

NW, NE, SW, SE = NW / ((self.radius + 1) ** 2), NE / ((self.radius + 1) ** 2), \
SW / ((self.radius + 1) ** 2), SE / ((self.radius + 1) ** 2)

print('L:', L)
print('R:', R)
print('U:', U)
print('D:', D)
print('NW:', NW)
print('NE:', NE)
print('SW:', SW)
print('SE:', SE)

for ch in range(c):
im_ch = im[:, ch, ::].clone().view(b, 1, h, w)
# print('im size in each channel:', im_ch.size())

for i in range(self.iteration):
# print('###', (F.conv2d(input=im_ch, weight=L, padding=(self.radius, self.radius)) / sum_L -
# im_ch).size(), d[:, 0,::].size())
d[:, 0, ::] = F.conv2d(input=im_ch, weight=L, padding=(self.radius, self.radius)) - im_ch
d[:, 1, ::] = F.conv2d(input=im_ch, weight=R, padding=(self.radius, self.radius)) - im_ch
d[:, 2, ::] = F.conv2d(input=im_ch, weight=U, padding=(self.radius, self.radius)) - im_ch
d[:, 3, ::] = F.conv2d(input=im_ch, weight=D, padding=(self.radius, self.radius)) - im_ch
d[:, 4, ::] = F.conv2d(input=im_ch, weight=NW, padding=(self.radius, self.radius)) - im_ch
d[:, 5, ::] = F.conv2d(input=im_ch, weight=NE, padding=(self.radius, self.radius)) - im_ch
d[:, 6, ::] = F.conv2d(input=im_ch, weight=SW, padding=(self.radius, self.radius)) - im_ch
d[:, 7, ::] = F.conv2d(input=im_ch, weight=SE, padding=(self.radius, self.radius)) - im_ch

d_abs = torch.abs(d)
print('im_ch', im_ch)
print('dm = ', d_abs.shape, d_abs)
mask_min = torch.argmin(d_abs, dim=1, keepdim=True)
print('mask min = ', mask_min.shape, mask_min)
dm = torch.gather(input=d, dim=1, index=mask_min)
im_ch = dm + im_ch

res[:, ch, ::] = im_ch
return res


class Net(nn.Module):

def __init__(self):
super(Net, self).__init__()
self.filter = SideWindowFilter(radius=1, iteration=1)

def forward(self, x):
return self.filter(x)


if __name__ == '__main__':
s = SideWindowFilter(radius=1, iteration=1)
from PIL import Image

import cv2
img = cv2.imread('../lena.png', flags=0)
img = torch.tensor(img, dtype=torch.float)
print('img ori = ', img)

print(len(img.size()))
if len(img.size()) == 2:
h, w = img.size()
img = img.view(-1, 1, h, w)
else:
c, h, w = img.size()
img = img.view(-1, c, h, w)
print('img = ', img.shape)

model = Net()
res = model(img)
res.mean().backward()

print('res = ', res.shape, res)
import numpy as np
if res.size(1) == 3:
img_res = np.transpose(np.squeeze(res.data.numpy()), (1, 2, 0))
else:
img_res = np.squeeze(res.data.numpy())

# print(img_res.shape, img_res)
img_res = img_res
img_res = img_res.astype(np.uint8)
print('img res:', img_res)
img_res = Image.fromarray(img_res) # numpy to image
img_res.show()
8 changes: 8 additions & 0 deletions Module/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
@Time : 2019-05-23 16:32
@Author : Wang Xin
@Email : [email protected]
@File : __init__.py.py
"""
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SideWindowFilter
39 changes: 39 additions & 0 deletions SideWindowFilter.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
function result=SideWindowBoxFilter(im, radius, iteration)
%papers: 1) Sub-window Box Filter, Y.Gong, B.Liu, X.Hou, G.Qiu, VCIP2018, Dec.09, Taiwan
% 2) Side Window Filtering, H.Yin, Y.Gong, G.Qiu. CVPR2019
%implemented by Yuanhao Gong

r = radius;
k = ones(2*r+1,1)/(2*r+1); %separable kernel
k_L=k; k_L(r+2:end)=0; k_L = k_L/sum(k_L); %half kernel
k_R=flipud(k_L);
m = size(im,1)+2*r; n = size(im,2)+2*r; total = m*n;
[row, col]=ndgrid(1:m,1:n);
offset = row + m*(col-1) - total;
im = single(im);
result = im;
d = zeros(m,n,8,'single');

for ch=1:size(im,3)
U = padarray(im(:,:,ch),[r,r],'replicate');
for i = 1:iteration
%all projection distances
d(:,:,1) = conv2(k_L, k_L, U,'same') - U;
d(:,:,2) = conv2(k_L, k_R, U,'same') - U;
d(:,:,3) = conv2(k_R, k_L, U,'same') - U;
d(:,:,4) = conv2(k_R, k_R, U,'same') - U;
d(:,:,5) = conv2(k_L, k, U,'same') - U;
d(:,:,6) = conv2(k_R, k, U,'same') - U;
d(:,:,7) = conv2(k, k_L, U,'same') - U;
d(:,:,8) = conv2(k, k_R, U,'same') - U;

%find the minimal signed distance
tmp = abs(d);
[~,ind] = min(tmp,[],3);
index = offset+total*ind;
dm = d(index); %signed minimal distance
%update
U = U + dm;
end
result(:,:,ch) = U(r+1:end-r,r+1:end-r);
end
Loading

0 comments on commit 2d36249

Please sign in to comment.