Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
lufficc committed Dec 6, 2018
0 parents commit 0d319e8
Show file tree
Hide file tree
Showing 46 changed files with 2,596 additions and 0 deletions.
23 changes: 23 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# compilation and distribution
__pycache__
*.pyc
*.so

# pytorch/python/numpy formats
*.pth
*.pkl
*.npy

# ipython/jupyter notebooks
*.ipynb
**/.ipynb_checkpoints/

# Editor temporaries
*.swn
*.swo
*.swp
*~

# Pycharm editor settings
.idea
.DS_Store
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2018 lufficc

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
73 changes: 73 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# High quality, fast, modular reference implementation of SSD in PyTorch 1.0


This repository implements [SSD (Single Shot MultiBox Detector)](https://arxiv.org/abs/1512.02325). The implementation is heavily influenced by the projects [ssd.pytorch](https://github.com/amdegroot/ssd.pytorch), [pytorch-ssd](https://github.com/qfgaohao/pytorch-ssd) and [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark). This repository aims to be the code base for researches based on SSD.

## Installation
### Requirements
1. Python3
1. PyTorch 1.0
1. yacs
1. GCC >= 4.9
1. OpenCV
### Build
```
# build nms
cd ext
python build.py build_ext develop
```

## Performance
### Origin Paper:

| | VOC2007 test |
| :-----: | :----------: |
| SSD300* | 77.2 |
| SSD512* | 79.8 |

### Our Implementation:

| | VOC2007 test |
| :-----: | :----------: |
| SSD300* | 77.8 |
| SSD512* | - |

### Details:

<table>
<thead>
<tr>
<th></th>
<th>VOC2007 test</th>
</tr>
</thead>
<tbody>
<tr>
<td>SSD300*</td>
<td><pre><code>mAP: 0.7783
aeroplane : 0.8252
bicycle : 0.8445
bird : 0.7597
boat : 0.7102
bottle : 0.5275
bus : 0.8643
car : 0.8660
cat : 0.8741
chair : 0.6179
cow : 0.8279
diningtable : 0.7862
dog : 0.8519
horse : 0.8630
motorbike : 0.8515
person : 0.8024
pottedplant : 0.5079
sheep : 0.7685
sofa : 0.7926
train : 0.8704
tvmonitor : 0.7554</code></pre></td>
</tr>
<tr>
<td>SSD512*</td>
<td><pre><code>-</code></pre></td>
</tr>
</tbody></table>
13 changes: 13 additions & 0 deletions configs/ssd300_voc0712.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
MODEL:
NUM_CLASSES: 21
INPUT:
IMAGE_SIZE: 300
DATASETS:
TRAIN: ("voc_2007_train", "voc_2007_val", "voc_2012_train", "voc_2012_val")
TEST: ("voc_2007_test", )
SOLVER:
MAX_ITER: 120000
LR_STEPS: [80000, 100000]
GAMMA: 0.1
BATCH_SIZE: 32
LR: 1e-3
20 changes: 20 additions & 0 deletions configs/ssd512_voc0712.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
MODEL:
NUM_CLASSES: 21
PRIORS:
FEATURE_MAPS: [64, 32, 16, 8, 4, 2, 1]
STRIDES: [8, 16, 32, 64, 128, 256, 512]
MIN_SIZES: [35.84, 76.8, 153.6, 230.4, 307.2, 384.0, 460.8]
MAX_SIZES: [76.8, 153.6, 230.4, 307.2, 384.0, 460.8, 537.65]
ASPECT_RATIOS: [[2], [2, 3], [2, 3], [2, 3], [2, 3], [2], [2]]
BOXES_PER_LOCATION: [4, 6, 6, 6, 6, 4, 4]
INPUT:
IMAGE_SIZE: 512
DATASETS:
TRAIN: ("voc_2007_train", "voc_2007_val", "voc_2012_train", "voc_2012_val")
TEST: ("voc_2007_test", )
SOLVER:
MAX_ITER: 120000
LR_STEPS: [80000, 100000]
GAMMA: 0.1
BATCH_SIZE: 24
LR: 1e-3
95 changes: 95 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import glob
import os

import torch
from PIL import Image
from tqdm import tqdm
from ssd.config import cfg
from ssd.modeling.predictor import Predictor
from ssd.modeling.vgg_ssd import build_ssd_model
from ssd.datasets.voc_dataset import VOCDataset
import argparse
import numpy as np

from ssd.utils.viz import draw_bounding_boxes


def run_demo(cfg, weights_file, iou_threshold, score_threshold, images_dir, output_dir, dataset_type):
if dataset_type == "voc":
class_names = VOCDataset.class_names
else:
raise NotImplementedError('Not implemented now.')

device = torch.device(cfg.MODEL.DEVICE)
model = build_ssd_model(cfg, is_test=True)
model.load(weights_file)
print('Loaded weights from {}.'.format(weights_file))
model = model.to(device)
predictor = Predictor(cfg=cfg,
model=model,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
device=device)
cpu_device = torch.device("cpu")

image_paths = glob.glob(os.path.join(images_dir, '*.jpg'))

if not os.path.exists(output_dir):
os.makedirs(output_dir)

for image_path in tqdm(image_paths):
image = Image.open(image_path).convert("RGB")
image = np.array(image)
output = predictor.predict(image)
boxes, labels, scores = [o.to(cpu_device).numpy() for o in output]
drawn_image = draw_bounding_boxes(image, boxes, labels, scores, class_names).astype(np.uint8)
image_name = os.path.basename(image_path)
Image.fromarray(drawn_image).save(os.path.join(output_dir, image_name))


def main():
parser = argparse.ArgumentParser(description="SSD Evaluation on VOC Dataset.")
parser.add_argument(
"--config-file",
default="",
metavar="FILE",
help="path to config file",
type=str,
)
parser.add_argument("--weights", type=str, help="Trained weights.")
parser.add_argument("--iou_threshold", type=float, default=0.5)
parser.add_argument("--score_threshold", type=float, default=0.5)
parser.add_argument("--images_dir", default='demo', type=str, help='Specify a image dir to do prediction.')
parser.add_argument("--output_dir", default='demo/result', type=str, help='Specify a image dir to predict.')
parser.add_argument("--dataset_type", default="voc", type=str, help='Specify dataset type. Currently support voc and coco.')

parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
print(args)

cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()

print("Loaded configuration file {}".format(args.config_file))
with open(args.config_file, "r") as cf:
config_str = "\n" + cf.read()
print(config_str)
print("Running with config:\n{}".format(cfg))

run_demo(cfg=cfg,
weights_file=args.weights,
iou_threshold=args.iou_threshold,
score_threshold=args.score_threshold,
images_dir=args.images_dir,
output_dir=args.output_dir,
dataset_type=args.dataset_type)


if __name__ == '__main__':
main()
Binary file added demo/000342.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/000542.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/003123.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/004101.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/008591.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
117 changes: 117 additions & 0 deletions eval_ssd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import os

import torch
from tqdm import tqdm
from ssd.config import cfg
from ssd.datasets import build_dataset
from ssd.modeling.predictor import Predictor
from ssd.modeling.vgg_ssd import build_ssd_model
from ssd.utils.eval_detection_voc import eval_detection_voc

import argparse
import numpy as np


def do_evaluation(cfg, model, test_dataset, output_dir):
class_names = test_dataset.class_names
device = torch.device(cfg.MODEL.DEVICE)
model.eval()
predictor = Predictor(cfg=cfg,
model=model,
iou_threshold=cfg.TEST.NMS_THRESHOLD,
score_threshold=cfg.TEST.CONFIDENCE_THRESHOLD,
device=device)

cpu_device = torch.device("cpu")

pred_boxes_list = []
pred_labels_list = []
pred_scores_list = []
gt_boxes_list = []
gt_labels_list = []
gt_difficults = []
for i in tqdm(range(len(test_dataset))):
image_id, annotation = test_dataset.get_annotation(i)
gt_boxes, gt_labels, is_difficult = annotation
gt_boxes_list.append(gt_boxes)
gt_labels_list.append(gt_labels)
gt_difficults.append(is_difficult.astype(np.bool))

image = test_dataset.get_image(i)
output = predictor.predict(image)
boxes, labels, scores = [o.to(cpu_device).numpy() for o in output]

pred_boxes_list.append(boxes)
pred_labels_list.append(labels)
pred_scores_list.append(scores)

result = eval_detection_voc(pred_bboxes=pred_boxes_list,
pred_labels=pred_labels_list,
pred_scores=pred_scores_list,
gt_bboxes=gt_boxes_list,
gt_labels=gt_labels_list,
gt_difficults=gt_difficults,
iou_thresh=0.5,
use_07_metric=True)

result_str = "mAP: {:.4f}\n".format(result["map"])
for i, ap in enumerate(result["ap"]):
if i == 0: # skip background
continue
result_str += "{:<16}: {:.4f}\n".format(class_names[i], ap)
print(result_str)
prediction_path = os.path.join(output_dir, "result.txt")
with open(prediction_path, "w") as f:
f.write(result_str)


def evaluation(cfg, weights_file, output_dir):
if not os.path.exists(output_dir):
os.makedirs(output_dir)

test_dataset = build_dataset(dataset_list=cfg.DATASETS.TEST)
print("Test dataset size: {}".format(len(test_dataset)))

device = torch.device(cfg.MODEL.DEVICE)
model = build_ssd_model(cfg, is_test=True)
model.load(weights_file)
print('Loaded weights from {}.'.format(weights_file))
model.to(device)
do_evaluation(cfg, model, test_dataset, output_dir)


def main():
parser = argparse.ArgumentParser(description='SSD Evaluation on VOC Dataset.')
parser.add_argument(
"--config-file",
default="",
metavar="FILE",
help="path to config file",
type=str,
)
parser.add_argument("--weights", type=str, help="Trained weights.")
parser.add_argument("--output_dir", default="eval_results", type=str, help="The directory to store evaluation results.")

parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
print(args)

cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()

print("Loaded configuration file {}".format(args.config_file))
with open(args.config_file, "r") as cf:
config_str = "\n" + cf.read()
print(config_str)
print("Running with config:\n{}".format(cfg))
evaluation(cfg, weights_file=args.weights, output_dir=args.output_dir)


if __name__ == '__main__':
main()
Empty file added ext/__init__.py
Empty file.
Loading

0 comments on commit 0d319e8

Please sign in to comment.