Skip to content

Commit 82d5a03

Browse files
committed
refine model.py
1 parent 74f512d commit 82d5a03

File tree

1 file changed

+30
-19
lines changed

1 file changed

+30
-19
lines changed

wpodnet/model.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,28 @@
11
import torch
2-
import torch.nn as nn
32

43

5-
class BasicConvBlock(nn.Module):
4+
class BasicConvBlock(torch.nn.Module):
65
def __init__(self, in_channels: int, out_channels: int):
76
super().__init__()
8-
self.conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
9-
self.bn_layer = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.001)
10-
self.act_layer = nn.ReLU(inplace=True)
7+
self.conv_layer = torch.nn.Conv2d(
8+
in_channels, out_channels, kernel_size=3, padding=1
9+
)
10+
self.bn_layer = torch.nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.001)
11+
self.act_layer = torch.nn.ReLU(inplace=True)
1112

1213
def forward(self, x):
1314
x = self.conv_layer(x)
1415
x = self.bn_layer(x)
1516
return self.act_layer(x)
1617

1718

18-
class ResBlock(nn.Module):
19+
class ResBlock(torch.nn.Module):
1920
def __init__(self, channels: int):
2021
super().__init__()
2122
self.conv_block = BasicConvBlock(channels, channels)
22-
self.sec_layer = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
23-
self.bn_layer = nn.BatchNorm2d(channels, momentum=0.99, eps=0.001)
24-
self.act_layer = nn.ReLU(inplace=True)
23+
self.sec_layer = torch.nn.Conv2d(channels, channels, kernel_size=3, padding=1)
24+
self.bn_layer = torch.nn.BatchNorm2d(channels, momentum=0.99, eps=0.001)
25+
self.act_layer = torch.nn.ReLU(inplace=True)
2526

2627
def forward(self, x):
2728
h = self.conv_block(x)
@@ -30,35 +31,45 @@ def forward(self, x):
3031
return self.act_layer(x + h)
3132

3233

33-
class WPODNet(nn.Module):
34+
class WPODNet(torch.nn.Module):
35+
"""
36+
WPODNet in PyTorch.
37+
38+
The original architecture is built in Keras: https://github.com/sergiomsilva/alpr-unconstrained/blob/master/create-model.py
39+
"""
40+
41+
# https://github.com/sergiomsilva/alpr-unconstrained/blob/master/src/keras_utils.py#L43-L44
42+
stride = 16 # net_stride
43+
scale_factor = 7.75 # side
44+
3445
def __init__(self):
3546
super().__init__()
36-
self.backbone = nn.Sequential(
47+
self.backbone = torch.nn.Sequential(
3748
BasicConvBlock(3, 16),
3849
BasicConvBlock(16, 16),
39-
nn.MaxPool2d(2),
50+
torch.nn.MaxPool2d(2),
4051
BasicConvBlock(16, 32),
4152
ResBlock(32),
42-
nn.MaxPool2d(2),
53+
torch.nn.MaxPool2d(2),
4354
BasicConvBlock(32, 64),
4455
ResBlock(64),
4556
ResBlock(64),
46-
nn.MaxPool2d(2),
57+
torch.nn.MaxPool2d(2),
4758
BasicConvBlock(64, 64),
4859
ResBlock(64),
4960
ResBlock(64),
50-
nn.MaxPool2d(2),
61+
torch.nn.MaxPool2d(2),
5162
BasicConvBlock(64, 128),
5263
ResBlock(128),
5364
ResBlock(128),
5465
ResBlock(128),
55-
ResBlock(128)
66+
ResBlock(128),
5667
)
57-
self.prob_layer = nn.Conv2d(128, 2, kernel_size=3, padding=1)
58-
self.bbox_layer = nn.Conv2d(128, 6, kernel_size=3, padding=1)
68+
self.prob_layer = torch.nn.Conv2d(128, 2, kernel_size=3, padding=1)
69+
self.bbox_layer = torch.nn.Conv2d(128, 6, kernel_size=3, padding=1)
5970

6071
# Registry a dummy tensor for retrieve the attached device
61-
self.register_buffer('dummy', torch.Tensor(), persistent=False)
72+
self.register_buffer("dummy", torch.Tensor(), persistent=False)
6273

6374
@property
6475
def device(self) -> torch.device:

0 commit comments

Comments
 (0)