From 1f35b0f7f0918e77601dcb91643a46d25dd35c29 Mon Sep 17 00:00:00 2001 From: zytx121 Date: Mon, 11 Jul 2022 11:12:00 +0000 Subject: [PATCH] Refactor GFL & LD & LAD model --- ...ffe_fpn_gn-head_mstrain_640-800_2x_coco.py | 11 +- ...ffe_fpn_gn-head_mstrain_640-800_2x_coco.py | 11 +- ...x4d_fpn_gn-head_mstrain_640-800_2x_coco.py | 11 +- configs/gfl/gfl_r50_fpn_1x_coco.py | 11 +- configs/gfl/gfl_r50_fpn_mstrain_2x_coco.py | 36 +- .../ld/ld_r101_gflv1_r101dcn_fpn_coco_2x.py | 35 +- configs/ld/ld_r18_gflv1_r101_fpn_coco_1x.py | 10 +- configs/paa/paa_r50_fpn_mstrain_3x_coco.py | 9 +- configs/tood/tood_r50_fpn_mstrain_2x_coco.py | 9 +- mmdet/models/dense_heads/gfl_head.py | 442 +++++++++--------- mmdet/models/dense_heads/lad_head.py | 91 ++-- mmdet/models/dense_heads/ld_head.py | 154 +++--- mmdet/models/detectors/gfl.py | 41 +- mmdet/models/detectors/kd_one_stage.py | 24 +- .../test_dense_heads/test_gfl_head.py | 89 ++++ .../test_dense_heads/test_lad_head.py | 17 +- .../test_dense_heads/test_ld_head.py | 149 ++++++ .../test_detectors/test_kd_single_stage.py | 42 +- 18 files changed, 753 insertions(+), 439 deletions(-) create mode 100644 tests/test_models/test_dense_heads/test_gfl_head.py create mode 100644 tests/test_models/test_dense_heads/test_ld_head.py diff --git a/configs/fcos/fcos_r101_caffe_fpn_gn-head_mstrain_640-800_2x_coco.py b/configs/fcos/fcos_r101_caffe_fpn_gn-head_mstrain_640-800_2x_coco.py index a8831aa4a12..fd1fb765a3c 100644 --- a/configs/fcos/fcos_r101_caffe_fpn_gn-head_mstrain_640-800_2x_coco.py +++ b/configs/fcos/fcos_r101_caffe_fpn_gn-head_mstrain_640-800_2x_coco.py @@ -10,19 +10,22 @@ # dataset settings train_pipeline = [ - dict(type='LoadImageFromFile'), + dict( + type='LoadImageFromFile', + file_client_args={{_base_.file_client_args}}), dict(type='LoadAnnotations', with_bbox=True), dict( type='RandomChoiceResize', scale=[(1333, 640), (1333, 800)], - resize_cfg=dict(type='Resize', keep_ratio=True)), + keep_ratio=True), dict(type='RandomFlip', prob=0.5), dict(type='PackDetInputs') ] train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) # training schedule for 2x -train_cfg = dict(max_epochs=24) +max_epochs = 24 +train_cfg = dict(max_epochs=max_epochs) # learning rate param_scheduler = [ @@ -30,7 +33,7 @@ dict( type='MultiStepLR', begin=0, - end=24, + end=max_epochs, by_epoch=True, milestones=[16, 22], gamma=0.1) diff --git a/configs/fcos/fcos_r50_caffe_fpn_gn-head_mstrain_640-800_2x_coco.py b/configs/fcos/fcos_r50_caffe_fpn_gn-head_mstrain_640-800_2x_coco.py index 2021c933af8..1384a3889de 100644 --- a/configs/fcos/fcos_r50_caffe_fpn_gn-head_mstrain_640-800_2x_coco.py +++ b/configs/fcos/fcos_r50_caffe_fpn_gn-head_mstrain_640-800_2x_coco.py @@ -2,19 +2,22 @@ # dataset settings train_pipeline = [ - dict(type='LoadImageFromFile'), + dict( + type='LoadImageFromFile', + file_client_args={{_base_.file_client_args}}), dict(type='LoadAnnotations', with_bbox=True), dict( type='RandomChoiceResize', scale=[(1333, 640), (1333, 800)], - resize_cfg=dict(type='Resize', keep_ratio=True)), + keep_ratio=True), dict(type='RandomFlip', prob=0.5), dict(type='PackDetInputs') ] train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) # training schedule for 2x -train_cfg = dict(max_epochs=24) +max_epochs = 24 +train_cfg = dict(max_epochs=max_epochs) # learning rate param_scheduler = [ @@ -22,7 +25,7 @@ dict( type='MultiStepLR', begin=0, - end=24, + end=max_epochs, by_epoch=True, milestones=[16, 22], gamma=0.1) diff --git a/configs/fcos/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco.py b/configs/fcos/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco.py index fc07f606615..5cf9ff39eb0 100644 --- a/configs/fcos/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco.py +++ b/configs/fcos/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco.py @@ -24,19 +24,22 @@ # dataset settings train_pipeline = [ - dict(type='LoadImageFromFile'), + dict( + type='LoadImageFromFile', + file_client_args={{_base_.file_client_args}}), dict(type='LoadAnnotations', with_bbox=True), dict( type='RandomChoiceResize', scale=[(1333, 640), (1333, 800)], - resize_cfg=dict(type='Resize', keep_ratio=True)), + keep_ratio=True), dict(type='RandomFlip', prob=0.5), dict(type='PackDetInputs') ] train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) # training schedule for 2x -train_cfg = dict(max_epochs=24) +max_epochs = 24 +train_cfg = dict(max_epochs=max_epochs) # learning rate param_scheduler = [ @@ -44,7 +47,7 @@ dict( type='MultiStepLR', begin=0, - end=24, + end=max_epochs, by_epoch=True, milestones=[16, 22], gamma=0.1) diff --git a/configs/gfl/gfl_r50_fpn_1x_coco.py b/configs/gfl/gfl_r50_fpn_1x_coco.py index cfd4b02391a..902382552d5 100644 --- a/configs/gfl/gfl_r50_fpn_1x_coco.py +++ b/configs/gfl/gfl_r50_fpn_1x_coco.py @@ -4,6 +4,12 @@ ] model = dict( type='GFL', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), backbone=dict( type='ResNet', depth=50, @@ -53,5 +59,8 @@ score_thr=0.05, nms=dict(type='nms', iou_threshold=0.6), max_per_img=100)) + # optimizer -optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)) diff --git a/configs/gfl/gfl_r50_fpn_mstrain_2x_coco.py b/configs/gfl/gfl_r50_fpn_mstrain_2x_coco.py index b8be6014575..cb1137e01df 100644 --- a/configs/gfl/gfl_r50_fpn_mstrain_2x_coco.py +++ b/configs/gfl/gfl_r50_fpn_mstrain_2x_coco.py @@ -1,22 +1,30 @@ _base_ = './gfl_r50_fpn_1x_coco.py' +max_epochs = 24 + # learning policy -lr_config = dict(step=[16, 22]) -runner = dict(type='EpochBasedRunner', max_epochs=24) +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500), + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[16, 22], + gamma=0.1) +] +train_cfg = dict(max_epochs=max_epochs) + # multi-scale training -img_norm_cfg = dict( - mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) train_pipeline = [ - dict(type='LoadImageFromFile'), + dict( + type='LoadImageFromFile', + file_client_args={{_base_.file_client_args}}), dict(type='LoadAnnotations', with_bbox=True), dict( - type='Resize', - img_scale=[(1333, 480), (1333, 800)], - multiscale_mode='range', + type='RandomResize', scale=[(1333, 480), (1333, 800)], keep_ratio=True), - dict(type='RandomFlip', flip_ratio=0.5), - dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size_divisor=32), - dict(type='DefaultFormatBundle'), - dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), + dict(type='RandomFlip', prob=0.5), + dict(type='PackDetInputs') ] -data = dict(train=dict(pipeline=train_pipeline)) +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) diff --git a/configs/ld/ld_r101_gflv1_r101dcn_fpn_coco_2x.py b/configs/ld/ld_r101_gflv1_r101dcn_fpn_coco_2x.py index 1cbdb4cf5a5..0dec5d6c1fd 100644 --- a/configs/ld/ld_r101_gflv1_r101dcn_fpn_coco_2x.py +++ b/configs/ld/ld_r101_gflv1_r101dcn_fpn_coco_2x.py @@ -22,23 +22,30 @@ add_extra_convs='on_output', num_outs=5)) -lr_config = dict(step=[16, 22]) -runner = dict(type='EpochBasedRunner', max_epochs=24) +max_epochs = 24 +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500), + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[16, 22], + gamma=0.1) +] +train_cfg = dict(max_epochs=max_epochs) + # multi-scale training -img_norm_cfg = dict( - mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) train_pipeline = [ - dict(type='LoadImageFromFile'), + dict( + type='LoadImageFromFile', + file_client_args={{_base_.file_client_args}}), dict(type='LoadAnnotations', with_bbox=True), dict( - type='Resize', - img_scale=[(1333, 480), (1333, 800)], - multiscale_mode='range', + type='RandomResize', scale=[(1333, 480), (1333, 800)], keep_ratio=True), - dict(type='RandomFlip', flip_ratio=0.5), - dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size_divisor=32), - dict(type='DefaultFormatBundle'), - dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), + dict(type='RandomFlip', prob=0.5), + dict(type='PackDetInputs') ] -data = dict(train=dict(pipeline=train_pipeline)) +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) diff --git a/configs/ld/ld_r18_gflv1_r101_fpn_coco_1x.py b/configs/ld/ld_r18_gflv1_r101_fpn_coco_1x.py index 18dce814be9..db09721af61 100644 --- a/configs/ld/ld_r18_gflv1_r101_fpn_coco_1x.py +++ b/configs/ld/ld_r18_gflv1_r101_fpn_coco_1x.py @@ -5,6 +5,12 @@ teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/gfl/gfl_r101_fpn_mstrain_2x_coco/gfl_r101_fpn_mstrain_2x_coco_20200629_200126-dd12f847.pth' # noqa model = dict( type='KnowledgeDistillationSingleStageDetector', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), teacher_config='configs/gfl/gfl_r101_fpn_mstrain_2x_coco.py', teacher_ckpt=teacher_ckpt, backbone=dict( @@ -59,4 +65,6 @@ nms=dict(type='nms', iou_threshold=0.6), max_per_img=100)) -optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)) diff --git a/configs/paa/paa_r50_fpn_mstrain_3x_coco.py b/configs/paa/paa_r50_fpn_mstrain_3x_coco.py index 1a088387b00..803ceeca0ec 100644 --- a/configs/paa/paa_r50_fpn_mstrain_3x_coco.py +++ b/configs/paa/paa_r50_fpn_mstrain_3x_coco.py @@ -18,12 +18,13 @@ train_cfg = dict(max_epochs=max_epochs) train_pipeline = [ - dict(type='LoadImageFromFile'), + dict( + type='LoadImageFromFile', + file_client_args={{_base_.file_client_args}}), dict(type='LoadAnnotations', with_bbox=True), dict( - type='RandomResize', - scale=[(1333, 640), (1333, 800)], - resize_cfg=dict(type='Resize', keep_ratio=True)), + type='RandomResize', scale=[(1333, 640), (1333, 800)], + keep_ratio=True), dict(type='RandomFlip', prob=0.5), dict(type='PackDetInputs') ] diff --git a/configs/tood/tood_r50_fpn_mstrain_2x_coco.py b/configs/tood/tood_r50_fpn_mstrain_2x_coco.py index 1f7bebb656d..93d1d47521d 100644 --- a/configs/tood/tood_r50_fpn_mstrain_2x_coco.py +++ b/configs/tood/tood_r50_fpn_mstrain_2x_coco.py @@ -19,12 +19,13 @@ # multi-scale training train_pipeline = [ - dict(type='LoadImageFromFile'), + dict( + type='LoadImageFromFile', + file_client_args={{_base_.file_client_args}}), dict(type='LoadAnnotations', with_bbox=True), dict( - type='RandomResize', - scale=[(1333, 480), (1333, 800)], - resize_cfg=dict(type='Resize', keep_ratio=True)), + type='RandomResize', scale=[(1333, 480), (1333, 800)], + keep_ratio=True), dict(type='RandomFlip', prob=0.5), dict(type='PackDetInputs') ] diff --git a/mmdet/models/dense_heads/gfl_head.py b/mmdet/models/dense_heads/gfl_head.py index 7510a87c33d..84e02acce7c 100644 --- a/mmdet/models/dense_heads/gfl_head.py +++ b/mmdet/models/dense_heads/gfl_head.py @@ -1,39 +1,43 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Tuple + import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule, Scale -from mmcv.runner import force_fp32 +from mmengine.config import ConfigDict +from mmengine.data import InstanceData +from torch import Tensor -from mmdet.core import (anchor_inside_flags, bbox_overlaps, build_assigner, - build_sampler, images_to_levels, multi_apply, +from mmdet.core import (ConfigType, InstanceList, MultiConfig, OptConfigType, + OptInstanceList, PseudoSampler, anchor_inside_flags, + bbox_overlaps, images_to_levels, multi_apply, reduce_mean, unmap) from mmdet.core.utils import filter_scores_and_topk -from mmdet.registry import MODELS -from ..builder import build_loss +from mmdet.registry import MODELS, TASK_UTILS from .anchor_head import AnchorHead class Integral(nn.Module): """A fixed layer for calculating integral result from distribution. - This layer calculates the target location by :math: `sum{P(y_i) * y_i}`, + This layer calculates the target location by :math: ``sum{P(y_i) * y_i}``, P(y_i) denotes the softmax vector that represents the discrete distribution y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max} Args: - reg_max (int): The maximal value of the discrete set. Default: 16. You - may want to reset it according to your new dataset or related + reg_max (int): The maximal value of the discrete set. Defaults to 16. + You may want to reset it according to your new dataset or related settings. """ - def __init__(self, reg_max=16): - super(Integral, self).__init__() + def __init__(self, reg_max: int = 16) -> None: + super().__init__() self.reg_max = reg_max self.register_buffer('project', torch.linspace(0, self.reg_max, self.reg_max + 1)) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """Forward feature from the regression head to get integral result of bounding box location. @@ -68,17 +72,20 @@ class GFLHead(AnchorHead): category. in_channels (int): Number of channels in the input feature map. stacked_convs (int): Number of conv layers in cls and reg tower. - Default: 4. - conv_cfg (dict): dictionary to construct and config conv layer. - Default: None. - norm_cfg (dict): dictionary to construct and config norm layer. - Default: dict(type='GN', num_groups=32, requires_grad=True). - loss_qfl (dict): Config of Quality Focal Loss (QFL). - bbox_coder (dict): Config of bbox coder. Defaults - 'DistancePointBBoxCoder'. - reg_max (int): Max value of integral set :math: `{0, ..., reg_max}` - in QFL setting. Default: 16. - init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to 4. + conv_cfg (:obj:`ConfigDict` or dict, optional): dictionary to construct + and config conv layer. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict): dictionary to construct and + config norm layer. Default: dict(type='GN', num_groups=32, + requires_grad=True). + loss_qfl (:obj:`ConfigDict` or dict): Config of Quality Focal Loss + (QFL). + bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder. Defaults + to 'DistancePointBBoxCoder'. + reg_max (int): Max value of integral set :math: ``{0, ..., reg_max}`` + in QFL setting. Defaults to 16. + init_cfg (:obj:`ConfigDict` or dict or list[dict] or + list[:obj:`ConfigDict`]): Initialization config dict. Example: >>> self = GFLHead(11, 7) >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]] @@ -87,15 +94,17 @@ class GFLHead(AnchorHead): """ def __init__(self, - num_classes, - in_channels, - stacked_convs=4, - conv_cfg=None, - norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), - loss_dfl=dict(type='DistributionFocalLoss', loss_weight=0.25), - bbox_coder=dict(type='DistancePointBBoxCoder'), - reg_max=16, - init_cfg=dict( + num_classes: int, + in_channels: int, + stacked_convs: int = 4, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict( + type='GN', num_groups=32, requires_grad=True), + loss_dfl: ConfigType = dict( + type='DistributionFocalLoss', loss_weight=0.25), + bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'), + reg_max: int = 16, + init_cfg: MultiConfig = dict( type='Normal', layer='Conv2d', std=0.01, @@ -104,31 +113,32 @@ def __init__(self, name='gfl_cls', std=0.01, bias_prob=0.01)), - **kwargs): + **kwargs) -> None: self.stacked_convs = stacked_convs self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.reg_max = reg_max - super(GFLHead, self).__init__( - num_classes, - in_channels, + super().__init__( + num_classes=num_classes, + in_channels=in_channels, bbox_coder=bbox_coder, init_cfg=init_cfg, **kwargs) - self.sampling = False if self.train_cfg: - self.assigner = build_assigner(self.train_cfg.assigner) - # SSD sampling=False so use PseudoSampler - sampler_cfg = dict(type='PseudoSampler') - self.sampler = build_sampler(sampler_cfg, context=self) + self.assigner = TASK_UTILS.build(self.train_cfg.assigner) + if self.train_cfg.get('sampler', None) is not None: + self.sampler = TASK_UTILS.build( + self.train_cfg['sampler'], default_args=dict(context=self)) + else: + self.sampler = PseudoSampler(context=self) self.integral = Integral(self.reg_max) - self.loss_dfl = build_loss(loss_dfl) + self.loss_dfl = MODELS.build(loss_dfl) - def _init_layers(self): + def _init_layers(self) -> None: """Initialize layers of the head.""" - self.relu = nn.ReLU(inplace=True) + self.relu = nn.ReLU() self.cls_convs = nn.ModuleList() self.reg_convs = nn.ModuleList() for i in range(self.stacked_convs): @@ -159,25 +169,26 @@ def _init_layers(self): self.scales = nn.ModuleList( [Scale(1.0) for _ in self.prior_generator.strides]) - def forward(self, feats): + def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor]]: """Forward features from the upstream network. Args: - feats (tuple[Tensor]): Features from the upstream network, each is + x (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. Returns: tuple: Usually a tuple of classification scores and bbox prediction - cls_scores (list[Tensor]): Classification and quality (IoU) - joint scores for all scale levels, each is a 4D-tensor, - the channel number is num_classes. - bbox_preds (list[Tensor]): Box distribution logits for all - scale levels, each is a 4D-tensor, the channel number is - 4*(n+1), n is max value of integral set. + + - cls_scores (list[Tensor]): Classification and quality (IoU) + joint scores for all scale levels, each is a 4D-tensor, + the channel number is num_classes. + - bbox_preds (list[Tensor]): Box distribution logits for all + scale levels, each is a 4D-tensor, the channel number is + 4*(n+1), n is max value of integral set. """ - return multi_apply(self.forward_single, feats, self.scales) + return multi_apply(self.forward_single, x, self.scales) - def forward_single(self, x, scale): + def forward_single(self, x: Tensor, scale: Scale) -> Sequence[Tensor]: """Forward feature of a single scale level. Args: @@ -187,11 +198,12 @@ def forward_single(self, x, scale): Returns: tuple: - cls_score (Tensor): Cls and quality joint scores for a single - scale level the channel number is num_classes. - bbox_pred (Tensor): Box distribution logits for a single scale - level, the channel number is 4*(n+1), n is max value of - integral set. + + - cls_score (Tensor): Cls and quality joint scores for a single + scale level the channel number is num_classes. + - bbox_pred (Tensor): Box distribution logits for a single scale + level, the channel number is 4*(n+1), n is max value of + integral set. """ cls_feat = x reg_feat = x @@ -203,22 +215,25 @@ def forward_single(self, x, scale): bbox_pred = scale(self.gfl_reg(reg_feat)).float() return cls_score, bbox_pred - def anchor_center(self, anchors): + def anchor_center(self, anchors: Tensor) -> Tensor: """Get anchor centers from anchors. Args: - anchors (Tensor): Anchor list with shape (N, 4), "xyxy" format. + anchors (Tensor): Anchor list with shape (N, 4), ``xyxy`` format. Returns: - Tensor: Anchor centers with shape (N, 2), "xy" format. + Tensor: Anchor centers with shape (N, 2), ``xy`` format. """ anchors_cx = (anchors[..., 2] + anchors[..., 0]) / 2 anchors_cy = (anchors[..., 3] + anchors[..., 1]) / 2 return torch.stack([anchors_cx, anchors_cy], dim=-1) - def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights, - bbox_targets, stride, num_total_samples): - """Compute loss of a single scale level. + def loss_by_feat_single(self, anchors: Tensor, cls_score: Tensor, + bbox_pred: Tensor, labels: Tensor, + label_weights: Tensor, bbox_targets: Tensor, + stride: Tuple[int], avg_factor: int) -> dict: + """Calculate the loss of a single scale level based on the features + extracted by the detection head. Args: anchors (Tensor): Box reference for each scale level with shape @@ -234,9 +249,12 @@ def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights, (N, num_total_anchors) bbox_targets (Tensor): BBox regression targets of each anchor weight shape (N, num_total_anchors, 4). - stride (tuple): Stride in this scale level. - num_total_samples (int): Number of positive samples that is - reduced over all GPUs. + stride (Tuple[int]): Stride in this scale level. + avg_factor (int): Average factor that is used to average + the loss. When using sampling method, avg_factor is usually + the sum of positive and negative priors. When using + `PseudoSampler`, `avg_factor` is usually equal to the number + of positive priors. Returns: dict[str, Tensor]: A dictionary of loss components. @@ -300,19 +318,19 @@ def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights, loss_cls = self.loss_cls( cls_score, (labels, score), weight=label_weights, - avg_factor=num_total_samples) + avg_factor=avg_factor) return loss_cls, loss_bbox, loss_dfl, weight_targets.sum() - @force_fp32(apply_to=('cls_scores', 'bbox_preds')) - def loss(self, - cls_scores, - bbox_preds, - gt_bboxes, - gt_labels, - img_metas, - gt_bboxes_ignore=None): - """Compute losses of the head. + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. Args: cls_scores (list[Tensor]): Cls and quality scores for each scale @@ -320,13 +338,15 @@ def loss(self, bbox_preds (list[Tensor]): Box distribution logits for each scale level with shape (N, 4*(n+1), H, W), n is max value of integral set. - gt_bboxes (list[Tensor]): Ground truth bboxes for each image with - shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. - gt_labels (list[Tensor]): class indices corresponding to each box - img_metas (list[dict]): Meta information of each image, e.g., + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. - gt_bboxes_ignore (list[Tensor] | None): specify which bounding - boxes can be ignored when computing the loss. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. Returns: dict[str, Tensor]: A dictionary of loss components. @@ -337,31 +357,24 @@ def loss(self, device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors( - featmap_sizes, img_metas, device=device) - label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 + featmap_sizes, batch_img_metas, device=device) cls_reg_targets = self.get_targets( anchor_list, valid_flag_list, - gt_bboxes, - img_metas, - gt_bboxes_ignore_list=gt_bboxes_ignore, - gt_labels_list=gt_labels, - label_channels=label_channels) - if cls_reg_targets is None: - return None + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) (anchor_list, labels_list, label_weights_list, bbox_targets_list, - bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets + bbox_weights_list, avg_factor) = cls_reg_targets - num_total_samples = reduce_mean( - torch.tensor(num_total_pos, dtype=torch.float, - device=device)).item() - num_total_samples = max(num_total_samples, 1.0) + avg_factor = reduce_mean( + torch.tensor(avg_factor, dtype=torch.float, device=device)).item() losses_cls, losses_bbox, losses_dfl,\ avg_factor = multi_apply( - self.loss_single, + self.loss_by_feat_single, anchor_list, cls_scores, bbox_preds, @@ -369,7 +382,7 @@ def loss(self, label_weights_list, bbox_targets_list, self.prior_generator.strides, - num_total_samples=num_total_samples) + avg_factor=avg_factor) avg_factor = sum(avg_factor) avg_factor = reduce_mean(avg_factor).clamp_(min=1).item() @@ -378,17 +391,18 @@ def loss(self, return dict( loss_cls=losses_cls, loss_bbox=losses_bbox, loss_dfl=losses_dfl) - def _get_bboxes_single(self, - cls_score_list, - bbox_pred_list, - score_factor_list, - mlvl_priors, - img_meta, - cfg, - rescale=False, - with_nms=True, - **kwargs): - """Transform outputs of a single image into bbox predictions. + def _predict_by_feat_single(self, + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + score_factor_list: List[Tensor], + mlvl_priors: List[Tensor], + img_meta: dict, + cfg: ConfigDict, + rescale: bool = False, + with_nms: bool = True, + **kwargs) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. Args: cls_score_list (list[Tensor]): Box scores from all scale @@ -403,26 +417,26 @@ def _get_bboxes_single(self, the priors of a single level in feature pyramid, has shape (num_priors, 4). img_meta (dict): Image meta info. - cfg (mmcv.Config): Test / postprocessing configuration, + cfg (:obj: `ConfigDict`): Test / postprocessing configuration, if None, test_cfg would be used. rescale (bool): If True, return boxes in original image space. - Default: False. + Defaults to False. with_nms (bool): If True, do nms before return boxes. - Default: True. + Defaults to True. Returns: tuple[Tensor]: Results of detected bboxes and labels. If with_nms - is False and mlvl_score_factor is None, return mlvl_bboxes and - mlvl_scores, else return mlvl_bboxes, mlvl_scores and - mlvl_score_factor. Usually with_nms is False is used for aug - test. If with_nms is True, then return the following format - - - det_bboxes (Tensor): Predicted bboxes with shape \ - [num_bboxes, 5], where the first 4 columns are bounding \ - box positions (tl_x, tl_y, br_x, br_y) and the 5-th \ - column are scores between 0 and 1. - - det_labels (Tensor): Predicted labels of the corresponding \ - box with shape [num_bboxes]. + is False and mlvl_score_factor is None, return mlvl_bboxes and + mlvl_scores, else return mlvl_bboxes, mlvl_scores and + mlvl_score_factor. Usually with_nms is False is used for aug + test. If with_nms is True, then return the following format + + - det_bboxes (Tensor): Predicted bboxes with shape + [num_bboxes, 5], where the first 4 columns are bounding + box positions (tl_x, tl_y, br_x, br_y) and the 5-th + column are scores between 0 and 1. + - det_labels (Tensor): Predicted labels of the corresponding + box with shape [num_bboxes]. """ cfg = self.test_cfg if cfg is None else cfg img_shape = img_meta['img_shape'] @@ -462,31 +476,33 @@ def _get_bboxes_single(self, mlvl_scores.append(scores) mlvl_labels.append(labels) + results = InstanceData() + results.bboxes = torch.cat(mlvl_bboxes) + results.scores = torch.cat(mlvl_scores) + results.labels = torch.cat(mlvl_labels) + return self._bbox_post_process( - mlvl_scores, - mlvl_labels, - mlvl_bboxes, - img_meta['scale_factor'], - cfg, + results=results, + cfg=cfg, rescale=rescale, - with_nms=with_nms) + with_nms=with_nms, + img_meta=img_meta, + **kwargs) def get_targets(self, - anchor_list, - valid_flag_list, - gt_bboxes_list, - img_metas, - gt_bboxes_ignore_list=None, - gt_labels_list=None, - label_channels=1, - unmap_outputs=True): + anchor_list: List[Tensor], + valid_flag_list: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + unmap_outputs=True) -> tuple: """Get targets for GFL head. This method is almost the same as `AnchorHead.get_targets()`. Besides returning the targets as the parent method does, it also returns the anchors as the first element of the returned tuple. """ - num_imgs = len(img_metas) + num_imgs = len(batch_img_metas) assert len(anchor_list) == len(valid_flag_list) == num_imgs # anchor number of multi levels @@ -500,28 +516,25 @@ def get_targets(self, valid_flag_list[i] = torch.cat(valid_flag_list[i]) # compute targets for each image - if gt_bboxes_ignore_list is None: - gt_bboxes_ignore_list = [None for _ in range(num_imgs)] - if gt_labels_list is None: - gt_labels_list = [None for _ in range(num_imgs)] + if batch_gt_instances_ignore is None: + batch_gt_instances_ignore = [None] * num_imgs (all_anchors, all_labels, all_label_weights, all_bbox_targets, - all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply( - self._get_target_single, + all_bbox_weights, pos_inds_list, neg_inds_list, + sampling_results_list) = multi_apply( + self._get_targets_single, anchor_list, valid_flag_list, num_level_anchors_list, - gt_bboxes_list, - gt_bboxes_ignore_list, - gt_labels_list, - img_metas, - label_channels=label_channels, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore, unmap_outputs=unmap_outputs) - # no valid anchors - if any([labels is None for labels in all_labels]): - return None - # sampled anchors of all images - num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) - num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) + # Get `avg_factor` of all images, which calculate in `SamplingResult`. + # When using sampling method, avg_factor is usually the sum of + # positive and negative priors. When using `PseudoSampler`, + # `avg_factor` is usually equal to the number of positive priors. + avg_factor = sum( + [results.avg_factor for results in sampling_results_list]) # split targets to a list w.r.t. multiple levels anchors_list = images_to_levels(all_anchors, num_level_anchors) labels_list = images_to_levels(all_labels, num_level_anchors) @@ -532,19 +545,16 @@ def get_targets(self, bbox_weights_list = images_to_levels(all_bbox_weights, num_level_anchors) return (anchors_list, labels_list, label_weights_list, - bbox_targets_list, bbox_weights_list, num_total_pos, - num_total_neg) - - def _get_target_single(self, - flat_anchors, - valid_flags, - num_level_anchors, - gt_bboxes, - gt_bboxes_ignore, - gt_labels, - img_meta, - label_channels=1, - unmap_outputs=True): + bbox_targets_list, bbox_weights_list, avg_factor) + + def _get_targets_single(self, + flat_anchors: Tensor, + valid_flags: Tensor, + num_level_anchors: List[int], + gt_instances: InstanceData, + img_meta: dict, + gt_instances_ignore: Optional[InstanceData] = None, + unmap_outputs: bool = True) -> tuple: """Compute regression, classification targets for anchors in a single image. @@ -554,50 +564,60 @@ def _get_target_single(self, valid_flags (Tensor): Multi level valid flags of the image, which are concatenated into a single tensor of shape (num_anchors,). - num_level_anchors Tensor): Number of anchors of each scale level. - gt_bboxes (Tensor): Ground truth bboxes of the image, - shape (num_gts, 4). - gt_bboxes_ignore (Tensor): Ground truth bboxes to be - ignored, shape (num_ignored_gts, 4). - gt_labels (Tensor): Ground truth labels of each box, - shape (num_gts,). - img_meta (dict): Meta info of the image. - label_channels (int): Channel of label. + num_level_anchors (list[int]): Number of anchors of each scale + level. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for current image. + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. unmap_outputs (bool): Whether to map outputs back to the original - set of anchors. + set of anchors. Defaults to True. Returns: tuple: N is the number of total anchors in the image. - anchors (Tensor): All anchors in the image with shape (N, 4). - labels (Tensor): Labels of all anchors in the image with shape - (N,). - label_weights (Tensor): Label weights of all anchor in the - image with shape (N,). - bbox_targets (Tensor): BBox targets of all anchors in the - image with shape (N, 4). - bbox_weights (Tensor): BBox weights of all anchors in the - image with shape (N, 4). - pos_inds (Tensor): Indices of positive anchor with shape - (num_pos,). - neg_inds (Tensor): Indices of negative anchor with shape - (num_neg,). + + - anchors (Tensor): All anchors in the image with shape (N, 4). + - labels (Tensor): Labels of all anchors in the image with + shape (N,). + - label_weights (Tensor): Label weights of all anchor in the + image with shape (N,). + - bbox_targets (Tensor): BBox targets of all anchors in the + image with shape (N, 4). + - bbox_weights (Tensor): BBox weights of all anchors in the + image with shape (N, 4). + - pos_inds (Tensor): Indices of positive anchor with shape + (num_pos,). + - neg_inds (Tensor): Indices of negative anchor with shape + (num_neg,). + - sampling_result (:obj:`SamplingResult`): Sampling results. """ inside_flags = anchor_inside_flags(flat_anchors, valid_flags, img_meta['img_shape'][:2], - self.train_cfg.allowed_border) + self.train_cfg['allowed_border']) if not inside_flags.any(): - return (None, ) * 7 + raise ValueError( + 'There is no valid anchor inside the image boundary. Please ' + 'check the image size and anchor sizes, or set ' + '``allowed_border`` to -1 to skip the condition.') # assign gt and sample anchors anchors = flat_anchors[inside_flags, :] - num_level_anchors_inside = self.get_num_level_anchors_inside( num_level_anchors, inside_flags) - assign_result = self.assigner.assign(anchors, num_level_anchors_inside, - gt_bboxes, gt_bboxes_ignore, - gt_labels) - - sampling_result = self.sampler.sample(assign_result, anchors, - gt_bboxes) + pred_instances = InstanceData(priors=anchors) + assign_result = self.assigner.assign( + pred_instances=pred_instances, + num_level_priors=num_level_anchors_inside, + gt_instances=gt_instances, + gt_instances_ignore=gt_instances_ignore) + + sampling_result = self.sampler.sample( + assign_result=assign_result, + pred_instances=pred_instances, + gt_instances=gt_instances) num_valid_anchors = anchors.shape[0] bbox_targets = torch.zeros_like(anchors) @@ -613,17 +633,12 @@ def _get_target_single(self, pos_bbox_targets = sampling_result.pos_gt_bboxes bbox_targets[pos_inds, :] = pos_bbox_targets bbox_weights[pos_inds, :] = 1.0 - if gt_labels is None: - # Only rpn gives gt_labels as None - # Foreground is the first class - labels[pos_inds] = 0 - else: - labels[pos_inds] = gt_labels[ - sampling_result.pos_assigned_gt_inds] - if self.train_cfg.pos_weight <= 0: + + labels[pos_inds] = sampling_result.pos_gt_labels + if self.train_cfg['pos_weight'] <= 0: label_weights[pos_inds] = 1.0 else: - label_weights[pos_inds] = self.train_cfg.pos_weight + label_weights[pos_inds] = self.train_cfg['pos_weight'] if len(neg_inds) > 0: label_weights[neg_inds] = 1.0 @@ -639,9 +654,12 @@ def _get_target_single(self, bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) return (anchors, labels, label_weights, bbox_targets, bbox_weights, - pos_inds, neg_inds) + pos_inds, neg_inds, sampling_result) + + def get_num_level_anchors_inside(self, num_level_anchors: List[int], + inside_flags: Tensor) -> List[int]: + """Get the number of valid anchors in every level.""" - def get_num_level_anchors_inside(self, num_level_anchors, inside_flags): split_inside_flags = torch.split(inside_flags, num_level_anchors) num_level_anchors_inside = [ int(flags.sum()) for flags in split_inside_flags diff --git a/mmdet/models/dense_heads/lad_head.py b/mmdet/models/dense_heads/lad_head.py index f76b22df8db..b6e78e8ad09 100644 --- a/mmdet/models/dense_heads/lad_head.py +++ b/mmdet/models/dense_heads/lad_head.py @@ -2,12 +2,12 @@ from typing import List, Optional import torch -from mmcv.runner import force_fp32 -from mmengine.data import InstanceData from torch import Tensor -from mmdet.core import bbox_overlaps, multi_apply +from mmdet.core import (InstanceList, OptInstanceList, SampleList, + bbox_overlaps, multi_apply) from mmdet.registry import MODELS +from ..utils.misc import unpack_gt_instances from .paa_head import PAAHead, levels_to_images @@ -16,16 +16,14 @@ class LADHead(PAAHead): """Label Assignment Head from the paper: `Improving Object Detection by Label Assignment Distillation `_""" - @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'iou_preds')) def get_label_assignment( - self, - cls_scores: List[Tensor], - bbox_preds: List[Tensor], - iou_preds: List[Tensor], - batch_gt_instances: List[InstanceData], - batch_img_metas: List[dict], - batch_gt_instances_ignore: Optional[List[InstanceData]] = None - ) -> tuple: + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + iou_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> tuple: """Get label assignment (from teacher). Args: @@ -48,18 +46,18 @@ def get_label_assignment( Returns: tuple: Returns a tuple containing label assignment variables. - - labels (Tensor): Labels of all anchors, each with - shape (num_anchors,). - - labels_weight (Tensor): Label weights of all anchor. - each with shape (num_anchors,). - - bboxes_target (Tensor): BBox targets of all anchors. - each with shape (num_anchors, 4). - - bboxes_weight (Tensor): BBox weights of all anchors. - each with shape (num_anchors, 4). - - pos_inds_flatten (Tensor): Contains all index of positive - sample in all anchor. - - pos_anchors (Tensor): Positive anchors. - - num_pos (int): Number of positive anchors. + - labels (Tensor): Labels of all anchors, each with + shape (num_anchors,). + - labels_weight (Tensor): Label weights of all anchor. + each with shape (num_anchors,). + - bboxes_target (Tensor): BBox targets of all anchors. + each with shape (num_anchors, 4). + - bboxes_weight (Tensor): BBox weights of all anchors. + each with shape (num_anchors, 4). + - pos_inds_flatten (Tensor): Contains all index of positive + sample in all anchor. + - pos_anchors (Tensor): Positive anchors. + - num_pos (int): Number of positive anchors. """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] @@ -122,14 +120,8 @@ def get_label_assignment( pos_anchors, num_pos) return label_assignment_results - def forward_train( - self, - x: List[Tensor], - label_assignment_results: tuple, - batch_gt_instances: List[InstanceData], - batch_img_metas: List[dict], - batch_gt_instances_ignore: Optional[List[InstanceData]] = None, - **kwargs): + def loss(self, x: List[Tensor], label_assignment_results: tuple, + batch_data_samples: SampleList) -> dict: """Forward train with the available label assignment (student receives from teacher). @@ -137,36 +129,33 @@ def forward_train( x (list[Tensor]): Features from FPN. label_assignment_results (tuple): As the outputs defined in the function `self.get_label_assignment`. - batch_gt_instances (list[:obj:`InstanceData`]): Batch of - gt_instance. It usually includes ``bboxes`` and ``labels`` - attributes. - batch_img_metas (list[dict]): Meta information of each image, e.g., - image size, scaling factor, etc. - batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): - Batch of gt_instances_ignore. It includes ``bboxes`` attribute - data that is ignored during training and testing. - Defaults to None. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. Returns: losses: (dict[str, Tensor]): A dictionary of loss components. """ + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \ + = outputs + outs = self(x) loss_inputs = outs + (batch_gt_instances, batch_img_metas) - losses = self.loss( + losses = self.loss_by_feat( *loss_inputs, batch_gt_instances_ignore=batch_gt_instances_ignore, label_assignment_results=label_assignment_results) return losses - @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'iou_preds')) - def loss(self, - cls_scores: List[Tensor], - bbox_preds: List[Tensor], - iou_preds: List[Tensor], - batch_gt_instances: List[InstanceData], - batch_img_metas: List[dict], - batch_gt_instances_ignore: Optional[List[InstanceData]] = None, - label_assignment_results: Optional[tuple] = None) -> dict: + def loss_by_feat(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + iou_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + label_assignment_results: Optional[tuple] = None) -> dict: """Compute losses of the head. Args: diff --git a/mmdet/models/dense_heads/ld_head.py b/mmdet/models/dense_heads/ld_head.py index 3943e90baf4..270bab746cf 100644 --- a/mmdet/models/dense_heads/ld_head.py +++ b/mmdet/models/dense_heads/ld_head.py @@ -1,10 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + import torch -from mmcv.runner import force_fp32 +from torch import Tensor -from mmdet.core import bbox_overlaps, multi_apply, reduce_mean +from mmdet.core import (ConfigType, InstanceList, OptInstanceList, SampleList, + bbox_overlaps, multi_apply, reduce_mean) from mmdet.registry import MODELS -from ..builder import build_loss +from ..utils.misc import unpack_gt_instances from .gfl_head import GFLHead @@ -20,25 +23,30 @@ class LDHead(GFLHead): num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. - loss_ld (dict): Config of Localization Distillation Loss (LD), - T is the temperature for distillation. + loss_ld (:obj:`ConfigDict` or dict): Config of Localization + Distillation Loss (LD), T is the temperature for distillation. """ def __init__(self, - num_classes, - in_channels, - loss_ld=dict( + num_classes: int, + in_channels: int, + loss_ld: ConfigType = dict( type='LocalizationDistillationLoss', loss_weight=0.25, T=10), - **kwargs): + **kwargs) -> dict: - super(LDHead, self).__init__(num_classes, in_channels, **kwargs) - self.loss_ld = build_loss(loss_ld) + super().__init__( + num_classes=num_classes, in_channels=in_channels, **kwargs) + self.loss_ld = MODELS.build(loss_ld) - def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights, - bbox_targets, stride, soft_targets, num_total_samples): - """Compute loss of a single scale level. + def loss_by_feat_single(self, anchors: Tensor, cls_score: Tensor, + bbox_pred: Tensor, labels: Tensor, + label_weights: Tensor, bbox_targets: Tensor, + stride: Tuple[int], soft_targets: Tensor, + avg_factor: int): + """Calculate the loss of a single scale level based on the features + extracted by the detection head. Args: anchors (Tensor): Box reference for each scale level with shape @@ -55,8 +63,12 @@ def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights, bbox_targets (Tensor): BBox regression targets of each anchor weight shape (N, num_total_anchors, 4). stride (tuple): Stride in this scale level. - num_total_samples (int): Number of positive samples that is - reduced over all GPUs. + soft_targets (Tensor): Soft BBox regression targets. + avg_factor (int): Average factor that is used to average + the loss. When using sampling method, avg_factor is usually + the sum of positive and negative priors. When using + `PseudoSampler`, `avg_factor` is usually equal to the number + of positive priors. Returns: dict[tuple, Tensor]: Loss components and weight targets. @@ -136,32 +148,19 @@ def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights, loss_cls = self.loss_cls( cls_score, (labels, score), weight=label_weights, - avg_factor=num_total_samples) + avg_factor=avg_factor) return loss_cls, loss_bbox, loss_dfl, loss_ld, weight_targets.sum() - def forward_train(self, - x, - out_teacher, - img_metas, - gt_bboxes, - gt_labels=None, - gt_bboxes_ignore=None, - proposal_cfg=None, - **kwargs): + def loss(self, x: List[Tensor], out_teacher: Tuple[Tensor], + batch_data_samples: SampleList) -> dict: """ Args: x (list[Tensor]): Features from FPN. - img_metas (list[dict]): Meta information of each image, e.g., - image size, scaling factor, etc. - gt_bboxes (Tensor): Ground truth bboxes of the image, - shape (num_gts, 4). - gt_labels (Tensor): Ground truth labels of each box, - shape (num_gts,). - gt_bboxes_ignore (Tensor): Ground truth bboxes to be - ignored, shape (num_ignored_gts, 4). - proposal_cfg (mmcv.Config): Test / postprocessing configuration, - if None, test_cfg would be used + out_teacher (tuple[Tensor]): The output of teacher. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. Returns: tuple[dict, list]: The loss components and proposals of each image. @@ -169,28 +168,27 @@ def forward_train(self, - losses (dict[str, Tensor]): A dictionary of loss components. - proposal_list (list[Tensor]): Proposals of each image. """ + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \ + = outputs + outs = self(x) - soft_target = out_teacher[1] - if gt_labels is None: - loss_inputs = outs + (gt_bboxes, soft_target, img_metas) - else: - loss_inputs = outs + (gt_bboxes, gt_labels, soft_target, img_metas) - losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) - if proposal_cfg is None: - return losses - else: - proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg) - return losses, proposal_list - - @force_fp32(apply_to=('cls_scores', 'bbox_preds')) - def loss(self, - cls_scores, - bbox_preds, - gt_bboxes, - gt_labels, - soft_target, - img_metas, - gt_bboxes_ignore=None): + soft_targets = out_teacher[1] + loss_inputs = outs + (batch_gt_instances, batch_img_metas, + soft_targets) + losses = self.loss_by_feat( + *loss_inputs, batch_gt_instances_ignore=batch_gt_instances_ignore) + + return losses + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + soft_targets: List[Tensor], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: """Compute losses of the head. Args: @@ -199,13 +197,16 @@ def loss(self, bbox_preds (list[Tensor]): Box distribution logits for each scale level with shape (N, 4*(n+1), H, W), n is max value of integral set. - gt_bboxes (list[Tensor]): Ground truth bboxes for each image with - shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. - gt_labels (list[Tensor]): class indices corresponding to each box - img_metas (list[dict]): Meta information of each image, e.g., + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + soft_targets (list[Tensor]): Soft BBox regression targets. + batch_img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. - gt_bboxes_ignore (list[Tensor] | None): specify which bounding - boxes can be ignored when computing the loss. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. Returns: dict[str, Tensor]: A dictionary of loss components. @@ -216,31 +217,24 @@ def loss(self, device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors( - featmap_sizes, img_metas, device=device) - label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 + featmap_sizes, batch_img_metas, device=device) cls_reg_targets = self.get_targets( anchor_list, valid_flag_list, - gt_bboxes, - img_metas, - gt_bboxes_ignore_list=gt_bboxes_ignore, - gt_labels_list=gt_labels, - label_channels=label_channels) - if cls_reg_targets is None: - return None + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) (anchor_list, labels_list, label_weights_list, bbox_targets_list, - bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets + bbox_weights_list, avg_factor) = cls_reg_targets - num_total_samples = reduce_mean( - torch.tensor(num_total_pos, dtype=torch.float, - device=device)).item() - num_total_samples = max(num_total_samples, 1.0) + avg_factor = reduce_mean( + torch.tensor(avg_factor, dtype=torch.float, device=device)).item() losses_cls, losses_bbox, losses_dfl, losses_ld, \ avg_factor = multi_apply( - self.loss_single, + self.loss_by_feat_single, anchor_list, cls_scores, bbox_preds, @@ -248,8 +242,8 @@ def loss(self, label_weights_list, bbox_targets_list, self.prior_generator.strides, - soft_target, - num_total_samples=num_total_samples) + soft_targets, + avg_factor=avg_factor) avg_factor = sum(avg_factor) + 1e-6 avg_factor = reduce_mean(avg_factor).item() diff --git a/mmdet/models/detectors/gfl.py b/mmdet/models/detectors/gfl.py index c39cada4c6b..56b8cdc1125 100644 --- a/mmdet/models/detectors/gfl.py +++ b/mmdet/models/detectors/gfl.py @@ -1,18 +1,41 @@ # Copyright (c) OpenMMLab. All rights reserved. +from mmdet.core import ConfigType, OptConfigType, OptMultiConfig from mmdet.registry import MODELS from .single_stage import SingleStageDetector @MODELS.register_module() class GFL(SingleStageDetector): + """Implementation of `GFL `_ + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone module. + neck (:obj:`ConfigDict` or dict): The neck module. + bbox_head (:obj:`ConfigDict` or dict): The bbox head module. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of GFL. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of GFL. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ def __init__(self, - backbone, - neck, - bbox_head, - train_cfg=None, - test_cfg=None, - pretrained=None, - init_cfg=None): - super(GFL, self).__init__(backbone, neck, bbox_head, train_cfg, - test_cfg, pretrained, init_cfg) + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/kd_one_stage.py b/mmdet/models/detectors/kd_one_stage.py index 2da5e889bc4..a1e47a69aa0 100644 --- a/mmdet/models/detectors/kd_one_stage.py +++ b/mmdet/models/detectors/kd_one_stage.py @@ -19,10 +19,23 @@ class KnowledgeDistillationSingleStageDetector(SingleStageDetector): `_. Args: + backbone (:obj:`ConfigDict` or dict): The backbone module. + neck (:obj:`ConfigDict` or dict): The neck module. + bbox_head (:obj:`ConfigDict` or dict): The bbox head module. teacher_config (:obj:`ConfigDict` | dict | str | Path): Config file path or the config object of teacher model. teacher_ckpt (str, optional): Checkpoint path of teacher model. If left as None, the model will not load any weights. + Defaults to True. + eval_teacher (bool): Set the train mode for teacher. + Defaults to True. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of ATSS. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of ATSS. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. """ def __init__( @@ -35,7 +48,7 @@ def __init__( eval_teacher: bool = True, train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, - preprocess_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, ) -> None: super().__init__( backbone=backbone, @@ -43,7 +56,7 @@ def __init__( bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg, - preprocess_cfg=preprocess_cfg) + data_preprocessor=data_preprocessor) self.eval_teacher = eval_teacher # Build teacher model if isinstance(teacher_config, (str, Path)): @@ -53,8 +66,8 @@ def __init__( load_checkpoint( self.teacher_model, teacher_ckpt, map_location='cpu') - def forward_train(self, batch_inputs: Tensor, - batch_data_samples: SampleList, **kwargs) -> dict: + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> dict: """ Args: batch_inputs (Tensor): Input images of shape (N, C, H, W). @@ -70,8 +83,7 @@ def forward_train(self, batch_inputs: Tensor, with torch.no_grad(): teacher_x = self.teacher_model.extract_feat(batch_inputs) out_teacher = self.teacher_model.bbox_head(teacher_x) - losses = self.bbox_head.forward_train(x, out_teacher, - batch_data_samples) + losses = self.bbox_head.loss(x, out_teacher, batch_data_samples) return losses def cuda(self, device: Optional[str] = None) -> nn.Module: diff --git a/tests/test_models/test_dense_heads/test_gfl_head.py b/tests/test_models/test_dense_heads/test_gfl_head.py new file mode 100644 index 00000000000..873023fb4af --- /dev/null +++ b/tests/test_models/test_dense_heads/test_gfl_head.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +from mmengine import Config +from mmengine.data import InstanceData + +from mmdet import * # noqa +from mmdet.models.dense_heads import GFLHead + + +class TestGFLHead(TestCase): + + def test_gfl_head_loss(self): + """Tests gfl head loss when truth is empty and non-empty.""" + s = 256 + img_metas = [{ + 'img_shape': (s, s, 3), + 'pad_shape': (s, s, 3), + 'scale_factor': 1 + }] + train_cfg = Config( + dict( + assigner=dict(type='ATSSAssigner', topk=9), + allowed_border=-1, + pos_weight=-1, + debug=False)) + gfl_head = GFLHead( + num_classes=4, + in_channels=1, + stacked_convs=1, + train_cfg=train_cfg, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0)) + feat = [ + torch.rand(1, 1, s // feat_size, s // feat_size) + for feat_size in [4, 8, 16, 32, 64] + ] + cls_scores, bbox_preds = gfl_head.forward(feat) + + # Test that empty ground truth encourages the network to predict + # background + gt_instances = InstanceData() + gt_instances.bboxes = torch.empty((0, 4)) + gt_instances.labels = torch.LongTensor([]) + + empty_gt_losses = gfl_head.loss_by_feat(cls_scores, bbox_preds, + [gt_instances], img_metas) + # When there is no truth, the cls loss should be nonzero but there + # should be no box loss. + empty_cls_loss = sum(empty_gt_losses['loss_cls']) + empty_box_loss = sum(empty_gt_losses['loss_bbox']) + empty_dfl_loss = sum(empty_gt_losses['loss_dfl']) + self.assertGreater(empty_cls_loss.item(), 0, + 'cls loss should be non-zero') + self.assertEqual( + empty_box_loss.item(), 0, + 'there should be no box loss when there are no true boxes') + self.assertEqual( + empty_dfl_loss.item(), 0, + 'there should be no dfl loss when there are no true boxes') + + # When truth is non-empty then both cls and box loss should be nonzero + # for random inputs + gt_instances = InstanceData() + gt_instances.bboxes = torch.Tensor( + [[23.6667, 23.8757, 238.6326, 151.8874]]) + gt_instances.labels = torch.LongTensor([2]) + one_gt_losses = gfl_head.loss_by_feat(cls_scores, bbox_preds, + [gt_instances], img_metas) + onegt_cls_loss = sum(one_gt_losses['loss_cls']) + onegt_box_loss = sum(one_gt_losses['loss_bbox']) + onegt_dfl_loss = sum(one_gt_losses['loss_dfl']) + self.assertGreater(onegt_cls_loss.item(), 0, + 'cls loss should be non-zero') + self.assertGreater(onegt_box_loss.item(), 0, + 'box loss should be non-zero') + self.assertGreater(onegt_dfl_loss.item(), 0, + 'dfl loss should be non-zero') diff --git a/tests/test_models/test_dense_heads/test_lad_head.py b/tests/test_models/test_dense_heads/test_lad_head.py index 018813b65da..3d314317a61 100644 --- a/tests/test_models/test_dense_heads/test_lad_head.py +++ b/tests/test_models/test_dense_heads/test_lad_head.py @@ -37,7 +37,8 @@ def score_samples(self, loss): s = 256 img_metas = [{ 'img_shape': (s, s, 3), - 'scale_factor': 1, + 'pad_shape': (s, s, 3), + 'scale_factor': 1 }] train_cfg = Config( dict( @@ -90,9 +91,9 @@ def score_samples(self, loss): batch_gt_instances_ignore) outs = teacher_model(feat) - empty_gt_losses = lad.loss(*outs, [gt_instances], img_metas, - batch_gt_instances_ignore, - label_assignment_results) + empty_gt_losses = lad.loss_by_feat(*outs, [gt_instances], img_metas, + batch_gt_instances_ignore, + label_assignment_results) # When there is no truth, the cls loss should be nonzero but there # should be no box loss. empty_cls_loss = empty_gt_losses['loss_cls'] @@ -118,9 +119,9 @@ def score_samples(self, loss): label_assignment_results = teacher_model.get_label_assignment( *outs_teacher, [gt_instances], img_metas, batch_gt_instances_ignore) - one_gt_losses = lad.loss(*outs, [gt_instances], img_metas, - batch_gt_instances_ignore, - label_assignment_results) + one_gt_losses = lad.loss_by_feat(*outs, [gt_instances], img_metas, + batch_gt_instances_ignore, + label_assignment_results) onegt_cls_loss = one_gt_losses['loss_cls'] onegt_box_loss = one_gt_losses['loss_bbox'] onegt_iou_loss = one_gt_losses['loss_iou'] @@ -163,5 +164,5 @@ def score_samples(self, loss): nms=dict(type='nms', iou_threshold=0.6), max_per_img=100)) rescale = False - lad.get_results( + lad.predict_by_feat( cls_scores, bbox_preds, iou_preds, img_metas, cfg, rescale=rescale) diff --git a/tests/test_models/test_dense_heads/test_ld_head.py b/tests/test_models/test_dense_heads/test_ld_head.py new file mode 100644 index 00000000000..bccaf8eb7b2 --- /dev/null +++ b/tests/test_models/test_dense_heads/test_ld_head.py @@ -0,0 +1,149 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +from mmengine import Config +from mmengine.data import InstanceData + +from mmdet import * # noqa +from mmdet.models.dense_heads import GFLHead, LDHead + + +class TestLDHead(TestCase): + + def test_ld_head_loss(self): + """Tests ld head loss when truth is empty and non-empty.""" + s = 256 + img_metas = [{ + 'img_shape': (s, s, 3), + 'pad_shape': (s, s, 3), + 'scale_factor': 1 + }] + train_cfg = Config( + dict( + assigner=dict(type='ATSSAssigner', topk=9, ignore_iof_thr=0.1), + allowed_border=-1, + pos_weight=-1, + debug=False)) + + ld_head = LDHead( + num_classes=4, + in_channels=1, + train_cfg=train_cfg, + loss_ld=dict( + type='KnowledgeDistillationKLDivLoss', loss_weight=1.0), + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128])) + + teacher_model = GFLHead( + num_classes=4, + in_channels=1, + train_cfg=train_cfg, + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128])) + + feat = [ + torch.rand(1, 1, s // feat_size, s // feat_size) + for feat_size in [4, 8, 16, 32, 64] + ] + cls_scores, bbox_preds = ld_head.forward(feat) + rand_soft_target = teacher_model.forward(feat)[1] + + # Test that empty ground truth encourages the network to predict + # background + + gt_instances = InstanceData() + gt_instances.bboxes = torch.empty((0, 4)) + gt_instances.labels = torch.LongTensor([]) + batch_gt_instances_ignore = None + + empty_gt_losses = ld_head.loss_by_feat(cls_scores, bbox_preds, + [gt_instances], img_metas, + rand_soft_target, + batch_gt_instances_ignore) + + # When there is no truth, the cls loss should be nonzero, ld loss + # should be non-negative but there should be no box loss. + empty_cls_loss = sum(empty_gt_losses['loss_cls']) + empty_box_loss = sum(empty_gt_losses['loss_bbox']) + empty_ld_loss = sum(empty_gt_losses['loss_ld']) + self.assertGreater(empty_cls_loss.item(), 0, + 'cls loss should be non-zero') + self.assertEqual( + empty_box_loss.item(), 0, + 'there should be no box loss when there are no true boxes') + self.assertGreaterEqual(empty_ld_loss.item(), 0, + 'ld loss should be non-negative') + + # When truth is non-empty then both cls and box loss should be nonzero + # for random inputs + gt_instances = InstanceData() + gt_instances.bboxes = torch.Tensor( + [[23.6667, 23.8757, 238.6326, 151.8874]]) + gt_instances.labels = torch.LongTensor([2]) + batch_gt_instances_ignore = None + + one_gt_losses = ld_head.loss_by_feat(cls_scores, bbox_preds, + [gt_instances], img_metas, + rand_soft_target, + batch_gt_instances_ignore) + onegt_cls_loss = sum(one_gt_losses['loss_cls']) + onegt_box_loss = sum(one_gt_losses['loss_bbox']) + + self.assertGreater(onegt_cls_loss.item(), 0, + 'cls loss should be non-zero') + self.assertGreater(onegt_box_loss.item(), 0, + 'box loss should be non-zero') + + batch_gt_instances_ignore = gt_instances + + # When truth is non-empty but ignored then the cls loss should be + # nonzero, but there should be no box loss. + ignore_gt_losses = ld_head.loss_by_feat(cls_scores, bbox_preds, + [gt_instances], img_metas, + rand_soft_target, + batch_gt_instances_ignore) + ignore_cls_loss = sum(ignore_gt_losses['loss_cls']) + ignore_box_loss = sum(ignore_gt_losses['loss_bbox']) + + self.assertGreater(ignore_cls_loss.item(), 0, + 'cls loss should be non-zero') + self.assertEqual(ignore_box_loss.item(), 0, + 'gt bbox ignored loss should be zero') + + # When truth is non-empty and not ignored then both cls and box loss + # should be nonzero for random inputs + batch_gt_instances_ignore = InstanceData() + batch_gt_instances_ignore.bboxes = torch.randn(1, 4) + + not_ignore_gt_losses = ld_head.loss_by_feat(cls_scores, bbox_preds, + [gt_instances], img_metas, + rand_soft_target, + batch_gt_instances_ignore) + not_ignore_cls_loss = sum(not_ignore_gt_losses['loss_cls']) + not_ignore_box_loss = sum(not_ignore_gt_losses['loss_bbox']) + + self.assertGreater(not_ignore_cls_loss.item(), 0, + 'cls loss should be non-zero') + self.assertGreaterEqual(not_ignore_box_loss.item(), 0, + 'gt bbox not ignored loss should be non-zero') diff --git a/tests/test_models/test_detectors/test_kd_single_stage.py b/tests/test_models/test_detectors/test_kd_single_stage.py index 5b78179f4df..27a2f2fd09d 100644 --- a/tests/test_models/test_detectors/test_kd_single_stage.py +++ b/tests/test_models/test_detectors/test_kd_single_stage.py @@ -8,24 +8,26 @@ from mmdet import * # noqa from mmdet.core import DetDataSample from mmdet.testing import demo_mm_inputs, get_detector_cfg +from mmdet.utils import register_all_modules class TestKDSingleStageDetector(TestCase): - # TODO: waiting ``ld/ld_r18_gflv1_r101_fpn_coco_1x.py`` ready - @parameterized.expand(['lad/lad_r101_paa_r50_fpn_coco_1x.py']) + def setUp(self): + register_all_modules() + + @parameterized.expand(['ld/ld_r18_gflv1_r101_fpn_coco_1x.py']) def test_init(self, cfg_file): model = get_detector_cfg(cfg_file) model.backbone.init_cfg = None from mmdet.models import build_detector detector = build_detector(model) - assert detector.backbone - assert detector.neck - assert detector.bbox_head - assert detector.device.type == 'cpu' + self.assertTrue(detector.backbone) + self.assertTrue(detector.neck) + self.assertTrue(detector.bbox_head) - @parameterized.expand([('lad/lad_r101_paa_r50_fpn_coco_1x.py', ('cpu', + @parameterized.expand([('ld/ld_r18_gflv1_r101_fpn_coco_1x.py', ('cpu', 'cuda'))]) def test_single_stage_forward_train(self, cfg_file, devices): model = get_detector_cfg(cfg_file) @@ -42,21 +44,15 @@ def test_single_stage_forward_train(self, cfg_file, devices): return unittest.skip('test requires GPU and torch+cuda') detector = detector.cuda() - assert detector.device.type == device - packed_inputs = demo_mm_inputs(2, [[3, 128, 128], [3, 125, 130]]) + batch_inputs, data_samples = detector.data_preprocessor( + packed_inputs, True) # Test forward train - losses = detector.forward(packed_inputs, return_loss=True) - assert isinstance(losses, dict) + losses = detector.forward(batch_inputs, data_samples, mode='loss') + self.assertIsInstance(losses, dict) - # Test forward_dummy - batch = torch.ones((1, 3, 64, 64)).to(device=device) - out = detector.forward_dummy(batch) - assert isinstance(out, tuple) - assert len(out) == 3 - - @parameterized.expand([('lad/lad_r101_paa_r50_fpn_coco_1x.py', ('cpu', + @parameterized.expand([('ld/ld_r18_gflv1_r101_fpn_coco_1x.py', ('cpu', 'cuda'))]) def test_single_stage_forward_test(self, cfg_file, devices): model = get_detector_cfg(cfg_file) @@ -73,14 +69,14 @@ def test_single_stage_forward_test(self, cfg_file, devices): return unittest.skip('test requires GPU and torch+cuda') detector = detector.cuda() - assert detector.device.type == device - packed_inputs = demo_mm_inputs(2, [[3, 128, 128], [3, 125, 130]]) + batch_inputs, data_samples = detector.data_preprocessor( + packed_inputs, False) # Test forward test detector.eval() with torch.no_grad(): batch_results = detector.forward( - packed_inputs, return_loss=False) - assert len(batch_results) == 2 - assert isinstance(batch_results[0], DetDataSample) + batch_inputs, data_samples, mode='predict') + self.assertEqual(len(batch_results), 2) + self.assertIsInstance(batch_results[0], DetDataSample)