Skip to content

Commit

Permalink
Add COCO training option and cleanup training script
Browse files Browse the repository at this point in the history
  • Loading branch information
amdegroot committed Mar 30, 2018
1 parent dd85aff commit 8dd3865
Show file tree
Hide file tree
Showing 12 changed files with 302 additions and 163 deletions.
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ A [PyTorch](http://pytorch.org/) implementation of [Single Shot MultiBox Detecto
- Clone this repository.
* Note: We currently only support Python 3+.
- Then download the dataset by following the [instructions](#datasets) below.
- We now support [Visdom](https://github.com/facebookresearch/visdom) for real-time loss visualization during training!
* To use Visdom in the browser:
- We now support [Visdom](https://github.com/facebookresearch/visdom) for real-time loss visualization during training!
* To use Visdom in the browser:
```Shell
# First install Python server and client
# First install Python server and client
pip install visdom
# Start the server (probably in a screen or tmux)
python -m visdom.server
Expand All @@ -40,7 +40,7 @@ To make things easy, we provide bash scripts to handle the dataset downloads and


### COCO
Microsoft COCO: Common Objects in Context
Microsoft COCO: Common Objects in Context

##### Download COCO 2014
```Shell
Expand Down Expand Up @@ -83,7 +83,7 @@ python train.py
* For training, an NVIDIA GPU is strongly recommended for speed.
* For instructions on Visdom usage/installation, see the <a href='#installation'>Installation</a> section.
* You can pick-up training from a checkpoint by specifying the path as one of the training parameters (again, see `train.py` for options)

## Evaluation
To evaluate a trained network:

Expand All @@ -107,31 +107,31 @@ You can specify the parameters listed in the `eval.py` file by flagging them or
| 77.2 % | 77.26 % | 58.12% | 77.43 % |

##### FPS
**GTX 1060:** ~45.45 FPS
**GTX 1060:** ~45.45 FPS

## Demos

### Use a pre-trained SSD network for detection

#### Download a pre-trained network
- We are trying to provide PyTorch `state_dicts` (dict of weight tensors) of the latest SSD model definitions trained on different datasets.
- Currently, we provide the following PyTorch models:
- Currently, we provide the following PyTorch models:
* SSD300 trained on VOC0712 (newest PyTorch weights)
- https://s3.amazonaws.com/amdegroot-models/ssd300_mAP_77.43_v2.pth
* SSD300 trained on VOC0712 (original Caffe weights)
- https://s3.amazonaws.com/amdegroot-models/ssd_300_VOC0712.pth
- Our goal is to reproduce this table from the [original paper](http://arxiv.org/abs/1512.02325)
- Our goal is to reproduce this table from the [original paper](http://arxiv.org/abs/1512.02325)
<p align="left">
<img src="http://www.cs.unc.edu/~wliu/papers/ssd_results.png" alt="SSD results on multiple datasets" width="800px"></p>

### Try the demo notebook
- Make sure you have [jupyter notebook](http://jupyter.readthedocs.io/en/latest/install.html) installed.
- Two alternatives for installing jupyter notebook:
1. If you installed PyTorch with [conda](https://www.continuum.io/downloads) (recommended), then you should already have it. (Just navigate to the ssd.pytorch cloned repo and run):
`jupyter notebook`
1. If you installed PyTorch with [conda](https://www.continuum.io/downloads) (recommended), then you should already have it. (Just navigate to the ssd.pytorch cloned repo and run):
`jupyter notebook`

2. If using [pip](https://pypi.python.org/pypi/pip):

```Shell
# make sure pip is upgraded
pip3 install --upgrade pip
Expand Down Expand Up @@ -169,5 +169,5 @@ We have accumulated the following to-do list, which we hope to complete in the n
- Wei Liu, et al. "SSD: Single Shot MultiBox Detector." [ECCV2016]((http://arxiv.org/abs/1512.02325)).
- [Original Implementation (CAFFE)](https://github.com/weiliu89/caffe/tree/ssd)
- A huge thank you to [Alex Koltun](https://github.com/alexkoltun) and his team at [Webyclip](webyclip.com) for their help in finishing the data augmentation portion.
- A list of other great SSD ports that were sources of inspiration (especially the Chainer repo):
* [Chainer](https://github.com/Hakuyume/chainer-ssd), [Keras](https://github.com/rykov8/ssd_keras), [MXNet](https://github.com/zhreshold/mxnet-ssd), [Tensorflow](https://github.com/balancap/SSD-Tensorflow)
- A list of other great SSD ports that were sources of inspiration (especially the Chainer repo):
* [Chainer](https://github.com/Hakuyume/chainer-ssd), [Keras](https://github.com/rykov8/ssd_keras), [MXNet](https://github.com/zhreshold/mxnet-ssd), [Tensorflow](https://github.com/balancap/SSD-Tensorflow)
3 changes: 2 additions & 1 deletion data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .voc0712 import VOCDetection, VOCAnnotationTransform, VOC_CLASSES, VOC_ROOT
from .coco import COCODetection, COCOAnnotationTransform, COCO_CLASSES, COCO_ROOT

from .coco import COCODetection, COCOAnnotationTransform, COCO_CLASSES, COCO_ROOT, get_label_map
from .config import *
import torch
import cv2
Expand Down
24 changes: 12 additions & 12 deletions data/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@
'teddy bear', 'hair drier', 'toothbrush')


def get_label_map(label_file):
label_map = {}
labels = open(label_file, 'r')
for line in labels:
ids = line.split(',')
label_map[int(ids[0])] = int(ids[1])
return label_map


class COCOAnnotationTransform(object):
"""Transforms a COCO annotation into a Tensor of bbox coords and label index
Initilized with a dictionary lookup of classnames to indexes
Expand Down Expand Up @@ -74,8 +83,8 @@ class COCODetection(data.Dataset):
in the target (bbox) and transforms it.
"""

def __init__(self, root, image_set, transform=None,
target_transform=None):
def __init__(self, root, image_set='trainval35k', transform=None,
target_transform=COCOAnnotationTransform(), dataset_name='MS COCO'):
sys.path.append(osp.join(root, COCO_API))
from pycocotools.coco import COCO
self.root = osp.join(root, IMAGES, image_set)
Expand All @@ -84,7 +93,7 @@ def __init__(self, root, image_set, transform=None,
self.ids = list(self.coco.imgToAnns.keys())
self.transform = transform
self.target_transform = target_transform
self.name = 'MS COCO ' + image_set
self.name = dataset_name

def __getitem__(self, index):
"""
Expand Down Expand Up @@ -169,12 +178,3 @@ def __repr__(self):
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str


def get_label_map(label_file):
label_map = {}
labels = open(label_file, 'r')
for line in labels:
ids = line.split(',')
label_map[int(ids[0])] = int(ids[1])
return label_map
80 changes: 80 additions & 0 deletions data/coco_labels.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
1,1,person
2,2,bicycle
3,3,car
4,4,motorcycle
5,5,airplane
6,6,bus
7,7,train
8,8,truck
9,9,boat
10,10,traffic light
11,11,fire hydrant
13,12,stop sign
14,13,parking meter
15,14,bench
16,15,bird
17,16,cat
18,17,dog
19,18,horse
20,19,sheep
21,20,cow
22,21,elephant
23,22,bear
24,23,zebra
25,24,giraffe
27,25,backpack
28,26,umbrella
31,27,handbag
32,28,tie
33,29,suitcase
34,30,frisbee
35,31,skis
36,32,snowboard
37,33,sports ball
38,34,kite
39,35,baseball bat
40,36,baseball glove
41,37,skateboard
42,38,surfboard
43,39,tennis racket
44,40,bottle
46,41,wine glass
47,42,cup
48,43,fork
49,44,knife
50,45,spoon
51,46,bowl
52,47,banana
53,48,apple
54,49,sandwich
55,50,orange
56,51,broccoli
57,52,carrot
58,53,hot dog
59,54,pizza
60,55,donut
61,56,cake
62,57,chair
63,58,couch
64,59,potted plant
65,60,bed
67,61,dining table
70,62,toilet
72,63,tv
73,64,laptop
74,65,mouse
75,66,remote
76,67,keyboard
77,68,cell phone
78,69,microwave
79,70,oven
80,71,toaster
81,72,sink
82,73,refrigerator
84,74,book
85,75,clock
86,76,vase
87,77,scissors
88,78,teddy bear
89,79,hair drier
90,80,toothbrush
6 changes: 6 additions & 0 deletions data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

# SSD300 CONFIGS
voc = {
'num_classes': 21,
'lr_steps': (80000, 100000, 120000),
'max_iter': 120000,
'feature_maps': [38, 19, 10, 5, 3, 1],
'min_dim': 300,
'steps': [8, 16, 32, 64, 100, 300],
Expand All @@ -24,6 +27,9 @@
}

coco = {
'num_classes': 201,
'lr_steps': (280000, 360000, 400000),
'max_iter': 400000,
'feature_maps': [38, 19, 10, 5, 3, 1],
'min_dim': 300,
'steps': [8, 16, 32, 64, 100, 300],
Expand Down
19 changes: 9 additions & 10 deletions data/voc0712.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,10 @@
Updated by: Ellis Brown, Max deGroot
"""
from .config import HOME
import os
import os.path
import os.path as osp
import sys
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image, ImageDraw, ImageFont
import cv2
import numpy as np
if sys.version_info[0] == 2:
Expand All @@ -28,7 +25,7 @@
'sheep', 'sofa', 'train', 'tvmonitor')

# note: if you used our download scripts, this should be right
VOC_ROOT = os.path.join(HOME, "data/VOCdevkit/")
VOC_ROOT = osp.join(HOME, "data/VOCdevkit/")


class VOCAnnotationTransform(object):
Expand Down Expand Up @@ -97,19 +94,21 @@ class VOCDetection(data.Dataset):
(default: 'VOC2007')
"""

def __init__(self, root, image_sets, transform=None, target_transform=None,
def __init__(self, root,
image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
transform=None, target_transform=VOCAnnotationTransform(),
dataset_name='VOC0712'):
self.root = root
self.image_set = image_sets
self.transform = transform
self.target_transform = target_transform
self.name = dataset_name
self._annopath = os.path.join('%s', 'Annotations', '%s.xml')
self._imgpath = os.path.join('%s', 'JPEGImages', '%s.jpg')
self._annopath = osp.join('%s', 'Annotations', '%s.xml')
self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
self.ids = list()
for (year, name) in image_sets:
rootpath = os.path.join(self.root, 'VOC' + year)
for line in open(os.path.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
rootpath = osp.join(self.root, 'VOC' + year)
for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
self.ids.append((rootpath, line.strip()))

def __getitem__(self, index):
Expand Down
24 changes: 14 additions & 10 deletions demo/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,12 @@
"import torch.nn as nn\n",
"import torch.backends.cudnn as cudnn\n",
"from torch.autograd import Variable\n",
"import torch.utils.data as data\n",
"import torchvision.transforms as transforms\n",
"from torch.utils.serialization import load_lua\n",
"import numpy as np\n",
"import cv2\n",
"if torch.cuda.is_available():\n",
" torch.set_default_tensor_type('torch.cuda.FloatTensor')\n",
"\n",
"from ssd import build_ssd\n",
"# from models import build_ssd as build_ssd_v1 # uncomment for older pool6 model"
"from ssd import build_ssd"
]
},
{
Expand All @@ -52,6 +48,7 @@
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [
Expand Down Expand Up @@ -80,7 +77,9 @@
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
Expand All @@ -97,9 +96,9 @@
"# image = cv2.imread('./data/example.jpg', cv2.IMREAD_COLOR) # uncomment if dataset not downloaded\n",
"%matplotlib inline\n",
"from matplotlib import pyplot as plt\n",
"from data import VOCDetection, VOCroot, AnnotationTransform\n",
"from data import VOCDetection, VOC_ROOT, VOCAnnotationTransform\n",
"# here we specify year (07 or 12) and dataset ('test', 'val', 'train') \n",
"testset = VOCDetection(VOCroot, [('2007', 'val')], None, AnnotationTransform())\n",
"testset = VOCDetection(VOC_ROOT, [('2007', 'val')], None, VOCAnnotationTransform())\n",
"img_id = 60\n",
"image = testset.pull_image(img_id)\n",
"rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
Expand All @@ -123,7 +122,9 @@
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
Expand Down Expand Up @@ -157,6 +158,7 @@
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
Expand All @@ -179,7 +181,9 @@
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
Expand Down
Loading

0 comments on commit 8dd3865

Please sign in to comment.