Skip to content

Commit

Permalink
HTC and SCNet
Browse files Browse the repository at this point in the history
  • Loading branch information
chhluo authored and ZwwWayne committed Jul 19, 2022
1 parent 36c1f47 commit 9731c4f
Show file tree
Hide file tree
Showing 34 changed files with 1,628 additions and 1,198 deletions.
93 changes: 52 additions & 41 deletions configs/_base_/datasets/coco_instance_semantic.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,65 @@
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

# file_client_args = dict(
# backend='petrel',
# path_mapping=dict({
# './data/': 's3://openmmlab/datasets/detection/',
# 'data/': 's3://openmmlab/datasets/detection/'
# }))
file_client_args = dict(backend='disk')

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='LoadAnnotations', with_bbox=True, with_mask=True, with_seg=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='SegRescale', scale_factor=1 / 8),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']),
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
dict(type='PackDetInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
seg_prefix=data_root + 'stuffthingmaps/train2017/',
pipeline=train_pipeline),
val=dict(

train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
batch_sampler=dict(type='AspectRatioBatchSampler'),
dataset=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
data_root=data_root,
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/', seg='stuffthingmaps/train2017/'),
filter_cfg=dict(filter_empty_gt=True, min_size=32),
pipeline=train_pipeline))

val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
data_root=data_root,
ann_file='annotations/instances_val2017.json',
data_prefix=dict(img='val2017/'),
test_mode=True,
pipeline=test_pipeline))
evaluation = dict(metric=['bbox', 'segm'])

test_dataloader = val_dataloader

val_evaluator = dict(
type='CocoMetric',
ann_file=data_root + 'annotations/instances_val2017.json',
metric=['bbox', 'segm'],
format_only=False)
test_evaluator = val_evaluator
5 changes: 1 addition & 4 deletions configs/htc/htc_r101_fpn_20e_coco.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
_base_ = './htc_r50_fpn_1x_coco.py'
_base_ = './htc_r50_fpn_20e_coco.py'
model = dict(
backbone=dict(
depth=101,
init_cfg=dict(type='Pretrained',
checkpoint='torchvision://resnet101')))
# learning policy
lr_config = dict(step=[16, 19])
runner = dict(type='EpochBasedRunner', max_epochs=20)
47 changes: 13 additions & 34 deletions configs/htc/htc_r50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
_base_ = './htc_without_semantic_r50_fpn_1x_coco.py'
model = dict(
data_preprocessor=dict(pad_seg=True),
roi_head=dict(
semantic_roi_extractor=dict(
type='SingleRoIExtractor',
Expand All @@ -10,47 +11,25 @@
type='FusedSemanticHead',
num_ins=5,
fusion_level=1,
seg_scale_factor=1 / 8,
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=183,
loss_seg=dict(
type='CrossEntropyLoss', ignore_index=255, loss_weight=0.2))))
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='LoadAnnotations', with_bbox=True, with_mask=True, with_seg=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='SegRescale', scale_factor=1 / 8),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
type='LoadImageFromFile',
file_client_args={{_base_.file_client_args}}),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
type='LoadAnnotations', with_bbox=True, with_mask=True, with_seg=True),
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
dict(type='PackDetInputs')
]
data = dict(
train=dict(
seg_prefix=data_root + 'stuffthingmaps/train2017/',
pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
train_dataloader = dict(
dataset=dict(
data_prefix=dict(img='train2017/', seg='stuffthingmaps/train2017/'),
pipeline=train_pipeline))
16 changes: 14 additions & 2 deletions configs/htc/htc_r50_fpn_20e_coco.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
_base_ = './htc_r50_fpn_1x_coco.py'

# learning policy
lr_config = dict(step=[16, 19])
runner = dict(type='EpochBasedRunner', max_epochs=20)
max_epochs = 20
param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[16, 19],
gamma=0.1)
]
train_cfg = dict(max_epochs=max_epochs)
25 changes: 6 additions & 19 deletions configs/htc/htc_without_semantic_r50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
# model settings
model = dict(
type='HybridTaskCascade',
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=32),
backbone=dict(
type='ResNet',
depth=50,
Expand Down Expand Up @@ -215,22 +221,3 @@
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5)))
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
val=dict(pipeline=test_pipeline), test=dict(pipeline=test_pipeline))
19 changes: 16 additions & 3 deletions configs/htc/htc_x101_32x4d_fpn_16x1_20e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,20 @@
style='pytorch',
init_cfg=dict(
type='Pretrained', checkpoint='open-mmlab://resnext101_32x4d')))
data = dict(samples_per_gpu=1, workers_per_gpu=1)

