-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
171 lines (128 loc) · 6.78 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
class Model(nn.Module):
def __init__(self, height, width, backbone):
super(Model, self).__init__()
self.height = height
self.width = width
self.n_size = (5, 5, 5, 5, 5)
self.image_normalization_mean = [0.485, 0.456, 0.406]
self.image_normalization_std = [0.229, 0.224, 0.225]
self.stage1_1 = nn.Sequential(*list(backbone.features.children())[:4])
self.stage1_2 = nn.Sequential(*list(backbone.features.children())[4:9])
self.stage1_3 = nn.Sequential(*list(backbone.features.children())[9:16])
self.stage1_4 = nn.Sequential(*list(backbone.features.children())[16:23])
self.stage1_5 = nn.Sequential(*list(backbone.features.children())[23:30])
self.module1 = (self.get_module(n_size = self.n_size[0], n_conv = 3, n_dconv = 0, ch = [64, 64, 64]))
self.module2 = (self.get_module(n_size = self.n_size[1], n_conv = 3, n_dconv = 1, ch = [64, 64, 64, 64]))
self.module3 = (self.get_module(n_size = self.n_size[2], n_conv = 3, n_dconv = 2, ch = [64, 64, 64, 64, 64]))
self.module4 = (self.get_module(n_size = self.n_size[3], n_conv = 3, n_dconv = 3, ch = [64, 64, 64, 64, 64, 64]))
self.module5 = (self.get_module(n_size = self.n_size[4], n_conv = 3, n_dconv = 4, ch = [64, 64, 64, 64, 64, 64, 64]))
self.ct_conv1 = (nn.Sequential(
nn.Conv2d(320, 128, kernel_size=1, stride=1),
nn.BatchNorm2d(128),
nn.LeakyReLU()))
self.ct_conv2 = (nn.Sequential(
nn.Conv2d(128, 64, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(64),
nn.LeakyReLU()))
self.ct_conv3 = (nn.Sequential(
nn.Conv2d(64, 32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.LeakyReLU()))
self.ct_conv4 = (nn.Sequential(
nn.Conv2d(32, 10, kernel_size=5, stride=1, padding=2)))
def forward(self, img1, img2):
layer1_1 = self.stage1_1(img1)
layer1_2 = self.stage1_1(img2)
corr1, depth1 = self.correlation_layer(layer1_1, layer1_2, self.n_size[0], int(self.height), int(self.width))
layer2_1 = self.stage1_2(layer1_1)
layer2_2 = self.stage1_2(layer1_2)
corr2, depth2 = self.correlation_layer(layer2_1, layer2_2, self.n_size[1], int(self.height/2), int(self.width/2))
layer3_1 = self.stage1_3(layer2_1)
layer3_2 = self.stage1_3(layer2_2)
corr3, depth3 = self.correlation_layer(layer3_1, layer3_2, self.n_size[2], int(self.height/4), int(self.width/4))
layer4_1 = self.stage1_4(layer3_1)
layer4_2 = self.stage1_4(layer3_2)
corr4, depth4 = self.correlation_layer(layer4_1, layer4_2, self.n_size[3], int(self.height/8), int(self.width/8))
layer5_1 = self.stage1_5(layer4_1)
layer5_2 = self.stage1_5(layer4_2)
corr5, depth5 = self.correlation_layer(layer5_1, layer5_2, self.n_size[4], int(self.height/16), int(self.width/16))
l_mod1 = self.module1(corr1)
l_mod2 = self.module2(corr2)
l_mod3 = self.module3(corr3)
l_mod4 = self.module4(corr4)
l_mod5 = self.module5(corr5)
ct_layer = torch.cat((l_mod1, l_mod2, l_mod3, l_mod4, l_mod5), dim=1)
out = self.ct_conv1(ct_layer)
out = self.ct_conv2(out)
out = self.ct_conv3(out)
out = self.ct_conv4(out)
return out
def get_depths(self, n_size):
""" Calculates depth of correlation map using window size
Args:
n_size (int): Window size for correlation to determine depth of correlation map
Returns:
depth (int): depth of correlation map produced
"""
max_displacement = int(math.ceil(n_size/2.0))
stride_2 = 2
assert(stride_2 <= n_size)
depth = int(math.floor(((2.0 * max_displacement) + 1) / stride_2) ** 2)
return depth
def get_module(self, n_size, n_conv, n_dconv, ch):
"""
Args:
n_size (int): Window size for correlation to determine depth of correlation map
n_conv (int): Number of convolution layers in the module
n_dconv (int): Number of deconvolution layers in the module
ch (list): List of channels in each layer
Returns:
Sequential module with n_conv convolutions and n_dconv deconvolutions
"""
depth = self.get_depths(n_size)
assert(n_conv + n_dconv == len(ch))
ch = [depth, *ch]
module = []
for i in range(len(ch)-1):
if i < n_dconv:
module.append(nn.ConvTranspose2d(ch[i],ch[i+1], kernel_size=2, stride=2))
module.append(nn.BatchNorm2d(ch[i+1]))
else:
module.append(nn.Conv2d(ch[i], ch[i+1], kernel_size=5, stride=1, padding=2))
module.append(nn.BatchNorm2d(ch[i+1]))
module.append(nn.LeakyReLU())
return nn.Sequential(*module)
def correlation_layer(self, map1, map2, n_size, h, w):
""" Returns Correlation Map between map1 and map2 as well as its depth value
Args:
map1 (Tensor): Feature Map 1
map2 (Tensor): Feature Map 2
n_size (int): Window size for correlation to determine depth of correlation map
h (int): height of feature map
w (int): width of feature map
"""
HEIGHT = int(h)
WIDTH = int(w)
max_displacement = int(math.ceil(n_size/2.0))
stride_2 = 2
assert(stride_2 <= n_size)
depth = int(math.floor(((2.0 * max_displacement) + 1) / stride_2) ** 2)
out = []
for i in range(-max_displacement+1, max_displacement, stride_2):
for j in range(-max_displacement+1, max_displacement, stride_2):
padded_a = F.pad(map1, (0,abs(j),0,abs(i)), mode='constant', value=0)
padded_b = F.pad(map2, (abs(j),0,abs(i),0), mode='constant', value=0)
m = padded_a * padded_b
height_start_idx = 0 if i <= 0 else i
height_end_idx = height_start_idx + HEIGHT
width_start_idx = 0 if j <= 0 else j
width_end_idx = width_start_idx + WIDTH
cut = m[:, :, height_start_idx:height_end_idx, width_start_idx:width_end_idx]
final = torch.sum(cut, 1)
out.append(final)
corr = torch.stack(out, 1)
return corr, depth