Skip to content

Commit 300deff

Browse files
committed
fix bugs
1 parent 9035d52 commit 300deff

File tree

4 files changed

+126
-7
lines changed

4 files changed

+126
-7
lines changed

configs/centripetalnet_mask_hg104.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
type='Centripetal_mask',
1414
num_classes=81,
1515
in_channels=256,
16-
with_mask=False,
16+
with_mask=True,
1717
))
1818
# training and testing settings
1919
train_cfg = dict(
@@ -37,8 +37,7 @@
3737
max_per_img=100)
3838
# dataset settings
3939
dataset_type = 'CocoDataset'
40-
#data_root = 'data/mscoco2017/'
41-
data_root = '/mnt/lustre/share/DSK/datasets/mscoco2017/'
40+
data_root = 'data/mscoco2017/'
4241
img_norm_cfg = dict(
4342
mean=[103.53, 116.28, 123.675], std=[57.375, 57.12, 58.395], to_rgb=False)
4443

src/datasets/coco.py

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import numpy as np
2+
from pycocotools.coco import COCO
3+
4+
from .custom import CustomDataset
5+
6+
7+
class CocoDataset(CustomDataset):
8+
9+
CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
10+
'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant',
11+
'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog',
12+
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
13+
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
14+
'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat',
15+
'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket',
16+
'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
17+
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
18+
'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
19+
'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop',
20+
'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
21+
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
22+
'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush')
23+
24+
def load_annotations(self, ann_file):
25+
self.coco = COCO(ann_file)
26+
self.cat_ids = self.coco.getCatIds()
27+
self.cat2label = {
28+
cat_id: i + 1
29+
for i, cat_id in enumerate(self.cat_ids)
30+
}
31+
self.img_ids = self.coco.getImgIds()
32+
img_infos = []
33+
for i in self.img_ids:
34+
info = self.coco.loadImgs([i])[0]
35+
info['filename'] = info['file_name']
36+
img_infos.append(info)
37+
return img_infos
38+
39+
def get_ann_info(self, idx):
40+
img_id = self.img_infos[idx]['id']
41+
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
42+
ann_info = self.coco.loadAnns(ann_ids)
43+
return self._parse_ann_info(ann_info, self.with_mask)
44+
45+
def _filter_imgs(self, min_size=32):
46+
"""Filter images too small or without ground truths."""
47+
valid_inds = []
48+
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
49+
for i, img_info in enumerate(self.img_infos):
50+
if self.img_ids[i] not in ids_with_ann:
51+
continue
52+
if min(img_info['width'], img_info['height']) >= min_size:
53+
valid_inds.append(i)
54+
return valid_inds
55+
56+
def _parse_ann_info(self, ann_info, with_mask=True):
57+
"""Parse bbox and mask annotation.
58+
59+
Args:
60+
ann_info (list[dict]): Annotation info of an image.
61+
with_mask (bool): Whether to parse mask annotations.
62+
63+
Returns:
64+
dict: A dict containing the following keys: bboxes, bboxes_ignore,
65+
labels, masks, mask_polys, poly_lens.
66+
"""
67+
gt_bboxes = []
68+
gt_labels = []
69+
gt_bboxes_ignore = []
70+
# Two formats are provided.
71+
# 1. mask: a binary map of the same size of the image.
72+
# 2. polys: each mask consists of one or several polys, each poly is a
73+
# list of float.
74+
if with_mask:
75+
gt_masks = []
76+
gt_mask_polys = []
77+
gt_poly_lens = []
78+
for i, ann in enumerate(ann_info):
79+
if ann.get('ignore', False):
80+
continue
81+
x1, y1, w, h = ann['bbox']
82+
if ann['area'] <= 0 or w < 1 or h < 1:
83+
continue
84+
bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
85+
if ann['iscrowd']:
86+
gt_bboxes_ignore.append(bbox)
87+
else:
88+
gt_bboxes.append(bbox)
89+
gt_labels.append(self.cat2label[ann['category_id']])
90+
if with_mask and not ann['iscrowd']:
91+
gt_masks.append(self.coco.annToMask(ann))
92+
mask_polys = [
93+
p for p in ann['segmentation'] if len(p) >= 6
94+
] # valid polygons have >= 3 points (6 coordinates)
95+
poly_lens = [len(p) for p in mask_polys]
96+
gt_mask_polys.append(mask_polys)
97+
gt_poly_lens.extend(poly_lens)
98+
if gt_bboxes:
99+
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
100+
gt_labels = np.array(gt_labels, dtype=np.int64)
101+
else:
102+
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
103+
gt_labels = np.array([], dtype=np.int64)
104+
105+
if gt_bboxes_ignore:
106+
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
107+
else:
108+
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
109+
110+
ann = dict(
111+
bboxes=gt_bboxes, labels=gt_labels, bboxes_ignore=gt_bboxes_ignore)
112+
113+
if with_mask:
114+
ann['masks'] = gt_masks
115+
# poly format is not used in the current implementation
116+
ann['mask_polys'] = gt_mask_polys
117+
ann['poly_lens'] = gt_poly_lens
118+
return ann

src/datasets/custom.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,8 @@ def prepare_test_img(self, idx, gt=True):#keep ratio and padding to desired size
336336
ann = self.get_ann_info(idx)
337337
gt_bboxes = ann['bboxes']
338338
gt_labels = ann['labels']
339-
#gt_masks = ann['masks']
339+
if self.with_mask:
340+
gt_masks = ann['masks']
340341

341342
def prepare_single(img, scale, flip):
342343
_img, border, offset = self.img_transform(
@@ -364,8 +365,9 @@ def prepare_single(img, scale, flip):
364365
imgs.append(_img)
365366
img_metas.append(DC(_img_meta, cpu_only=True))
366367
data = dict(img=imgs, img_meta=img_metas)
367-
h,w=_img.shape[0:2]
368-
gt_masks = [np.zeros([h,w])]
368+
if not self.with_mask:
369+
h, w = _img.shape[0:2]
370+
gt_masks = [np.zeros([h, w])]
369371

370372
if len(gt_labels)==0:
371373
gt_labels = np.array([-1])

src/models/bbox_heads/centripetal_mask.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def forward(self, feats):
184184

185185
def loss(self, tl_result, br_result, mask, mid_tl_result, mid_br_result, mid_mask, gt_bboxes, gt_labels, gt_masks, img_metas, cfg, imgscale):
186186
gt_tl_heatmap, gt_br_heatmap, gt_tl_offsets, gt_br_offsets, gt_tl_off_c, gt_br_off_c,\
187-
gt_tl_off_c2, gt_br_off_c2 = cornerv2_target(gt_bboxes=gt_bboxes, gt_labels=gt_labels, feats=tl_result, imgscale=imgscale, direct=True, scale=1.0, dcn=True)
187+
gt_tl_off_c2, gt_br_off_c2 = corner_target(gt_bboxes=gt_bboxes, gt_labels=gt_labels, feats=tl_result, imgscale=imgscale, direct=True, scale=1.0, dcn=True)
188188
# pred_tl_heatmap = _sigmoid(tl_result[:, :self.num_classes, :, :])
189189
pred_tl_heatmap = tl_result[:, :self.num_classes, :, :].sigmoid()
190190
pred_tl_off_c = tl_result[:, self.num_classes:self.num_classes + 2, :, :]

0 commit comments

Comments
 (0)