train_dataloader = dict(batch_size=1, num_workers=1)

# learning policy
lr_config = dict(step=[16, 19])
runner = dict(type='EpochBasedRunner', max_epochs=20)
max_epochs = 20
param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[16, 19],
gamma=0.1)
]
train_cfg = dict(max_epochs=max_epochs)
14 changes: 1 addition & 13 deletions configs/htc/htc_x101_64x4d_fpn_16x1_20e_coco.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,7 @@
_base_ = './htc_r50_fpn_1x_coco.py'
_base_ = './htc_x101_32x4d_fpn_16x1_20e_coco.py'
model = dict(
backbone=dict(
type='ResNeXt',
depth=101,
groups=64,
base_width=4,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(
type='Pretrained', checkpoint='open-mmlab://resnext101_64x4d')))
data = dict(samples_per_gpu=1, workers_per_gpu=1)
# learning policy
lr_config = dict(step=[16, 19])
runner = dict(type='EpochBasedRunner', max_epochs=20)
Original file line number Diff line number Diff line change
@@ -1,43 +1,20 @@
_base_ = './htc_r50_fpn_1x_coco.py'
_base_ = './htc_x101_64x4d_fpn_16x1_20e_coco.py'

model = dict(
backbone=dict(
type='ResNeXt',
depth=101,
groups=64,
base_width=4,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
dcn=dict(type='DCN', deform_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True),
init_cfg=dict(
type='Pretrained', checkpoint='open-mmlab://resnext101_64x4d')))
stage_with_dcn=(False, True, True, True)))

# dataset settings
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='LoadAnnotations', with_bbox=True, with_mask=True, with_seg=True),
dict(
type='Resize',
img_scale=[(1600, 400), (1600, 1400)],
multiscale_mode='range',
type='RandomResize',
scale=[(1600, 400), (1600, 1400)],
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='SegRescale', scale_factor=1 / 8),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']),
dict(type='RandomFlip', prob=0.5),
dict(type='PackDetInputs')
]
data = dict(
samples_per_gpu=1, workers_per_gpu=1, train=dict(pipeline=train_pipeline))
# learning policy
lr_config = dict(step=[16, 19])
runner = dict(type='EpochBasedRunner', max_epochs=20)
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
2 changes: 2 additions & 0 deletions configs/scnet/scnet_r50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
type='SCNetSemanticHead',
num_ins=5,
fusion_level=1,
seg_scale_factor=1 / 8,
num_convs=4,
in_channels=256,
conv_out_channels=256,
Expand All @@ -112,6 +113,7 @@
roi_feat_size=7,
scale_factor=2)))

# TODO
# uncomment below code to enable test time augmentations
# img_norm_cfg = dict(
# mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
Expand Down
15 changes: 13 additions & 2 deletions configs/scnet/scnet_r50_fpn_20e_coco.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
_base_ = './scnet_r50_fpn_1x_coco.py'
# learning policy
lr_config = dict(step=[16, 19])
runner = dict(type='EpochBasedRunner', max_epochs=20)
max_epochs = 20
param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[16, 19],
gamma=0.1)
]
train_cfg = dict(max_epochs=max_epochs)
4 changes: 2 additions & 2 deletions configs/scnet/scnet_x101_64x4d_fpn_8x1_20e_coco.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
_base_ = './scnet_x101_64x4d_fpn_20e_coco.py'
data = dict(samples_per_gpu=1, workers_per_gpu=1)
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
train_dataloader = dict(batch_size=1, num_workers=1)

optim_wrapper = dict(optimizer=dict(lr=0.01))
# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (1 samples per GPU)
Expand Down
9 changes: 7 additions & 2 deletions mmdet/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,15 @@ def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
data_info = {}

img_path = osp.join(self.data_prefix['img'], img_info['file_name'])
seg_map_path = img_info['file_name'].replace('jpg', 'png')
if self.data_prefix.get('seg', None):
seg_map_path = osp.join(
self.data_prefix['seg'],
img_info['file_name'].replace('jpg', 'png'))
else:
seg_map_path = None
data_info['img_path'] = img_path
data_info['img_id'] = img_info['img_id']
data_info['seg_path'] = seg_map_path
data_info['seg_map_path'] = seg_map_path
data_info['height'] = img_info['height']
data_info['width'] = img_info['width']

Expand Down
Loading

0 comments on commit 9731c4f

Please sign in to comment.