-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel_resnet.py
75 lines (51 loc) · 2.45 KB
/
model_resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#pylint: disable=E1101
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
class Net(nn.Module):
def __init__(self, num_classes = 1):
super(Net, self).__init__()
self.num_classes = num_classes
self.resnet = models.resnet50(pretrained=True)
for i,param in enumerate(self.resnet.parameters()):
param.requires_grad = False
self.a_convT2d = nn.ConvTranspose2d(in_channels=2048, out_channels=256, kernel_size=4, stride=2, padding=1)
self.b_convT2d = nn.ConvTranspose2d(in_channels=1280, out_channels=128, kernel_size=4, stride=2, padding=1)
self.c_convT2d = nn.ConvTranspose2d(in_channels=640, out_channels=64, kernel_size=4, stride=2, padding=1)
self.convT2d3 = nn.ConvTranspose2d(in_channels=320, out_channels=self.num_classes, kernel_size=4, stride=4, padding=0)
def setTrainableLayers(self, trainable_layers):
for name, node in self.resnet.named_children():
unlock = name in trainable_layers
for param in node.parameters():
param.requires_grad = unlock
def forward(self, x):
skipConnections = {}
x = self.resnet.conv1(x)
x = self.resnet.bn1(x)
x = self.resnet.relu(x)
x = self.resnet.maxpool(x)
skipConnections[1] = x = self.resnet.layer1(x) # [10, 256, 56, 56]
skipConnections[2] = x = self.resnet.layer2(x) # [10, 512, 28, 28]
skipConnections[3] = x = self.resnet.layer3(x) # [10, 1024, 14, 14]
skipConnections[4] = x = self.resnet.layer4(x) # [10, 2048, 7, 7]
x = self.a_convT2d(x) # [10, 256, 14, 14]
x = torch.cat((x,skipConnections[3]), 1)
x = self.b_convT2d(x) # [10, 128, 28, 28]
x = torch.cat((x, skipConnections[2]), 1)
x = self.c_convT2d(x) # [10, 64, 56, 56]
x = torch.cat((x, skipConnections[1]), 1)
x = self.convT2d3(x) # [10, num_classes, 224, 224]
x = nn.Sigmoid()(x)
x = x.view(x.size()[0], -1, self.num_classes)
return x
def save(self, filename):
# write out weights
torch.save(self.state_dict(), filename)
def load(self, filename):
self.load_state_dict(torch.load(filename))
if __name__ == "__main__":
net = Net()
net.freeze(['layer4'])
# write out weights
net.load('./output/output-resnet_01111617/ultrasound.pth')