Skip to content

Commit 533f63a

Browse files
Support adding for maxpooling and Tiny YOLO
1 parent e904d3c commit 533f63a

File tree

1 file changed

+29
-8
lines changed

1 file changed

+29
-8
lines changed

darknet.py

+29-8
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,11 @@ def parse_cfg(cfgfile):
6262
key,value = line.split("=")
6363
block[key.rstrip()] = value.lstrip()
6464
blocks.append(block)
65+
6566
return blocks
6667
# print('\n\n'.join([repr(x) for x in blocks]))
67-
68+
69+
import pickle as pkl
6870

6971
class MaxPoolStride1(nn.Module):
7072
def __init__(self, kernel_size):
@@ -73,9 +75,10 @@ def __init__(self, kernel_size):
7375
self.pad = kernel_size - 1
7476

7577
def forward(self, x):
76-
padded_x = F.pad(x, (0, self.pad, 0, self.pad), mode = "replicate")
77-
pooled_x = F.max_pool2d(padded_x, self.kernel_size, padding = self.pad)
78+
padded_x = F.pad(x, (0,self.pad,0,self.pad), mode="replicate")
79+
pooled_x = nn.MaxPool2d(self.kernel_size, self.pad)(padded_x)
7880
return pooled_x
81+
7982

8083
class EmptyLayer(nn.Module):
8184
def __init__(self):
@@ -237,13 +240,24 @@ def create_modules(blocks):
237240

238241

239242
#shortcut corresponds to skip connection
240-
if x["type"] == "shortcut":
243+
elif x["type"] == "shortcut":
241244
from_ = int(x["from"])
242245
shortcut = EmptyLayer()
243246
module.add_module("shortcut_{}".format(index), shortcut)
247+
248+
249+
elif x["type"] == "maxpool":
250+
stride = int(x["stride"])
251+
size = int(x["size"])
252+
if stride != 1:
253+
maxpool = nn.MaxPool2d(size, stride)
254+
else:
255+
maxpool = MaxPoolStride1(size)
256+
257+
module.add_module("maxpool_{}".format(index), maxpool)
244258

245259
#Yolo is the detection layer
246-
if x["type"] == "yolo":
260+
elif x["type"] == "yolo":
247261
mask = x["mask"].split(",")
248262
mask = [int(x) for x in mask]
249263

@@ -255,15 +269,19 @@ def create_modules(blocks):
255269

256270
detection = DetectionLayer(anchors)
257271
module.add_module("Detection_{}".format(index), detection)
272+
258273

259274

260-
275+
else:
276+
print("Something I dunno")
277+
assert False
261278

262279

263280
module_list.append(module)
264281
prev_filters = filters
265282
output_filters.append(filters)
266283
index += 1
284+
267285

268286
return (net_info, module_list)
269287

@@ -295,9 +313,8 @@ def forward(self, x, CUDA):
295313
write = 0
296314
for i in range(len(modules)):
297315

298-
299316
module_type = (modules[i]["type"])
300-
if module_type == "convolutional" or module_type == "upsample":
317+
if module_type == "convolutional" or module_type == "upsample" or module_type == "maxpool":
301318

302319
x = self.module_list[i](x)
303320
outputs[i] = x
@@ -320,13 +337,16 @@ def forward(self, x, CUDA):
320337
map1 = outputs[i + layers[0]]
321338
map2 = outputs[i + layers[1]]
322339

340+
323341
x = torch.cat((map1, map2), 1)
324342
outputs[i] = x
325343

326344
elif module_type == "shortcut":
327345
from_ = int(modules[i]["from"])
328346
x = outputs[i-1] + outputs[i+from_]
329347
outputs[i] = x
348+
349+
330350

331351
elif module_type == 'yolo':
332352

@@ -353,6 +373,7 @@ def forward(self, x, CUDA):
353373
detections = torch.cat((detections, x), 1)
354374

355375
outputs[i] = outputs[i-1]
376+
356377

357378

358379
try:

0 commit comments

Comments
 (0)