diff --git a/configs/deformable_detr/deformable-detr_r50_16xb2-50e_coco.py b/configs/deformable_detr/deformable-detr_r50_16xb2-50e_coco.py index 8f8dfb8ef03..0b68f806bf3 100644 --- a/configs/deformable_detr/deformable-detr_r50_16xb2-50e_coco.py +++ b/configs/deformable_detr/deformable-detr_r50_16xb2-50e_coco.py @@ -3,6 +3,10 @@ ] model = dict( type='DeformableDETR', + num_query=300, + num_feature_levels=4, + with_box_refine=False, + as_two_stage=False, data_preprocessor=dict( type='DetDataPreprocessor', mean=[123.675, 116.28, 103.53], @@ -27,50 +31,31 @@ act_cfg=None, norm_cfg=dict(type='GN', num_groups=32), num_outs=4), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=1024, ffn_drop=0.1))), + decoder=dict( # DeformableDetrTransformerDecoder + num_layers=6, + return_intermediate=True, + layer_cfg=dict( # DeformableDetrTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.1), + cross_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=1024, ffn_drop=0.1)), + post_norm_cfg=None), + positional_encoding_cfg=dict(num_feats=128, normalize=True, offset=-0.5), bbox_head=dict( type='DeformableDETRHead', - num_query=300, num_classes=80, - in_channels=2048, sync_cls_avg_factor=True, - as_two_stage=False, - transformer=dict( - type='DeformableDetrTransformer', - encoder=dict( - type='DetrTransformerEncoder', - num_layers=6, - transformerlayers=dict( - type='BaseTransformerLayer', - attn_cfgs=dict( - type='MultiScaleDeformableAttention', embed_dims=256), - feedforward_channels=1024, - ffn_dropout=0.1, - operation_order=('self_attn', 'norm', 'ffn', 'norm'))), - decoder=dict( - type='DeformableDetrTransformerDecoder', - num_layers=6, - return_intermediate=True, - transformerlayers=dict( - type='DetrTransformerDecoderLayer', - attn_cfgs=[ - dict( - type='MultiheadAttention', - embed_dims=256, - num_heads=8, - dropout=0.1), - dict( - type='MultiScaleDeformableAttention', - embed_dims=256) - ], - feedforward_channels=1024, - ffn_dropout=0.1, - operation_order=('self_attn', 'norm', 'cross_attn', 'norm', - 'ffn', 'norm')))), - positional_encoding=dict( - type='SinePositionalEncoding', - num_feats=128, - normalize=True, - offset=-0.5), loss_cls=dict( type='FocalLoss', use_sigmoid=True, diff --git a/configs/deformable_detr/deformable-detr_refine_r50_16xb2-50e_coco.py b/configs/deformable_detr/deformable-detr_refine_r50_16xb2-50e_coco.py index 8c31edb65cd..b968674f4a9 100644 --- a/configs/deformable_detr/deformable-detr_refine_r50_16xb2-50e_coco.py +++ b/configs/deformable_detr/deformable-detr_refine_r50_16xb2-50e_coco.py @@ -1,2 +1,2 @@ _base_ = 'deformable-detr_r50_16xb2-50e_coco.py' -model = dict(bbox_head=dict(with_box_refine=True)) +model = dict(with_box_refine=True) diff --git a/configs/deformable_detr/deformable-detr_refine_twostage_r50_16xb2-50e_coco.py b/configs/deformable_detr/deformable-detr_refine_twostage_r50_16xb2-50e_coco.py index 466e8d5c0f5..8286189d4b9 100644 --- a/configs/deformable_detr/deformable-detr_refine_twostage_r50_16xb2-50e_coco.py +++ b/configs/deformable_detr/deformable-detr_refine_twostage_r50_16xb2-50e_coco.py @@ -1,2 +1,2 @@ _base_ = 'deformable-detr_refine_r50_16xb2-50e_coco.py' -model = dict(bbox_head=dict(as_two_stage=True)) +model = dict(as_two_stage=True) diff --git a/configs/detr/detr_r18_8xb2-500e_coco.py b/configs/detr/detr_r18_8xb2-500e_coco.py index 8b5f108dc4b..305b9d6fee8 100644 --- a/configs/detr/detr_r18_8xb2-500e_coco.py +++ b/configs/detr/detr_r18_8xb2-500e_coco.py @@ -4,4 +4,4 @@ backbone=dict( depth=18, init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18')), - bbox_head=dict(in_channels=512)) + neck=dict(in_channels=[512])) diff --git a/configs/detr/detr_r50_8xb2-150e_coco.py b/configs/detr/detr_r50_8xb2-150e_coco.py index 8c2ad57568a..ec010397789 100644 --- a/configs/detr/detr_r50_8xb2-150e_coco.py +++ b/configs/detr/detr_r50_8xb2-150e_coco.py @@ -3,6 +3,7 @@ ] model = dict( type='DETR', + num_query=100, data_preprocessor=dict( type='DetDataPreprocessor', mean=[123.675, 116.28, 103.53], @@ -19,45 +20,50 @@ norm_eval=True, style='pytorch', init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='ChannelMapper', + in_channels=[2048], + kernel_size=1, + out_channels=256, + act_cfg=None, + norm_cfg=None, + num_outs=1), + encoder=dict( # DetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.1), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.1, + act_cfg=dict(type='ReLU', inplace=True)))), + decoder=dict( # DetrTransformerDecoder + num_layers=6, + layer_cfg=dict( # DetrTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.1), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.1), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.1, + act_cfg=dict(type='ReLU', inplace=True))), + return_intermediate=True), + positional_encoding_cfg=dict(num_feats=128, normalize=True), bbox_head=dict( type='DETRHead', num_classes=80, - in_channels=2048, - transformer=dict( - type='Transformer', - encoder=dict( - type='DetrTransformerEncoder', - num_layers=6, - transformerlayers=dict( - type='BaseTransformerLayer', - attn_cfgs=[ - dict( - type='MultiheadAttention', - embed_dims=256, - num_heads=8, - dropout=0.1) - ], - feedforward_channels=2048, - ffn_dropout=0.1, - operation_order=('self_attn', 'norm', 'ffn', 'norm'))), - decoder=dict( - type='DetrTransformerDecoder', - return_intermediate=True, - num_layers=6, - transformerlayers=dict( - type='DetrTransformerDecoderLayer', - attn_cfgs=dict( - type='MultiheadAttention', - embed_dims=256, - num_heads=8, - dropout=0.1), - feedforward_channels=2048, - ffn_dropout=0.1, - operation_order=('self_attn', 'norm', 'cross_attn', 'norm', - 'ffn', 'norm')), - )), - positional_encoding=dict( - type='SinePositionalEncoding', num_feats=128, normalize=True), + embed_dims=256, loss_cls=dict( type='CrossEntropyLoss', bg_cls_weight=0.1, diff --git a/mmdet/models/dense_heads/deformable_detr_head.py b/mmdet/models/dense_heads/deformable_detr_head.py index 39b36a8fad5..1165124053f 100644 --- a/mmdet/models/dense_heads/deformable_detr_head.py +++ b/mmdet/models/dense_heads/deformable_detr_head.py @@ -4,22 +4,21 @@ import torch import torch.nn as nn -import torch.nn.functional as F from mmcv.cnn import Linear from mmengine.model import bias_init_with_prob, constant_init from torch import Tensor from mmdet.registry import MODELS -from mmdet.utils import InstanceList, OptConfigType, OptInstanceList +from mmdet.structures import SampleList +from mmdet.utils import InstanceList, OptInstanceList from ..layers import inverse_sigmoid -from ..utils import multi_apply from .detr_head import DETRHead @MODELS.register_module() class DeformableDETRHead(DETRHead): - """Head of DeformDETR: Deformable DETR: Deformable Transformers for End-to- - End Object Detection. + r"""Head of DeformDETR: Deformable DETR: Deformable Transformers for + End-to-End Object Detection. Code is modified from the `official github repo `_. @@ -28,30 +27,28 @@ class DeformableDETRHead(DETRHead): `_ . Args: - with_box_refine (bool): Whether to refine the reference points - in the decoder. Defaults to False. - as_two_stage (bool) : Whether to generate the proposal from - the outputs of encoder. - transformer (obj:`ConfigDict`): ConfigDict is used for building - the Encoder and Decoder. + share_pred_layer (bool): Whether to share parameters for all the + prediction layers. Defaults to `False`. + num_pred_layer (int): The number of the prediction layers. + Defaults to 6. + as_two_stage (bool, optional): Whether to generate the proposal + from the outputs of encoder. Defaults to `False`. """ def __init__(self, *args, - with_box_refine: bool = False, + share_pred_layer: bool = False, + num_pred_layer: int = 6, as_two_stage: bool = False, - transformer: OptConfigType = None, **kwargs) -> None: - self.with_box_refine = with_box_refine + self.share_pred_layer = share_pred_layer + self.num_pred_layer = num_pred_layer self.as_two_stage = as_two_stage - if self.as_two_stage: - transformer['as_two_stage'] = self.as_two_stage - super().__init__(*args, transformer=transformer, **kwargs) + super().__init__(*args, **kwargs) def _init_layers(self) -> None: """Initialize classification branch and regression branch of head.""" - fc_cls = Linear(self.embed_dims, self.cls_out_channels) reg_branch = [] for _ in range(self.num_reg_fcs): @@ -60,31 +57,20 @@ def _init_layers(self) -> None: reg_branch.append(Linear(self.embed_dims, 4)) reg_branch = nn.Sequential(*reg_branch) - def _get_clones(module, N): - return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) - - # last reg_branch is used to generate proposal from - # encode feature map when as_two_stage is True. - num_pred = (self.transformer.decoder.num_layers + 1) if \ - self.as_two_stage else self.transformer.decoder.num_layers - - if self.with_box_refine: - self.cls_branches = _get_clones(fc_cls, num_pred) - self.reg_branches = _get_clones(reg_branch, num_pred) - else: - + if self.share_pred_layer: self.cls_branches = nn.ModuleList( - [fc_cls for _ in range(num_pred)]) + [fc_cls for _ in range(self.num_pred_layer)]) self.reg_branches = nn.ModuleList( - [reg_branch for _ in range(num_pred)]) - - if not self.as_two_stage: - self.query_embedding = nn.Embedding(self.num_query, - self.embed_dims * 2) + [reg_branch for _ in range(self.num_pred_layer)]) + else: + self.cls_branches = nn.ModuleList( + [copy.deepcopy(fc_cls) for _ in range(self.num_pred_layer)]) + self.reg_branches = nn.ModuleList([ + copy.deepcopy(reg_branch) for _ in range(self.num_pred_layer) + ]) def init_weights(self) -> None: - """Initialize weights of the DeformDETR head.""" - self.transformer.init_weights() + """Initialize weights of the Deformable DETR head.""" if self.loss_cls.use_sigmoid: bias_init = bias_init_with_prob(0.01) for m in self.cls_branches: @@ -96,120 +82,135 @@ def init_weights(self) -> None: for m in self.reg_branches: nn.init.constant_(m[-1].bias.data[2:], 0.0) - def forward(self, x: Tuple[Tensor], - batch_img_metas: List[dict]) -> Tuple[Tensor, ...]: + def forward(self, hidden_states: Tensor, + references: List[Tensor]) -> Tuple[Tensor]: """Forward function. Args: - x (tuple[Tensor]): Features from the upstream network, each is - a 4D-tensor. - batch_img_metas (list[dict]): Meta information of each image, e.g., - image size, scaling factor, etc. + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, num_query, bs, dim). + references (list[Tensor]): List of the reference from the decoder. + The first reference is the `init_reference` (initial) and the + other num_decoder_layers(6) references are `inter_references` + (intermediate). The `init_reference` has shape (bs, num_query, + 4) when `as_two_stage` of the detector is `True`, otherwise + (bs, num_query, 2). Each `inter_reference` has shape + (bs, num_query, 4) when `with_box_refine` of the detector is + `True`, otherwise (bs, num_query, 2). Returns: - tuple[Tensor]: - - - all_cls_scores (Tensor): Outputs from the classification head, - shape [nb_dec, bs, num_query, cls_out_channels]. - - cls_out_channels should includes background. - - all_bbox_preds (Tensor): Sigmoid outputs from the regression - head with normalized coordinate format (cx, cy, w, h). - Shape [nb_dec, bs, num_query, 4]. - - enc_outputs_class (Tensor): The score of each point on encode - feature map, has shape (N, h*w, num_class). Only when - as_two_stage is True it would be returned, otherwise `None` - would be returned. - - enc_outputs_coord (Tensor): The proposal generate from the - encode feature map, has shape (N, h*w, 4). Only when - as_two_stage is True it would be returned, otherwise `None` - would be returned. + tuple[Tensor]: results of head containing the following tensor. + + - all_layers_outputs_classes (Tensor): Outputs from the + classification head, has shape (num_decoder_layers, bs, + num_query, cls_out_channels). + - all_layers_outputs_coords (Tensor): Sigmoid outputs from the + regression head with normalized coordinate format (cx, cy, w, + h), has shape (num_decoder_layers, bs, num_query, 4). """ - - batch_size = x[0].size(0) - input_img_h, input_img_w = batch_img_metas[0]['batch_input_shape'] - img_masks = x[0].new_ones((batch_size, input_img_h, input_img_w)) - for img_id in range(batch_size): - img_h, img_w = batch_img_metas[img_id]['img_shape'] - img_masks[img_id, :img_h, :img_w] = 0 - - mlvl_masks = [] - mlvl_positional_encodings = [] - for feat in x: - mlvl_masks.append( - F.interpolate(img_masks[None], - size=feat.shape[-2:]).to(torch.bool).squeeze(0)) - mlvl_positional_encodings.append( - self.positional_encoding(mlvl_masks[-1])) - - query_embeds = None - if not self.as_two_stage: - query_embeds = self.query_embedding.weight - hs, init_reference, inter_references, \ - enc_outputs_class, enc_outputs_coord = self.transformer( - x, - mlvl_masks, - query_embeds, - mlvl_positional_encodings, - reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501 - cls_branches=self.cls_branches if self.as_two_stage else None # noqa:E501 - ) - hs = hs.permute(0, 2, 1, 3) - outputs_classes = [] - outputs_coords = [] - - for lvl in range(hs.shape[0]): - if lvl == 0: - reference = init_reference - else: - reference = inter_references[lvl - 1] - reference = inverse_sigmoid(reference) - outputs_class = self.cls_branches[lvl](hs[lvl]) - tmp = self.reg_branches[lvl](hs[lvl]) + # (num_decoder_layers, bs, num_query, dim) + hidden_states = hidden_states.permute(0, 2, 1, 3) + all_layers_outputs_classes = [] + all_layers_outputs_coords = [] + + for layer_id in range(hidden_states.shape[0]): + reference = inverse_sigmoid(references[layer_id]) + # NOTE The last reference will not be used. + hidden_state = hidden_states[layer_id] + outputs_class = self.cls_branches[layer_id](hidden_state) + tmp_reg_preds = self.reg_branches[layer_id](hidden_state) if reference.shape[-1] == 4: - tmp += reference + # When `layer` is 0 and `as_two_stage` of the detector + # is `True`, or when `layer` is greater than 0 and + # `with_box_refine` of the detector is `True`. + tmp_reg_preds += reference else: + # When `layer` is 0 and `as_two_stage` of the detector + # is `False`, or when `layer` is greater than 0 and + # `with_box_refine` of the detector is `False`. assert reference.shape[-1] == 2 - tmp[..., :2] += reference - outputs_coord = tmp.sigmoid() - outputs_classes.append(outputs_class) - outputs_coords.append(outputs_coord) + tmp_reg_preds[..., :2] += reference + outputs_coord = tmp_reg_preds.sigmoid() + all_layers_outputs_classes.append(outputs_class) + all_layers_outputs_coords.append(outputs_coord) - outputs_classes = torch.stack(outputs_classes) - outputs_coords = torch.stack(outputs_coords) - if self.as_two_stage: - return outputs_classes, outputs_coords, \ - enc_outputs_class, \ - enc_outputs_coord.sigmoid() - else: - return outputs_classes, outputs_coords + all_layers_outputs_classes = torch.stack(all_layers_outputs_classes) + all_layers_outputs_coords = torch.stack(all_layers_outputs_coords) + + return all_layers_outputs_classes, all_layers_outputs_coords + + def loss(self, hidden_states: Tensor, references: List[Tensor], + enc_outputs_class: Tensor, enc_outputs_coord: Tensor, + batch_data_samples: SampleList) -> dict: + """Perform forward propagation and loss calculation of the detection + head on the queries of the upstream network. + + Args: + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, num_query, bs, dim). + references (list[Tensor]): List of the reference from the decoder. + The first reference is the `init_reference` (initial) and the + other num_decoder_layers(6) references are `inter_references` + (intermediate). The `init_reference` has shape (bs, num_query, + 4) when `as_two_stage` of the detector is `True`, otherwise + (bs, num_query, 2). Each `inter_reference` has shape + (bs, num_query, 4) when `with_box_refine` of the detector is + `True`, otherwise (bs, num_query, 2). + enc_outputs_class (Tensor): The score of each point on encode + feature map, has shape (bs, num_feat, cls_out_channels). + Only when `as_two_stage` is `True` it would be returned, + otherwise `None` would be returned. + enc_outputs_coord (Tensor): The proposal generate from the + encode feature map, has shape (bs, num_feat, 4). Only when + `as_two_stage` is `True` it would be returned, otherwise + `None` would be returned. + batch_data_samples (list[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components. + """ + batch_gt_instances = [] + batch_img_metas = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + + outs = self(hidden_states, references) + loss_inputs = outs + (enc_outputs_class, enc_outputs_coord, + batch_gt_instances, batch_img_metas) + losses = self.loss_by_feat(*loss_inputs) + return losses def loss_by_feat( self, - all_cls_scores: Tensor, - all_bbox_preds: Tensor, + all_layers_cls_scores: Tensor, + all_layers_bbox_preds: Tensor, enc_cls_scores: Tensor, enc_bbox_preds: Tensor, batch_gt_instances: InstanceList, batch_img_metas: List[dict], batch_gt_instances_ignore: OptInstanceList = None ) -> Dict[str, Tensor]: - """"Loss function. + """Loss function. Args: - all_cls_scores (Tensor): Classification score of all - decoder layers, has shape - [nb_dec, bs, num_query, cls_out_channels]. - all_bbox_preds (Tensor): Sigmoid regression - outputs of all decode layers. Each is a 4D-tensor with - normalized coordinate format (cx, cy, w, h) and shape - [nb_dec, bs, num_query, 4]. - enc_cls_scores (Tensor): Classification scores of - points on encode feature map , has shape - (N, h*w, num_classes). Only be passed when as_two_stage is - True, otherwise is None. - enc_bbox_preds (Tensor): Regression results of each points - on the encode feature map, has shape (N, h*w, 4). Only be - passed when as_two_stage is True, otherwise is None. + all_layers_cls_scores (Tensor): Classification scores of all + decoder layers, has shape (num_decoder_layers, bs, num_query, + cls_out_channels). + all_layers_bbox_preds (Tensor): Regression outputs of all decode + layers. Each is a 4D-tensor with normalized coordinate format + (cx, cy, w, h) and has shape (num_decoder_layers, bs, + num_query, 4). + enc_cls_scores (Tensor): The score of each point on encode + feature map, has shape (bs, num_feat, cls_out_channels). + Only when `as_two_stage` is `True` it would be returned, + otherwise `None` would be returned. + enc_bbox_preds (Tensor): The proposal generate from the + encode feature map, has shape (bs, num_feat, 4). Only when + `as_two_stage` is `True` it would be returned, otherwise + `None` would be returned. batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. @@ -223,87 +224,93 @@ def loss_by_feat( Returns: dict[str, Tensor]: A dictionary of loss components. """ - assert batch_gt_instances_ignore is None, \ - f'{self.__class__.__name__} only supports ' \ - f'for gt_bboxes_ignore setting to None.' + loss_dict = super().loss_by_feat(all_layers_cls_scores, + all_layers_bbox_preds, + batch_gt_instances, batch_img_metas, + batch_gt_instances_ignore) - num_dec_layers = len(all_cls_scores) - batch_gt_instances_list = [ - batch_gt_instances for _ in range(num_dec_layers) - ] - batch_img_metas_list = [batch_img_metas for _ in range(num_dec_layers)] - - losses_cls, losses_bbox, losses_iou = multi_apply( - self.loss_by_feat_single, all_cls_scores, all_bbox_preds, - batch_gt_instances_list, batch_img_metas_list) - - loss_dict = dict() # loss of proposal generated from encode feature map. if enc_cls_scores is not None: - for i in range(len(batch_img_metas)): - batch_gt_instances[i].labels = torch.zeros_like( - batch_gt_instances[i].labels) + proposal_gt_instances = copy.deepcopy(batch_gt_instances) + for i in range(len(proposal_gt_instances)): + proposal_gt_instances[i].labels = torch.zeros_like( + proposal_gt_instances[i].labels) enc_loss_cls, enc_losses_bbox, enc_losses_iou = \ - self.loss_by_feat_single(enc_cls_scores, enc_bbox_preds, - batch_gt_instances, batch_img_metas) + self.loss_by_feat_single( + enc_cls_scores, enc_bbox_preds, + batch_gt_instances=proposal_gt_instances, + batch_img_metas=batch_img_metas) loss_dict['enc_loss_cls'] = enc_loss_cls loss_dict['enc_loss_bbox'] = enc_losses_bbox loss_dict['enc_loss_iou'] = enc_losses_iou - - # loss from the last decoder layer - loss_dict['loss_cls'] = losses_cls[-1] - loss_dict['loss_bbox'] = losses_bbox[-1] - loss_dict['loss_iou'] = losses_iou[-1] - # loss from other decoder layers - num_dec_layer = 0 - for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1], - losses_bbox[:-1], - losses_iou[:-1]): - loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i - loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i - loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i - num_dec_layer += 1 return loss_dict + def predict(self, + hidden_states: Tensor, + references: List[Tensor], + batch_data_samples: SampleList, + rescale: bool = True, + **kwargs) -> InstanceList: + """Perform forward propagation and loss calculation of the detection + head on the queries of the upstream network. + + Args: + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, num_query, bs, dim). + references (list[Tensor]): List of the reference from the decoder. + The first reference is the `init_reference` (initial) and the + other num_decoder_layers(6) references are `inter_references` + (intermediate). The `init_reference` has shape (bs, num_query, + 4) when `as_two_stage` of the detector is `True`, otherwise + (bs, num_query, 2). Each `inter_reference` has shape + (bs, num_query, 4) when `with_box_refine` of the detector is + `True`, otherwise (bs, num_query, 2). + batch_data_samples (list[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool, optional): If `True`, return boxes in original + image space. Defaults to `True`. + + Returns: + list[obj:`InstanceData`]: Detection results of each image + after the post process. + """ + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + + outs = self(hidden_states, references) + + predictions = self.predict_by_feat( + *outs, batch_img_metas=batch_img_metas, rescale=rescale) + return predictions + def predict_by_feat(self, - all_cls_scores: Tensor, - all_bbox_preds: Tensor, - enc_cls_scores: Tensor, - enc_bbox_preds: Tensor, + all_layers_cls_scores: Tensor, + all_layers_bbox_preds: Tensor, batch_img_metas: List[Dict], rescale: bool = False) -> InstanceList: """Transform a batch of output features extracted from the head into bbox results. Args: - all_cls_scores (Tensor): Classification score of all - decoder layers, has shape - [nb_dec, bs, num_query, cls_out_channels]. - all_bbox_preds (Tensor): Sigmoid regression - outputs of all decode layers. Each is a 4D-tensor with - normalized coordinate format (cx, cy, w, h) and shape - [nb_dec, bs, num_query, 4]. - enc_cls_scores (Tensor): Classification scores of - points on encode feature map , has shape - (N, h*w, num_classes). Only be passed when as_two_stage is - True, otherwise is None. - enc_bbox_preds (Tensor): Regression results of each points - on the encode feature map, has shape (N, h*w, 4). Only be - passed when as_two_stage is True, otherwise is None. + all_layers_cls_scores (Tensor): Classification scores of all + decoder layers, has shape (num_decoder_layers, bs, num_query, + cls_out_channels). + all_layers_bbox_preds (Tensor): Regression outputs of all decode + layers. Each is a 4D-tensor with normalized coordinate format + (cx, cy, w, h) and shape (num_decoder_layers, bs, + num_query, 4). batch_img_metas (list[dict]): Meta information of each image. - rescale (bool, optional): If True, return boxes in original - image space. Default False. + rescale (bool, optional): If `True`, return boxes in original + image space. Default `False`. Returns: - list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. \ - The first item is an (n, 5) tensor, where the first 4 columns \ - are bounding box positions (tl_x, tl_y, br_x, br_y) and the \ - 5-th column is a score between 0 and 1. The second item is a \ - (n,) tensor where each item is the predicted class label of \ - the corresponding box. + list[obj:`InstanceData`]: Detection results of each image + after the post process. """ - cls_scores = all_cls_scores[-1] - bbox_preds = all_bbox_preds[-1] + cls_scores = all_layers_cls_scores[-1] + bbox_preds = all_layers_bbox_preds[-1] result_list = [] for img_id in range(len(batch_img_metas)): diff --git a/mmdet/models/dense_heads/detr_head.py b/mmdet/models/dense_heads/detr_head.py index acf82a4c7bd..aafc3a3dbf0 100644 --- a/mmdet/models/dense_heads/detr_head.py +++ b/mmdet/models/dense_heads/detr_head.py @@ -1,55 +1,49 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Tuple import torch import torch.nn as nn import torch.nn.functional as F -from mmcv.cnn import Conv2d, Linear, build_activation_layer -from mmcv.cnn.bricks.transformer import FFN, build_positional_encoding +from mmcv.cnn import Linear +from mmcv.cnn.bricks.transformer import FFN +from mmengine.model import BaseModule from mmengine.structures import InstanceData from torch import Tensor from mmdet.registry import MODELS, TASK_UTILS from mmdet.structures import SampleList from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh -from mmdet.utils import (ConfigType, InstanceList, OptConfigType, - OptInstanceList, OptMultiConfig, reduce_mean) +from mmdet.utils import (ConfigType, InstanceList, OptInstanceList, + OptMultiConfig, reduce_mean) from ..utils import multi_apply -from .anchor_free_head import AnchorFreeHead @MODELS.register_module() -class DETRHead(AnchorFreeHead): - """Implements the DETR transformer head. +class DETRHead(BaseModule): + r"""Head of DETR. DETR:End-to-End Object Detection with Transformers. - See `paper: End-to-End Object Detection with Transformers - `_ for details. + More details can be found in the `paper + `_ . Args: num_classes (int): Number of categories excluding the background. - in_channels (int): Number of channels in the input feature map. - num_query (int): Number of query in Transformer. Defaults to 100. - num_reg_fcs (int): Number of fully-connected layers used in - `FFN`, which is then used for the regression head. - Defaults to 2. - transformer (:obj:`ConfigDict` or dict, optional): Config for - transformer. Defaults to None. - sync_cls_avg_factor (bool): Whether to sync the avg_factor of all - ranks. Defaults to False. - positional_encoding (:obj:`ConfigDict` or dict): Config for position - encoding. + embed_dims (int): The dims of Transformer embedding. + num_reg_fcs (int): Number of fully-connected layers used in `FFN`, + which is then used for the regression head. Defaults to 2. + sync_cls_avg_factor (bool): Whether to sync the `avg_factor` of + all ranks. Default to `False`. loss_cls (:obj:`ConfigDict` or dict): Config of the classification loss. Defaults to `CrossEntropyLoss`. - loss_bbox (:obj:`ConfigDict` or dict): Config of the regression loss. - Defaults to `L1Loss`. + loss_bbox (:obj:`ConfigDict` or dict): Config of the regression bbox + loss. Defaults to `L1Loss`. loss_iou (:obj:`ConfigDict` or dict): Config of the regression iou loss. Defaults to `GIoULoss`. - tran_cfg (:obj:`ConfigDict` or dict): Training config of transformer + train_cfg (:obj:`ConfigDict` or dict): Training config of transformer head. test_cfg (:obj:`ConfigDict` or dict): Testing config of transformer head. - init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ - dict], optional): Initialization config dict. Defaults to None. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. """ _version = 2 @@ -57,13 +51,9 @@ class DETRHead(AnchorFreeHead): def __init__( self, num_classes: int, - in_channels: int, - num_query: int = 100, + embed_dims: int = 256, num_reg_fcs: int = 2, - transformer: OptConfigType = None, sync_cls_avg_factor: bool = False, - positional_encoding: ConfigType = dict( - type='SinePositionalEncoding', num_feats=128, normalize=True), loss_cls: ConfigType = dict( type='CrossEntropyLoss', bg_cls_weight=0.1, @@ -82,10 +72,7 @@ def __init__( ])), test_cfg: ConfigType = dict(max_per_img=100), init_cfg: OptMultiConfig = None) -> None: - # NOTE here use `AnchorFreeHead` instead of `TransformerHead`, - # since it brings inconvenience when the initialization of - # `AnchorFreeHead` is called. - super(AnchorFreeHead, self).__init__(init_cfg=init_cfg) + super().__init__(init_cfg=init_cfg) self.bg_cls_weight = 0 self.sync_cls_avg_factor = sync_cls_avg_factor class_weight = loss_cls.get('class_weight', None) @@ -108,15 +95,14 @@ def __init__( self.bg_cls_weight = bg_cls_weight if train_cfg: - assert 'assigner' in train_cfg, 'assigner should be provided '\ - 'when train_cfg is set.' + assert 'assigner' in train_cfg, 'assigner should be provided ' \ + 'when train_cfg is set.' assigner = train_cfg['assigner'] self.assigner = TASK_UTILS.build(assigner) if train_cfg.get('sampler', None) is not None: raise RuntimeError('DETR do not build sampler.') - self.num_query = num_query self.num_classes = num_classes - self.in_channels = in_channels + self.embed_dims = embed_dims self.num_reg_fcs = num_reg_fcs self.train_cfg = train_cfg self.test_cfg = test_cfg @@ -128,153 +114,84 @@ def __init__( self.cls_out_channels = num_classes else: self.cls_out_channels = num_classes + 1 - self.act_cfg = transformer.get('act_cfg', - dict(type='ReLU', inplace=True)) - self.activate = build_activation_layer(self.act_cfg) - self.positional_encoding = build_positional_encoding( - positional_encoding) - self.transformer = MODELS.build(transformer) - self.embed_dims = self.transformer.embed_dims - assert 'num_feats' in positional_encoding - num_feats = positional_encoding['num_feats'] - assert num_feats * 2 == self.embed_dims, 'embed_dims should' \ - f' be exactly 2 times of num_feats. Found {self.embed_dims}' \ - f' and {num_feats}.' + self._init_layers() def _init_layers(self) -> None: """Initialize layers of the transformer head.""" - self.input_proj = Conv2d( - self.in_channels, self.embed_dims, kernel_size=1) + # cls branch self.fc_cls = Linear(self.embed_dims, self.cls_out_channels) + # reg branch + self.activate = nn.ReLU() self.reg_ffn = FFN( self.embed_dims, self.embed_dims, self.num_reg_fcs, - self.act_cfg, + dict(type='ReLU', inplace=True), dropout=0.0, add_residual=False) + # NOTE the activations of reg_branch here is the same as + # those in transformer, but they are actually different + # in DAB DETR (prelu in transformer and relu in reg_branch) self.fc_reg = Linear(self.embed_dims, 4) - self.query_embedding = nn.Embedding(self.num_query, self.embed_dims) - - def init_weights(self) -> None: - """Initialize weights of the transformer head.""" - # The initialization for transformer is important - self.transformer.init_weights() - - def _load_from_state_dict(self, state_dict: dict, prefix: str, - local_metadata: dict, strict: bool, - missing_keys: Union[List[str], str], - unexpected_keys: Union[List[str], str], - error_msgs: Union[List[str], str]) -> None: - """load checkpoints.""" - # NOTE here use `AnchorFreeHead` instead of `TransformerHead`, - # since `AnchorFreeHead._load_from_state_dict` should not be - # called here. Invoking the default `Module._load_from_state_dict` - # is enough. - - # Names of some parameters in has been changed. - version = local_metadata.get('version', None) - if (version is None or version < 2) and self.__class__ is DETRHead: - convert_dict = { - '.self_attn.': '.attentions.0.', - '.ffn.': '.ffns.0.', - '.multihead_attn.': '.attentions.1.', - '.decoder.norm.': '.decoder.post_norm.' - } - state_dict_keys = list(state_dict.keys()) - for k in state_dict_keys: - for ori_key, convert_key in convert_dict.items(): - if ori_key in k: - convert_key = k.replace(ori_key, convert_key) - state_dict[convert_key] = state_dict[k] - del state_dict[k] - - super(AnchorFreeHead, self)._load_from_state_dict( - state_dict=state_dict, - prefix=prefix, - local_metadata=local_metadata, - strict=strict, - missing_keys=missing_keys, - unexpected_keys=unexpected_keys, - error_msgs=error_msgs) - - def forward( - self, x: Tuple[Tensor], - batch_img_metas: List[dict]) -> Tuple[List[Tensor], List[Tensor]]: - """Forward function. - Args: - x (tuple[Tensor]): Features from the upstream network, each is - a 4D-tensor. - batch_img_metas (list[dict]): Meta information of each image, e.g., - image size, scaling factor, etc. + # Note function _load_from_state_dict is deleted without + # supporting refactor-DETR in mmdetection2.0 + def forward(self, hidden_states: Tensor) -> Tuple[Tensor]: + """"Forward function. + + Args: + hidden_states (Tensor): Features from transformer decoder. If + `return_intermediate_dec` in detr.py is True output has shape + (num_hidden_states, bs, num_query, dim), else has shape (1, + bs, num_query, dim) which only contains the last layer outputs. Returns: - tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels. - - - all_cls_scores_list (list[Tensor]): Classification scores \ - for each scale level. Each is a 4D-tensor with shape \ - [nb_dec, bs, num_query, cls_out_channels]. Note \ - `cls_out_channels` should includes background. - - all_bbox_preds_list (list[Tensor]): Sigmoid regression \ - outputs for each scale level. Each is a 4D-tensor with \ - normalized coordinate format (cx, cy, w, h) and shape \ - [nb_dec, bs, num_query, 4]. + tuple[Tensor]: results of head containing the following tensor. + + - layers_cls_scores (Tensor): Outputs from the classification head, + shape (num_hidden_states, bs, num_query, cls_out_channels). Note + cls_out_channels should include background. + - layers_bbox_preds (Tensor): Sigmoid outputs from the regression + head with normalized coordinate format (cx, cy, w, h), has shape + (num_hidden_states, bs, num_query, 4). """ - num_levels = len(x) - batch_img_metas_list = [batch_img_metas for _ in range(num_levels)] - return multi_apply(self.forward_single, x, batch_img_metas_list) + layers_cls_scores = self.fc_cls(hidden_states) + layers_bbox_preds = self.fc_reg( + self.activate(self.reg_ffn(hidden_states))).sigmoid() + return layers_cls_scores, layers_bbox_preds - def forward_single(self, x: Tensor, - batch_img_metas: List[dict]) -> Tuple[Tensor, Tensor]: - """"Forward function for a single feature level. + def loss(self, hidden_states: Tensor, + batch_data_samples: SampleList) -> dict: + """Perform forward propagation and loss calculation of the detection + head on the features of the upstream network. Args: - x (Tensor): Input feature from backbone's single stage, shape - [bs, c, h, w]. - batch_img_metas (list[dict]): Meta information of each image, e.g., - image size, scaling factor, etc. + hidden_states (Tensor): Feature from the transformer decoder, has + shape (num_decoder_layers, bs, num_query, cls_out_channels) or + (num_decoder_layers, num_query, bs, cls_out_channels). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. Returns: - tuple[Tensor]: - - - all_cls_scores (Tensor): Outputs from the classification head, \ - shape [nb_dec, bs, num_query, cls_out_channels]. Note \ - cls_out_channels should includes background. - - all_bbox_preds (Tensor): Sigmoid outputs from the regression \ - head with normalized coordinate format (cx, cy, w, h). \ - Shape [nb_dec, bs, num_query, 4]. + dict: A dictionary of loss components. """ - # construct binary masks which used for the transformer. - # NOTE following the official DETR repo, non-zero values representing - # ignored positions, while zero values means valid positions. - batch_size = x.size(0) - input_img_h, input_img_w = batch_img_metas[0]['batch_input_shape'] - masks = x.new_ones((batch_size, input_img_h, input_img_w)) - for img_id in range(batch_size): - img_h, img_w, = batch_img_metas[img_id]['img_shape'] - masks[img_id, :img_h, :img_w] = 0 - - x = self.input_proj(x) - # interpolate masks to have the same spatial shape with x - masks = F.interpolate( - masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1) - # position encoding - pos_embed = self.positional_encoding(masks) # [bs, embed_dim, h, w] - # outs_dec: [nb_dec, bs, num_query, embed_dim] - outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight, - pos_embed) - - all_cls_scores = self.fc_cls(outs_dec) - all_bbox_preds = self.fc_reg(self.activate( - self.reg_ffn(outs_dec))).sigmoid() - return all_cls_scores, all_bbox_preds + batch_gt_instances = [] + batch_img_metas = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + + outs = self(hidden_states) + loss_inputs = outs + (batch_gt_instances, batch_img_metas) + losses = self.loss_by_feat(*loss_inputs) + return losses def loss_by_feat( self, - all_cls_scores_list: List[Tensor], - all_bbox_preds_list: List[Tensor], + all_layers_cls_scores: Tensor, + all_layers_bbox_preds: Tensor, batch_gt_instances: InstanceList, batch_img_metas: List[dict], batch_gt_instances_ignore: OptInstanceList = None @@ -285,13 +202,13 @@ def loss_by_feat( losses by default. Args: - all_cls_scores_list (list[Tensor]): Classification outputs - for each feature level. Each is a 4D-tensor with shape - [nb_dec, bs, num_query, cls_out_channels]. - all_bbox_preds_list (list[Tensor]): Sigmoid regression - outputs for each feature level. Each is a 4D-tensor with + all_layers_cls_scores (Tensor): Classification outputs + of each decoder layers. Each is a 4D-tensor, has shape + (num_decoder_layers, bs, num_query, cls_out_channels). + all_layers_bbox_preds (Tensor): Sigmoid regression + outputs of each decoder layers. Each is a 4D-tensor with normalized coordinate format (cx, cy, w, h) and shape - [nb_dec, bs, num_query, 4]. + (num_decoder_layers, bs, num_query, 4). batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. @@ -305,21 +222,16 @@ def loss_by_feat( Returns: dict[str, Tensor]: A dictionary of loss components. """ - # NOTE defaultly only the outputs from the last feature scale is used. - all_cls_scores = all_cls_scores_list[-1] - all_bbox_preds = all_bbox_preds_list[-1] assert batch_gt_instances_ignore is None, \ - 'Only supports for batch_gt_instances_ignore setting to None.' - - num_dec_layers = len(all_cls_scores) - batch_gt_instances_list = [ - batch_gt_instances for _ in range(num_dec_layers) - ] - batch_img_metas_list = [batch_img_metas for _ in range(num_dec_layers)] + f'{self.__class__.__name__} only supports ' \ + 'for batch_gt_instances_ignore setting to None.' losses_cls, losses_bbox, losses_iou = multi_apply( - self.loss_by_feat_single, all_cls_scores, all_bbox_preds, - batch_gt_instances_list, batch_img_metas_list) + self.loss_by_feat_single, + all_layers_cls_scores, + all_layers_bbox_preds, + batch_gt_instances=batch_gt_instances, + batch_img_metas=batch_img_metas) loss_dict = dict() # loss from the last decoder layer @@ -328,9 +240,8 @@ def loss_by_feat( loss_dict['loss_iou'] = losses_iou[-1] # loss from other decoder layers num_dec_layer = 0 - for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1], - losses_bbox[:-1], - losses_iou[:-1]): + for loss_cls_i, loss_bbox_i, loss_iou_i in \ + zip(losses_cls[:-1], losses_bbox[:-1], losses_iou[:-1]): loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i @@ -345,10 +256,10 @@ def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor, Args: cls_scores (Tensor): Box score logits from a single decoder layer - for all images. Shape [bs, num_query, cls_out_channels]. + for all images, has shape (bs, num_query, cls_out_channels). bbox_preds (Tensor): Sigmoid outputs from a single decoder layer for all images, with normalized coordinate (cx, cy, w, h) and - shape [bs, num_query, 4]. + shape (bs, num_query, 4). batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. @@ -356,8 +267,8 @@ def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor, image size, scaling factor, etc. Returns: - Tupe[Tensor]: A tuple includes loss_cls, loss_box and - loss_iou. + Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and + `loss_iou`. """ num_imgs = cls_scores.size(0) cls_scores_list = [cls_scores[i] for i in range(num_imgs)] @@ -425,7 +336,7 @@ def get_targets(self, cls_scores_list: List[Tensor], Args: cls_scores_list (list[Tensor]): Box score logits from a single - decoder layer for each image with shape [num_query, + decoder layer for each image, has shape [num_query, cls_out_channels]. bbox_preds_list (list[Tensor]): Sigmoid outputs from a single decoder layer for each image, with normalized coordinate @@ -529,49 +440,18 @@ def _get_targets_single(self, cls_score: Tensor, bbox_pred: Tensor, return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds) - # over-write because img_metas are needed as inputs for bbox_head. - def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict: - """Perform forward propagation and loss calculation of the detection - head on the features of the upstream network. - - Args: - x (tuple[Tensor]): Features from the upstream network, each is - a 4D-tensor. - batch_data_samples (List[:obj:`DetDataSample`]): The Data - Samples. It usually includes information such as - `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. - - Returns: - dict: A dictionary of loss components. - """ - batch_gt_instances = [] - batch_img_metas = [] - for data_sample in batch_data_samples: - batch_img_metas.append(data_sample.metainfo) - batch_gt_instances.append(data_sample.gt_instances) - - outs = self(x, batch_img_metas) - loss_inputs = outs + (batch_gt_instances, batch_img_metas) - losses = self.loss_by_feat(*loss_inputs) - return losses - - def loss_and_predict(self, - x: Tuple[Tensor], - batch_data_samples: SampleList, - proposal_cfg: Optional[ConfigType] = None) \ - -> Tuple[dict, InstanceList]: + def loss_and_predict( + self, hidden_states: Tuple[Tensor], + batch_data_samples: SampleList) -> Tuple[dict, InstanceList]: """Perform forward propagation of the head, then calculate loss and predictions from the features and data samples. Over-write because img_metas are needed as inputs for bbox_head. Args: - x (tuple[Tensor]): Features from FPN. + hidden_states (tuple[Tensor]): Features from FPN. batch_data_samples (list[:obj:`DetDataSample`]): Each item contains the meta information of each image and corresponding annotations. - proposal_cfg (ConfigDict, optional): Test / postprocessing - configuration, if None, test_cfg would be used. - Defaults to None. Returns: tuple: the return value is a tuple contains: @@ -586,7 +466,7 @@ def loss_and_predict(self, batch_img_metas.append(data_sample.metainfo) batch_gt_instances.append(data_sample.gt_instances) - outs = self(x, batch_img_metas) + outs = self(hidden_states) loss_inputs = outs + (batch_gt_instances, batch_img_metas) losses = self.loss_by_feat(*loss_inputs) @@ -595,7 +475,7 @@ def loss_and_predict(self, return losses, predictions def predict(self, - x: Tuple[Tensor], + hidden_states: Tuple[Tensor], batch_data_samples: SampleList, rescale: bool = True) -> InstanceList: """Perform forward propagation of the detection head and predict @@ -603,7 +483,7 @@ def predict(self, because img_metas are needed as inputs for bbox_head. Args: - x (tuple[Tensor]): Multi-level features from the + hidden_states (tuple[Tensor]): Multi-level features from the upstream network, each is a 4D-tensor. batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as @@ -619,30 +499,32 @@ def predict(self, data_samples.metainfo for data_samples in batch_data_samples ] - outs = self(x, batch_img_metas) + last_layer_hidden_state = hidden_states[-1].unsqueeze(0) + outs = self(last_layer_hidden_state) predictions = self.predict_by_feat( *outs, batch_img_metas=batch_img_metas, rescale=rescale) + return predictions def predict_by_feat(self, - all_cls_scores_list: List[Tensor], - all_bbox_preds_list: List[Tensor], + layer_cls_scores: Tensor, + layer_bbox_preds: Tensor, batch_img_metas: List[dict], rescale: bool = True) -> InstanceList: """Transform network outputs for a batch into bbox predictions. Args: - all_cls_scores_list (list[Tensor]): Classification outputs - for each feature level. Each is a 4D-tensor with shape - [nb_dec, bs, num_query, cls_out_channels]. - all_bbox_preds_list (list[Tensor]): Sigmoid regression - outputs for each feature level. Each is a 4D-tensor with - normalized coordinate format (cx, cy, w, h) and shape - [nb_dec, bs, num_query, 4]. + layer_cls_scores (Tensor): Classification outputs of the last or + all decoder layer. Each is a 4D-tensor, has shape + (num_decoder_layers, bs, num_query, cls_out_channels). + layer_bbox_preds (Tensor): Sigmoid regression outputs of the last + or all decoder layer. Each is a 4D-tensor with normalized + coordinate format (cx, cy, w, h) and shape + (num_decoder_layers, bs, num_query, 4). batch_img_metas (list[dict]): Meta information of each image. - rescale (bool, optional): If True, return boxes in original - image space. Defaults to True. + rescale (bool, optional): If `True`, return boxes in original + image space. Defaults to `True`. Returns: list[:obj:`InstanceData`]: Object detection results of each image @@ -655,10 +537,10 @@ def predict_by_feat(self, - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). """ - # NOTE defaultly only using outputs from the last feature level, + # NOTE only using outputs from the last feature level, # and only the outputs from the last decoder layer is used. - cls_scores = all_cls_scores_list[-1][-1] - bbox_preds = all_bbox_preds_list[-1][-1] + cls_scores = layer_cls_scores[-1] + bbox_preds = layer_bbox_preds[-1] result_list = [] for img_id in range(len(batch_img_metas)): @@ -700,8 +582,8 @@ def _predict_by_feat_single(self, - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). """ - assert len(cls_score) == len(bbox_pred) - max_per_img = self.test_cfg.get('max_per_img', self.num_query) + assert len(cls_score) == len(bbox_pred) # num_query + max_per_img = self.test_cfg.get('max_per_img', len(cls_score)) img_shape = img_meta['img_shape'] # exclude background if self.loss_cls.use_sigmoid: diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index 6df092a025a..df43ac7914c 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -3,6 +3,7 @@ from .autoassign import AutoAssign from .base import BaseDetector from .boxinst import BoxInst +from .base_detr import DetectionTransformer from .cascade_rcnn import CascadeRCNN from .centernet import CenterNet from .condinst import CondInst @@ -62,5 +63,6 @@ 'SOLOv2', 'DeformableDETR', 'AutoAssign', 'YOLOF', 'CenterNet', 'YOLOX', 'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst', 'LAD', 'TOOD', 'MaskFormer', 'DDOD', 'Mask2Former', 'SemiBaseDetector', 'SoftTeacher', - 'RTMDet', 'Detectron2Wrapper', 'RTMDet', 'CrowdDet', 'CondInst', 'BoxInst' + 'RTMDet', 'Detectron2Wrapper', 'CrowdDet', 'CondInst', 'BoxInst', + 'DetectionTransformer' ] diff --git a/mmdet/models/detectors/base_detr.py b/mmdet/models/detectors/base_detr.py new file mode 100644 index 00000000000..7a2a0227417 --- /dev/null +++ b/mmdet/models/detectors/base_detr.py @@ -0,0 +1,329 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import Dict, List, Tuple, Union + +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import OptSampleList, SampleList +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .base import BaseDetector + + +@MODELS.register_module() +class DetectionTransformer(BaseDetector, metaclass=ABCMeta): + r"""Base class for Detection Transformer. + + Detection Transformer uses an encoder to process output features of neck, + then several queries interactive with the output features of encoder and + do the regression and classification with bounding box head. + + Args: + backbone (:obj:`ConfigDict` or dict): Config of the backbone. + neck (:obj:`ConfigDict` or dict, optional): Config of the neck. + Defaults to None. + encoder (:obj:`ConfigDict` or dict, optional): Config of the + Transformer encoder. Defaults to None. + decoder (:obj:`ConfigDict` or dict, optional): Config of the + Transformer decoder. Defaults to None. + positional_encoding_cfg (:obj:`ConfigDict` or dict, optional): Config + of the positional encoding module. Defaults to None. + bbox_head (:obj:`ConfigDict` or dict, optional): Config for the + bounding box head module. Defaults to None. + num_query (int, optional): Number of decoder query in Transformer. + Defaults to 100. + train_cfg (:obj:`ConfigDict` or dict, optional): Training config of + the bounding box head module. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of + the bounding box head module. Defaults to None. + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`BaseDataPreprocessor`. it usually includes, + ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. + Defaults to None. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + encoder: OptConfigType = None, + decoder: OptConfigType = None, + positional_encoding_cfg: OptConfigType = None, + bbox_head: OptConfigType = None, + num_query: int = 100, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + # process args + bbox_head.update(train_cfg=train_cfg) + bbox_head.update(test_cfg=test_cfg) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.encoder = encoder + self.decoder = decoder + self.positional_encoding_cfg = positional_encoding_cfg + self.num_query = num_query + + # init model layers + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + self.bbox_head = MODELS.build(bbox_head) + self._init_layers() + + @abstractmethod + def _init_layers(self) -> None: + """Initialize layers except for backbone, neck and bbox_head.""" + pass + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> Union[dict, list]: + """Calculate losses from a batch of inputs and data samples. + + Args: + batch_inputs (Tensor): Input images of shape (bs, dim, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components + """ + img_feats = self.extract_feat(batch_inputs) + head_inputs_dict = self.forward_transformer(img_feats, + batch_data_samples) + losses = self.bbox_head.loss( + **head_inputs_dict, batch_data_samples=batch_data_samples) + return losses + + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the input images. + Each DetDataSample usually contain 'pred_instances'. And the + `pred_instances` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + img_feats = self.extract_feat(batch_inputs) + head_inputs_dict = self.forward_transformer(img_feats, + batch_data_samples) + results_list = self.bbox_head.predict( + **head_inputs_dict, + rescale=rescale, + batch_data_samples=batch_data_samples) + batch_data_samples = self.add_pred_to_datasample( + batch_data_samples, results_list) + return batch_data_samples + + def _forward( + self, + batch_inputs: Tensor, + batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W). + batch_data_samples (List[:obj:`DetDataSample`], optional): The + batch data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + tuple[Tensor]: A tuple of features from ``bbox_head`` forward. + """ + img_feats = self.extract_feat(batch_inputs) + head_inputs_dict = self.forward_transformer(img_feats, + batch_data_samples) + results = self.bbox_head.forward(**head_inputs_dict) + return results + + def forward_transformer(self, + img_feats: Tuple[Tensor], + batch_data_samples: OptSampleList = None) -> Dict: + """Forward process of Transformer, which includes four steps: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'. We + summarized the parameters flow of the existing DETR-like detector, + which can be illustrated as follow: + + .. code:: text + + img_feats & batch_data_samples + | + V + +-----------------+ + | pre_transformer | + +-----------------+ + | | + | V + | +-----------------+ + | | forward_encoder | + | +-----------------+ + | | + | V + | +---------------+ + | | pre_decoder | + | +---------------+ + | | | + V V | + +-----------------+ | + | forward_decoder | | + +-----------------+ | + | | + V V + head_inputs_dict + + Args: + img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each + feature map has shape (bs, dim, H, W). + batch_data_samples (list[:obj:`DetDataSample`], optional): The + batch data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + dict: The dictionary of bbox_head function inputs, which always + includes the `hidden_states` of the decoder output and may contain + `references` including the initial and intermediate references. + """ + encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( + img_feats, batch_data_samples) + + encoder_outputs_dict = self.forward_encoder(**encoder_inputs_dict) + + tmp_dec_in, head_inputs_dict = self.pre_decoder(**encoder_outputs_dict) + decoder_inputs_dict.update(tmp_dec_in) + + decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict) + head_inputs_dict.update(decoder_outputs_dict) + return head_inputs_dict + + def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]: + """Extract features. + + Args: + batch_inputs (Tensor): Image tensor, has shape (bs, dim, H, W). + + Returns: + tuple[Tensor]: Tuple of feature maps from neck. Each feature map + has shape (bs, dim, H, W). + """ + x = self.backbone(batch_inputs) + if self.with_neck: + x = self.neck(x) + return x + + @abstractmethod + def pre_transformer( + self, + img_feats: Tuple[Tensor], + batch_data_samples: OptSampleList = None) -> Tuple[Dict, Dict]: + """Process image features before feeding them to the transformer. + + Args: + img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each + feature map has shape (bs, dim, H, W). + batch_data_samples (list[:obj:`DetDataSample`], optional): The + batch data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + tuple[dict, dict]: The first dict contains the inputs of encoder + and the second dict contains the inputs of decoder. + + - encoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_encoder()`, which includes 'feat', 'feat_mask', + 'feat_pos', and other algorithm-specific arguments. + - decoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_decoder()`, which includes 'memory_mask', and + other algorithm-specific arguments. + """ + pass + + @abstractmethod + def forward_encoder(self, feat: Tensor, feat_mask: Tensor, + feat_pos: Tensor, **kwargs) -> Dict: + """Forward with Transformer encoder. + + Args: + feat (Tensor): Sequential features, has shape (num_feat, bs, dim). + feat_mask (Tensor): ByteTensor, the padding mask of the features, + has shape (num_feat, bs). + feat_pos (Tensor): The positional embeddings of the features, has + shape (num_feat, bs, dim). + + Returns: + dict: The dictionary of encoder outputs, which includes the + `memory` of the encoder output and other algorithm-specific + arguments. + """ + pass + + @abstractmethod + def pre_decoder(self, memory: Tensor, **kwargs) -> Tuple[Dict, Dict]: + """Prepare intermediate variables before entering Transformer decoder, + such as `query`, `query_pos`, and `reference_points`. + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (num_feat, bs, dim). + + Returns: + tuple[dict, dict]: The first dict contains the inputs of decoder + and the second dict contains the inputs of the bbox_head function. + + - decoder_inputs_dict (dict): The keyword dictionary args of + `self.forward_decoder()`, which includes 'query', 'query_pos', + 'memory', and other algorithm-specific arguments. + - head_inputs_dict (dict): The keyword dictionary args of the + bbox_head functions, which is usually empty, or includes + `enc_outputs_class` and `enc_outputs_class` when the detector + support 'two stage' or 'query selection' strategies. + """ + pass + + @abstractmethod + def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor, + **kwargs) -> Dict: + """Forward with Transformer decoder. + + Args: + query (Tensor): The queries of decoder inputs, has shape + (num_query, bs, dim). + query_pos (Tensor): The positional queries of decoder inputs, + has shape (num_query, bs, dim). + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (num_feat, bs, dim). + + Returns: + dict: The dictionary of decoder outputs, which includes the + `hidden_states` of the decoder output, `references` including + the initial and intermediate reference_points, and other + algorithm-specific arguments. + """ + pass diff --git a/mmdet/models/detectors/deformable_detr.py b/mmdet/models/detectors/deformable_detr.py index 7fbbbb86ad7..44e6cf5e758 100644 --- a/mmdet/models/detectors/deformable_detr.py +++ b/mmdet/models/detectors/deformable_detr.py @@ -1,12 +1,541 @@ # Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, Tuple + +import torch +import torch.nn.functional as F +from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention +from mmengine.model import xavier_init +from torch import Tensor, nn +from torch.nn.init import normal_ + from mmdet.registry import MODELS -from .detr import DETR +from mmdet.structures import OptSampleList +from mmdet.utils import OptConfigType +from ..layers import (DeformableDetrTransformerDecoder, + DeformableDetrTransformerEncoder, SinePositionalEncoding) +from .base_detr import DetectionTransformer @MODELS.register_module() -class DeformableDETR(DETR): +class DeformableDETR(DetectionTransformer): r"""Implementation of `Deformable DETR: Deformable Transformers for - End-to-End Object Detection `_""" + End-to-End Object Detection `_ + + Code is modified from the `official github repo + `_. + + Args: + decoder (:obj:`ConfigDict` or dict, optional): Config of the + Transformer decoder. Defaults to None. + bbox_head (:obj:`ConfigDict` or dict, optional): Config for the + bounding box head module. Defaults to None. + with_box_refine (bool, optional): Whether to refine the references + in the decoder. Defaults to `False`. + as_two_stage (bool, optional): Whether to generate the proposal + from the outputs of encoder. Defaults to `False`. + num_feature_levels (int, optional): Number of feature levels. + Defaults to 4. + """ + + def __init__(self, + *args, + decoder: OptConfigType = None, + bbox_head: OptConfigType = None, + with_box_refine: bool = False, + as_two_stage: bool = False, + num_feature_levels: int = 4, + **kwargs) -> None: + self.with_box_refine = with_box_refine + self.as_two_stage = as_two_stage + self.num_feature_levels = num_feature_levels + + if bbox_head is not None: + assert 'share_pred_layer' not in bbox_head and \ + 'num_pred_layer' not in bbox_head and \ + 'as_two_stage' not in bbox_head, \ + 'The two keyword args `share_pred_layer`, `num_pred_layer`, ' \ + 'and `as_two_stage are set in `detector.__init__()`, users ' \ + 'should not set them in `bbox_head` config.' + # The last prediction layer is used to generate proposal + # from encode feature map when `as_two_stage` is `True`. + # And all the prediction layers should share parameters + # when `with_box_refine` is `True`. + bbox_head['share_pred_layer'] = not with_box_refine + bbox_head['num_pred_layer'] = (decoder['num_layers'] + 1) \ + if self.as_two_stage else decoder['num_layers'] + bbox_head['as_two_stage'] = as_two_stage + + super().__init__(*args, decoder=decoder, bbox_head=bbox_head, **kwargs) + + def _init_layers(self) -> None: + """Initialize layers except for backbone, neck and bbox_head.""" + self.positional_encoding = SinePositionalEncoding( + **self.positional_encoding_cfg) + self.encoder = DeformableDetrTransformerEncoder(**self.encoder) + self.decoder = DeformableDetrTransformerDecoder(**self.decoder) + self.embed_dims = self.encoder.embed_dims + if not self.as_two_stage: + self.query_embedding = nn.Embedding(self.num_query, + self.embed_dims * 2) + # NOTE The query_embedding will be split into query and query_pos + # in self.pre_decoder, hence, the embed_dims are doubled. + + num_feats = self.positional_encoding.num_feats + assert num_feats * 2 == self.embed_dims, \ + 'embed_dims should be exactly 2 times of num_feats. ' \ + f'Found {self.embed_dims} and {num_feats}.' + + self.level_embed = nn.Parameter( + torch.Tensor(self.num_feature_levels, self.embed_dims)) + + if self.as_two_stage: + self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims) + self.memory_trans_norm = nn.LayerNorm(self.embed_dims) + self.pos_trans_fc = nn.Linear(self.embed_dims * 2, + self.embed_dims * 2) + self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2) + else: + self.reference_points_fc = nn.Linear(self.embed_dims, 2) + + def init_weights(self) -> None: + """Initialize weights for Transformer and other components.""" + super().init_weights() + for coder in self.encoder, self.decoder: + for p in coder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MultiScaleDeformableAttention): + m.init_weights() + if self.as_two_stage: + nn.init.xavier_uniform_(self.memory_trans_fc.weight) + nn.init.xavier_uniform_(self.pos_trans_fc.weight) + else: + xavier_init( + self.reference_points_fc, distribution='uniform', bias=0.) + normal_(self.level_embed) + + def pre_transformer( + self, + mlvl_feats: Tuple[Tensor], + batch_data_samples: OptSampleList = None) -> Tuple[Dict]: + """Process image features before feeding them to the transformer. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + mlvl_feats (tuple[Tensor]): Multi-level features that may have + different resolutions, output from neck. Each feature has + shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'. + batch_data_samples (list[:obj:`DetDataSample`], optional): The + batch data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + tuple[dict]: The first dict contains the inputs of encoder and the + second dict contains the inputs of decoder. + + - encoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_encoder()`, which includes 'feat', 'feat_mask', + and 'feat_pos'. + - decoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_decoder()`, which includes 'memory_mask'. + """ + batch_size = mlvl_feats[0].size(0) + + # construct binary masks for the transformer. + assert batch_data_samples is not None + batch_input_shape = batch_data_samples[0].batch_input_shape + img_shape_list = [sample.img_shape for sample in batch_data_samples] + input_img_h, input_img_w = batch_input_shape + masks = mlvl_feats[0].new_ones((batch_size, input_img_h, input_img_w)) + for img_id in range(batch_size): + img_h, img_w = img_shape_list[img_id] + masks[img_id, :img_h, :img_w] = 0 + # NOTE following the official DETR repo, non-zero values representing + # ignored positions, while zero values means valid positions. + + mlvl_masks = [] + mlvl_pos_embeds = [] + for feat in mlvl_feats: + mlvl_masks.append( + F.interpolate(masks[None], + size=feat.shape[-2:]).to(torch.bool).squeeze(0)) + mlvl_pos_embeds.append(self.positional_encoding(mlvl_masks[-1])) + + feat_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate( + zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): + batch_size, c, h, w = feat.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + feat = feat.flatten(2).transpose(1, 2) # (bs, h_lvl*w_lvl, dim) + pos_embed = pos_embed.flatten(2).transpose(1, 2) # as above + mask = mask.flatten(1) # (bs, h_lvl*w_lvl) + lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + feat_flatten.append(feat) + mask_flatten.append(mask) + + # (bs, num_feat), where num_feat = sum_lvl(h_lvl*w_lvl) + mask_flatten = torch.cat(mask_flatten, 1) + # (bs, num_feat, dim) + feat_flatten = torch.cat(feat_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + # (num_feat, bs, dim) + feat_flatten = feat_flatten.permute(1, 0, 2) + lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(1, 0, 2) + + spatial_shapes = torch.as_tensor( # (num_level, 2) + spatial_shapes, + dtype=torch.long, + device=feat_flatten.device) + level_start_index = torch.cat(( + spatial_shapes.new_zeros( # (num_level) + (1, )), + spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack( # (bs, num_level, 2) + [self.get_valid_ratio(m) for m in mlvl_masks], 1) + + encoder_inputs_dict = dict( + feat=feat_flatten, + feat_mask=mask_flatten, + feat_pos=lvl_pos_embed_flatten, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios) + decoder_inputs_dict = dict( + memory_mask=mask_flatten, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios) + return encoder_inputs_dict, decoder_inputs_dict + + def forward_encoder(self, feat: Tensor, feat_mask: Tensor, + feat_pos: Tensor, spatial_shapes: Tensor, + level_start_index: Tensor, + valid_ratios: Tensor) -> Dict: + """Forward with Transformer encoder. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + feat (Tensor): Sequential features, has shape (num_feat, bs, dim). + feat_mask (Tensor): ByteTensor, the padding mask of the features, + has shape (num_feat, bs). + feat_pos (Tensor): The positional embeddings of the features, has + shape (num_feat, bs, dim). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + + Returns: + dict: The dictionary of encoder outputs, which includes the + `memory` of the encoder output. + """ + memory = self.encoder( + query=feat, + query_pos=feat_pos, + key_padding_mask=feat_mask, # for self_attn + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios) # (num_feat, bs, dim) + memory = memory.permute(1, 0, 2) # (bs, num_feat, dim) + encoder_outputs_dict = dict( + memory=memory, + memory_mask=feat_mask, + spatial_shapes=spatial_shapes) + return encoder_outputs_dict + + def pre_decoder(self, memory: Tensor, memory_mask: Tensor, + spatial_shapes: Tensor) -> Tuple[Dict, Dict]: + """Prepare intermediate variables before entering Transformer decoder, + such as `query`, `query_pos`, and `reference_points`. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat). It will only be used when + `as_two_stage` is `True`. + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + It will only be used when `as_two_stage` is `True`. + + Returns: + tuple[dict, dict]: The decoder_inputs_dict and head_inputs_dict. + + - decoder_inputs_dict (dict): The keyword dictionary args of + `self.forward_decoder()`, which includes 'query', 'query_pos', + 'memory', and `reference_points`. The reference_points of + decoder input here are 4D boxes when `as_two_stage` is `True`, + otherwise 2D points, although it has `points` in its name. + The reference_points in encoder is always 2D points. + - head_inputs_dict (dict): The keyword dictionary args of the + bbox_head functions, which includes `enc_outputs_class` and + `enc_outputs_class`. They are both `None` when 'as_two_stage' + is `False`. + """ + batch_size, _, c = memory.shape + if self.as_two_stage: + output_memory, output_proposals = \ + self.gen_encoder_output_proposals( + memory, memory_mask, spatial_shapes) + enc_outputs_class = self.bbox_head.cls_branches[ + self.decoder.num_layers]( + output_memory) + enc_outputs_coord_unact = self.bbox_head.reg_branches[ + self.decoder.num_layers](output_memory) + output_proposals + # We only use the first channel in enc_outputs_class as foreground, + # the other (num_classes - 1) channels are actually not used. + # Its targets are set to be 0s, which indicates the first + # class (foreground) because we use [0, num_classes - 1] to + # indicate class labels, background class is indicated by + # num_classes (similar convention in RPN). + # See https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/deformable_detr_head.py#L241 # noqa + # This follows the official implementation of Deformable DETR. + topk_proposals = torch.topk( + enc_outputs_class[..., 0], self.num_query, dim=1)[1] + topk_coords_unact = torch.gather( + enc_outputs_coord_unact, 1, + topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact.sigmoid() + pos_trans_out = self.pos_trans_fc( + self.get_proposal_pos_embed(topk_coords_unact)) + pos_trans_out = self.pos_trans_norm(pos_trans_out) + query_pos, query = torch.split(pos_trans_out, c, dim=2) + else: + query_embed = self.query_embedding.weight + query_pos, query = torch.split(query_embed, c, dim=1) + query_pos = query_pos.unsqueeze(0).expand(batch_size, -1, -1) + query = query.unsqueeze(0).expand(batch_size, -1, -1) + reference_points = self.reference_points_fc(query_pos).sigmoid() + + query = query.permute(1, 0, 2) # (num_query, bs, dim) + memory = memory.permute(1, 0, 2) # (num_feat, bs, dim) + query_pos = query_pos.permute(1, 0, 2) # (num_query, bs, dim) + + decoder_inputs_dict = dict( + query=query, + query_pos=query_pos, + memory=memory, + reference_points=reference_points) + head_inputs_dict = dict( + enc_outputs_class=enc_outputs_class if self.as_two_stage else None, + enc_outputs_coord=enc_outputs_coord_unact.sigmoid() + if self.as_two_stage else None) + return decoder_inputs_dict, head_inputs_dict + + def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor, + memory_mask: Tensor, reference_points: Tensor, + spatial_shapes: Tensor, level_start_index: Tensor, + valid_ratios: Tensor) -> Dict: + """Forward with Transformer decoder. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + query (Tensor): The queries of decoder inputs, has shape + (num_query, bs, dim). + query_pos (Tensor): The positional queries of decoder inputs, + has shape (num_query, bs, dim). + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (num_feat, bs, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat). + reference_points (Tensor): The initial reference, has shape + (bs, num_query, 4) when `as_two_stage` is `True`, + otherwise has shape (bs, num_query, 2). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + + Returns: + dict: The dictionary of decoder outputs, which includes the + `hidden_states` of the decoder output and `references` including + the initial and intermediate reference_points. + """ + inter_states, inter_references = self.decoder( + query=query, + value=memory, + query_pos=query_pos, + key_padding_mask=memory_mask, # for cross_attn + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=self.bbox_head.reg_branches + if self.with_box_refine else None) + references = [reference_points, *inter_references] + decoder_outputs_dict = dict( + hidden_states=inter_states, references=references) + return decoder_outputs_dict + + @staticmethod + def get_valid_ratio(mask: Tensor) -> Tensor: + """Get the valid radios of feature map in a level. + + .. code:: text + + |---> valid_H <---| + ---+-----------------+-----+--- + A | | | A + | | | | | + | | | | | + valid_W | | | | + | | | | W + | | | | | + V | | | | + ---+-----------------+ | | + | | V + +-----------------------+--- + |---------> H <---------| + + The valid_ratios are defined as: + r_h = valid_H / H, r_w = valid_W / W + They are the factors to re-normalize the relative coordinates of the + image to the relative coordinates of the current level feature map. + + Args: + mask (Tensor): Binary mask of a feature map, has shape (bs, H, W). + + Returns: + Tensor: valid ratios [r_w, r_h] of a feature map, has shape (1, 2). + """ + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def gen_encoder_output_proposals( + self, memory: Tensor, memory_mask: Tensor, + spatial_shapes: Tensor) -> Tuple[Tensor, Tensor]: + """Generate proposals from encoded memory. The function will only be + used when `as_two_stage` is `True`. + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (num_feat, bs, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + + Returns: + tuple: A tuple of transformed memory and proposals. + + - output_memory (Tensor): The transformed memory for obtaining + top-k proposals, has shape (bs, num_feat, dim). + - output_proposals (Tensor): The inverse-normalized proposal, has + shape (batch_size, num_keys, 4). + """ + + num_feat = memory.size(0) + proposals = [] + _cur = 0 # start index in the sequence of the current level + for lvl, (H, W) in enumerate(spatial_shapes): + mask_flatten_ = memory_mask[:, _cur:(_cur + H * W)].view( + num_feat, H, W, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1).unsqueeze(-1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1).unsqueeze(-1) + + grid_y, grid_x = torch.meshgrid( + torch.linspace( + 0, H - 1, H, dtype=torch.float32, device=memory.device), + torch.linspace( + 0, W - 1, W, dtype=torch.float32, device=memory.device)) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W, valid_H], 1).view(num_feat, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(num_feat, -1, -1, -1) + + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) + proposal = torch.cat((grid, wh), -1).view(num_feat, -1, 4) + proposals.append(proposal) + _cur += (H * W) + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & + (output_proposals < 0.99)).all( + -1, keepdim=True) + # inverse_sigmoid + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill( + memory_mask.unsqueeze(-1), float('inf')) + output_proposals = output_proposals.masked_fill( + ~output_proposals_valid, float('inf')) + + output_memory = memory + output_memory = output_memory.masked_fill( + memory_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, + float(0)) + output_memory = self.memory_trans_fc(output_memory) + output_memory = self.memory_trans_norm(output_memory) + # [bs, sum(hw), 2] + return output_memory, output_proposals + + @staticmethod + def get_proposal_pos_embed(proposals: Tensor, + num_pos_feats: int = 128, + temperature: int = 10000) -> Tensor: + """Get the position embedding of the proposal. + + Args: + proposals (Tensor): Not normalized proposals, has shape + (bs, num_query, 4). + num_pos_feats (int, optional): The feature dimension for each + position along x, y, w, and h-axis. Note the final returned + dimension for each position is 4 times of num_pos_feats. + Default to 128. + temperature (int, optional): The temperature used for scaling the + position embedding. Defaults to 10000. - def __init__(self, *args, **kwargs) -> None: - super(DETR, self).__init__(*args, **kwargs) + Returns: + Tensor: The position embedding of proposal, has shape + (bs, num_query, num_pos_feats * 4) + """ + scale = 2 * math.pi + dim_t = torch.arange( + num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), + dim=4).flatten(2) + return pos diff --git a/mmdet/models/detectors/detr.py b/mmdet/models/detectors/detr.py index 979b9bc829f..48079be22f2 100644 --- a/mmdet/models/detectors/detr.py +++ b/mmdet/models/detectors/detr.py @@ -1,52 +1,212 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Tuple +from typing import Dict, Tuple -from torch import Tensor +import torch +import torch.nn.functional as F +from torch import Tensor, nn from mmdet.registry import MODELS -from mmdet.structures import SampleList -from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig -from .single_stage import SingleStageDetector +from mmdet.structures import OptSampleList +from ..layers import (DetrTransformerDecoder, DetrTransformerEncoder, + SinePositionalEncoding) +from .base_detr import DetectionTransformer @MODELS.register_module() -class DETR(SingleStageDetector): - r"""Implementation of `DETR: End-to-End Object Detection with - Transformers `_""" - - def __init__(self, - backbone: ConfigType, - bbox_head: ConfigType, - train_cfg: OptConfigType = None, - test_cfg: OptConfigType = None, - data_preprocessor: OptConfigType = None, - init_cfg: OptMultiConfig = None) -> None: - super().__init__( - backbone=backbone, - neck=None, - bbox_head=bbox_head, - train_cfg=train_cfg, - test_cfg=test_cfg, - data_preprocessor=data_preprocessor, - init_cfg=init_cfg) - - def _forward(self, batch_inputs: Tensor, - batch_data_samples: SampleList) -> Tuple[List[Tensor]]: - """Network forward process. Usually includes backbone, neck and head - forward without any post-processing. - - Args: - batch_inputs (Tensor): Inputs with shape (N, C, H, W). - batch_data_samples (list[:obj:`DetDataSample`]): The batch - data samples. It usually includes information such - as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. +class DETR(DetectionTransformer): + r"""Implementation of `DETR: End-to-End Object Detection with Transformers. + + `_. + + Code is modified from the `official github repo + `_. + """ + + def _init_layers(self) -> None: + """Initialize layers except for backbone, neck and bbox_head.""" + self.positional_encoding = SinePositionalEncoding( + **self.positional_encoding_cfg) + self.encoder = DetrTransformerEncoder(**self.encoder) + self.decoder = DetrTransformerDecoder(**self.decoder) + self.embed_dims = self.encoder.embed_dims + # NOTE The embed_dims is typically passed from the inside out. + # For example in DETR, The embed_dims is passed as + # self_attn -> the first encoder layer -> encoder -> detector. + self.query_embedding = nn.Embedding(self.num_query, self.embed_dims) + + num_feats = self.positional_encoding.num_feats + assert num_feats * 2 == self.embed_dims, \ + 'embed_dims should be exactly 2 times of num_feats. ' \ + f'Found {self.embed_dims} and {num_feats}.' + + def init_weights(self) -> None: + """Initialize weights for Transformer and other components.""" + super().init_weights() + for coder in self.encoder, self.decoder: + for p in coder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def pre_transformer( + self, + img_feats: Tuple[Tensor], + batch_data_samples: OptSampleList = None) -> Tuple[Dict, Dict]: + """Prepare the inputs of the Transformer. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + img_feats (Tuple[Tensor]): Tuple of features output from the neck, + has shape (bs, c, h, w). + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such as + `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + tuple[dict, dict]: The first dict contains the inputs of encoder + and the second dict contains the inputs of decoder. + + - encoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_encoder()`, which includes 'feat', 'feat_mask', + and 'feat_pos'. + - decoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_decoder()`, which includes 'memory_mask', + and 'memory_pos'. + """ + + feat = img_feats[-1] # NOTE img_feats contains only one feature. + batch_size, feat_dim, _, _ = feat.shape + # construct binary masks which for the transformer. + assert batch_data_samples is not None + batch_input_shape = batch_data_samples[0].batch_input_shape + img_shape_list = [sample.img_shape for sample in batch_data_samples] + + input_img_h, input_img_w = batch_input_shape + masks = feat.new_ones((batch_size, input_img_h, input_img_w)) + for img_id in range(batch_size): + img_h, img_w = img_shape_list[img_id] + masks[img_id, :img_h, :img_w] = 0 + # NOTE following the official DETR repo, non-zero values represent + # ignored positions, while zero values mean valid positions. + + masks = F.interpolate( + masks.unsqueeze(1), size=feat.shape[-2:]).to(torch.bool).squeeze(1) + # [batch_size, embed_dim, h, w] + pos_embed = self.positional_encoding(masks) + + # use `view` instead of `flatten` for dynamically exporting to ONNX + # [bs, c, h, w] -> [h*w, bs, c] + feat = feat.view(batch_size, feat_dim, -1).permute(2, 0, 1) + pos_embed = pos_embed.view(batch_size, feat_dim, -1).permute(2, 0, 1) + # [bs, h, w] -> [bs, h*w] + masks = masks.view(batch_size, -1) + + # prepare transformer_inputs_dict + encoder_inputs_dict = dict( + feat=feat, feat_mask=masks, feat_pos=pos_embed) + decoder_inputs_dict = dict(memory_mask=masks, memory_pos=pos_embed) + return encoder_inputs_dict, decoder_inputs_dict + + def forward_encoder(self, feat: Tensor, feat_mask: Tensor, + feat_pos: Tensor) -> Dict: + """Forward with Transformer encoder. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + feat (Tensor): Sequential features, has shape (num_feat, bs, dim). + feat_mask (Tensor): ByteTensor, the padding mask of the features, + has shape (num_feat, bs). + feat_pos (Tensor): The positional embeddings of the features, has + shape (num_feat, bs, dim). + + Returns: + dict: The dictionary of encoder outputs, which includes the + `memory` of the encoder output. + """ + memory = self.encoder( + query=feat, query_pos=feat_pos, + key_padding_mask=feat_mask) # for self_attn + encoder_outputs_dict = dict(memory=memory) + return encoder_outputs_dict + + def pre_decoder(self, memory: Tensor) -> Tuple[Dict, Dict]: + """Prepare intermediate variables before entering Transformer decoder, + such as `query`, `query_pos`. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (num_feat, bs, dim). + + Returns: + tuple[dict, dict]: The first dict contains the inputs of decoder + and the second dict contains the inputs of the bbox_head function. + + - decoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_decoder()`, which includes 'query', 'query_pos', + 'memory'. + - head_inputs_dict (dict): The keyword args dictionary of the + bbox_head functions, which is usually empty, or includes + `enc_outputs_class` and `enc_outputs_class` when the detector + support 'two stage' or 'query selection' strategies. + """ + + batch_size = memory.size(1) + query_pos = self.query_embedding.weight + # (num_query, dim) -> (num_query, bs, dim) + query_pos = query_pos.unsqueeze(1).repeat(1, batch_size, 1) + query = torch.zeros_like(query_pos) + + decoder_inputs_dict = dict( + query_pos=query_pos, query=query, memory=memory) + head_inputs_dict = dict() + return decoder_inputs_dict, head_inputs_dict + + def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor, + memory_mask: Tensor, memory_pos: Tensor) -> Dict: + """Forward with Transformer decoder. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + query (Tensor): The queries of decoder inputs, has shape + (num_query, bs, dim). + query_pos (Tensor): The positional queries of decoder inputs, + has shape (num_query, bs, dim). + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (num_feat, bs, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat). + memory_pos (Tensor): The positional embeddings of memory, has + shape (num_feat, bs, dim). Returns: - tuple[list]: A tuple of features from ``bbox_head`` forward. + dict: The dictionary of decoder outputs, which includes the + `hidden_states` of the decoder output. """ - x = self.extract_feat(batch_inputs) - batch_img_metas = [ - data_samples.metainfo for data_samples in batch_data_samples - ] - results = self.bbox_head.forward(x, batch_img_metas) - return results + # (num_decoder_layers, num_query, bs, dim) + hidden_states = self.decoder( + query=query, + key=memory, + value=memory, + query_pos=query_pos, + key_pos=memory_pos, + key_padding_mask=memory_mask) # for cross_attn + hidden_states = hidden_states.transpose(1, 2) + head_inputs_dict = dict(hidden_states=hidden_states) + return head_inputs_dict diff --git a/mmdet/models/layers/__init__.py b/mmdet/models/layers/__init__.py index 98f8e843075..e36f9b4b96a 100644 --- a/mmdet/models/layers/__init__.py +++ b/mmdet/models/layers/__init__.py @@ -15,18 +15,27 @@ SinePositionalEncoding) from .res_layer import ResLayer, SimplifiedBasicBlock from .se_layer import ChannelAttention, DyReLU, SELayer -from .transformer import (DetrTransformerDecoder, DetrTransformerDecoderLayer, - DynamicConv, PatchEmbed, PatchMerging, Transformer, +from .transformer import (MLP, DeformableDetrTransformerDecoder, + DeformableDetrTransformerDecoderLayer, + DeformableDetrTransformerEncoder, + DeformableDetrTransformerEncoderLayer, + DetrTransformerDecoder, DetrTransformerDecoderLayer, + DetrTransformerEncoder, DetrTransformerEncoderLayer, + DynamicConv, PatchEmbed, PatchMerging, inverse_sigmoid, nchw_to_nlc, nlc_to_nchw) __all__ = [ 'fast_nms', 'multiclass_nms', 'mask_matrix_nms', 'DropBlock', 'PixelDecoder', 'TransformerEncoderPixelDecoder', - 'MSDeformAttnPixelDecoder', 'ResLayer', 'DetrTransformerDecoderLayer', - 'DetrTransformerDecoder', 'Transformer', 'PatchMerging', + 'MSDeformAttnPixelDecoder', 'ResLayer', 'PatchMerging', 'SinePositionalEncoding', 'LearnedPositionalEncoding', 'DynamicConv', 'SimplifiedBasicBlock', 'NormedLinear', 'NormedConv2d', 'InvertedResidual', 'SELayer', 'ConvUpsample', 'CSPLayer', 'adaptive_avg_pool2d', 'AdaptiveAvgPool2d', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw', 'DyReLU', - 'ExpMomentumEMA', 'inverse_sigmoid', 'ChannelAttention', 'SiLU' + 'ExpMomentumEMA', 'inverse_sigmoid', 'ChannelAttention', 'SiLU', 'MLP', + 'DetrTransformerEncoderLayer', 'DetrTransformerDecoderLayer', + 'DetrTransformerEncoder', 'DetrTransformerDecoder', + 'DeformableDetrTransformerEncoder', 'DeformableDetrTransformerDecoder', + 'DeformableDetrTransformerEncoderLayer', + 'DeformableDetrTransformerDecoderLayer' ] diff --git a/mmdet/models/layers/msdeformattn_pixel_decoder.py b/mmdet/models/layers/msdeformattn_pixel_decoder.py index 953f873f400..12ea14d7efc 100644 --- a/mmdet/models/layers/msdeformattn_pixel_decoder.py +++ b/mmdet/models/layers/msdeformattn_pixel_decoder.py @@ -5,7 +5,8 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import Conv2d, ConvModule -from mmcv.cnn.bricks.transformer import (build_positional_encoding, +from mmcv.cnn.bricks.transformer import (MultiScaleDeformableAttention, + build_positional_encoding, build_transformer_layer_sequence) from mmengine.model import (BaseModule, ModuleList, caffe2_xavier_init, normal_init, xavier_init) @@ -14,7 +15,6 @@ from mmdet.registry import MODELS from mmdet.utils import ConfigType, OptMultiConfig from ..task_modules.prior_generators import MlvlPointGenerator -from .transformer import MultiScaleDeformableAttention @MODELS.register_module() diff --git a/mmdet/models/layers/transformer.py b/mmdet/models/layers/transformer.py deleted file mode 100644 index 19c3e62f289..00000000000 --- a/mmdet/models/layers/transformer.py +++ /dev/null @@ -1,1164 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import math -import warnings -from typing import Sequence - -import torch -import torch.nn as nn -import torch.nn.functional as F -from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer -from mmcv.cnn.bricks.transformer import (BaseTransformerLayer, - TransformerLayerSequence, - build_transformer_layer_sequence) -from mmengine.model import BaseModule, xavier_init -from mmengine.utils import to_2tuple -from torch.nn.init import normal_ - -from mmdet.registry import MODELS - -try: - from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention - -except ImportError: - warnings.warn( - '`MultiScaleDeformableAttention` in MMCV has been moved to ' - '`mmcv.ops.multi_scale_deform_attn`, please update your MMCV') - from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention - - -def nlc_to_nchw(x, hw_shape): - """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. - - Args: - x (Tensor): The input tensor of shape [N, L, C] before conversion. - hw_shape (Sequence[int]): The height and width of output feature map. - - Returns: - Tensor: The output tensor of shape [N, C, H, W] after conversion. - """ - H, W = hw_shape - assert len(x.shape) == 3 - B, L, C = x.shape - assert L == H * W, 'The seq_len does not match H, W' - return x.transpose(1, 2).reshape(B, C, H, W).contiguous() - - -def nchw_to_nlc(x): - """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. - - Args: - x (Tensor): The input tensor of shape [N, C, H, W] before conversion. - - Returns: - Tensor: The output tensor of shape [N, L, C] after conversion. - """ - assert len(x.shape) == 4 - return x.flatten(2).transpose(1, 2).contiguous() - - -class AdaptivePadding(nn.Module): - """Applies padding to input (if needed) so that input can get fully covered - by filter you specified. It support two modes "same" and "corner". The - "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around - input. The "corner" mode would pad zero to bottom right. - - Args: - kernel_size (int | tuple): Size of the kernel: - stride (int | tuple): Stride of the filter. Default: 1: - dilation (int | tuple): Spacing between kernel elements. - Default: 1 - padding (str): Support "same" and "corner", "corner" mode - would pad zero to bottom right, and "same" mode would - pad zero around input. Default: "corner". - Example: - >>> kernel_size = 16 - >>> stride = 16 - >>> dilation = 1 - >>> input = torch.rand(1, 1, 15, 17) - >>> adap_pad = AdaptivePadding( - >>> kernel_size=kernel_size, - >>> stride=stride, - >>> dilation=dilation, - >>> padding="corner") - >>> out = adap_pad(input) - >>> assert (out.shape[2], out.shape[3]) == (16, 32) - >>> input = torch.rand(1, 1, 16, 17) - >>> out = adap_pad(input) - >>> assert (out.shape[2], out.shape[3]) == (16, 32) - """ - - def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'): - - super(AdaptivePadding, self).__init__() - - assert padding in ('same', 'corner') - - kernel_size = to_2tuple(kernel_size) - stride = to_2tuple(stride) - padding = to_2tuple(padding) - dilation = to_2tuple(dilation) - - self.padding = padding - self.kernel_size = kernel_size - self.stride = stride - self.dilation = dilation - - def get_pad_shape(self, input_shape): - input_h, input_w = input_shape - kernel_h, kernel_w = self.kernel_size - stride_h, stride_w = self.stride - output_h = math.ceil(input_h / stride_h) - output_w = math.ceil(input_w / stride_w) - pad_h = max((output_h - 1) * stride_h + - (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) - pad_w = max((output_w - 1) * stride_w + - (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) - return pad_h, pad_w - - def forward(self, x): - pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) - if pad_h > 0 or pad_w > 0: - if self.padding == 'corner': - x = F.pad(x, [0, pad_w, 0, pad_h]) - elif self.padding == 'same': - x = F.pad(x, [ - pad_w // 2, pad_w - pad_w // 2, pad_h // 2, - pad_h - pad_h // 2 - ]) - return x - - -class PatchEmbed(BaseModule): - """Image to Patch Embedding. - - We use a conv layer to implement PatchEmbed. - - Args: - in_channels (int): The num of input channels. Default: 3 - embed_dims (int): The dimensions of embedding. Default: 768 - conv_type (str): The config dict for embedding - conv layer type selection. Default: "Conv2d. - kernel_size (int): The kernel_size of embedding conv. Default: 16. - stride (int): The slide stride of embedding conv. - Default: None (Would be set as `kernel_size`). - padding (int | tuple | string ): The padding length of - embedding conv. When it is a string, it means the mode - of adaptive padding, support "same" and "corner" now. - Default: "corner". - dilation (int): The dilation rate of embedding conv. Default: 1. - bias (bool): Bias of embed conv. Default: True. - norm_cfg (dict, optional): Config dict for normalization layer. - Default: None. - input_size (int | tuple | None): The size of input, which will be - used to calculate the out size. Only work when `dynamic_size` - is False. Default: None. - init_cfg (`mmengine.ConfigDict`, optional): The Config for - initialization. Default: None. - """ - - def __init__( - self, - in_channels=3, - embed_dims=768, - conv_type='Conv2d', - kernel_size=16, - stride=16, - padding='corner', - dilation=1, - bias=True, - norm_cfg=None, - input_size=None, - init_cfg=None, - ): - super(PatchEmbed, self).__init__(init_cfg=init_cfg) - - self.embed_dims = embed_dims - if stride is None: - stride = kernel_size - - kernel_size = to_2tuple(kernel_size) - stride = to_2tuple(stride) - dilation = to_2tuple(dilation) - - if isinstance(padding, str): - self.adap_padding = AdaptivePadding( - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - padding=padding) - # disable the padding of conv - padding = 0 - else: - self.adap_padding = None - padding = to_2tuple(padding) - - self.projection = build_conv_layer( - dict(type=conv_type), - in_channels=in_channels, - out_channels=embed_dims, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - bias=bias) - - if norm_cfg is not None: - self.norm = build_norm_layer(norm_cfg, embed_dims)[1] - else: - self.norm = None - - if input_size: - input_size = to_2tuple(input_size) - # `init_out_size` would be used outside to - # calculate the num_patches - # when `use_abs_pos_embed` outside - self.init_input_size = input_size - if self.adap_padding: - pad_h, pad_w = self.adap_padding.get_pad_shape(input_size) - input_h, input_w = input_size - input_h = input_h + pad_h - input_w = input_w + pad_w - input_size = (input_h, input_w) - - # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html - h_out = (input_size[0] + 2 * padding[0] - dilation[0] * - (kernel_size[0] - 1) - 1) // stride[0] + 1 - w_out = (input_size[1] + 2 * padding[1] - dilation[1] * - (kernel_size[1] - 1) - 1) // stride[1] + 1 - self.init_out_size = (h_out, w_out) - else: - self.init_input_size = None - self.init_out_size = None - - def forward(self, x): - """ - Args: - x (Tensor): Has shape (B, C, H, W). In most case, C is 3. - - Returns: - tuple: Contains merged results and its spatial shape. - - - x (Tensor): Has shape (B, out_h * out_w, embed_dims) - - out_size (tuple[int]): Spatial shape of x, arrange as - (out_h, out_w). - """ - - if self.adap_padding: - x = self.adap_padding(x) - - x = self.projection(x) - out_size = (x.shape[2], x.shape[3]) - x = x.flatten(2).transpose(1, 2) - if self.norm is not None: - x = self.norm(x) - return x, out_size - - -class PatchMerging(BaseModule): - """Merge patch feature map. - - This layer groups feature map by kernel_size, and applies norm and linear - layers to the grouped feature map. Our implementation uses `nn.Unfold` to - merge patch, which is about 25% faster than original implementation. - Instead, we need to modify pretrained models for compatibility. - - Args: - in_channels (int): The num of input channels. - to gets fully covered by filter and stride you specified.. - Default: True. - out_channels (int): The num of output channels. - kernel_size (int | tuple, optional): the kernel size in the unfold - layer. Defaults to 2. - stride (int | tuple, optional): the stride of the sliding blocks in the - unfold layer. Default: None. (Would be set as `kernel_size`) - padding (int | tuple | string ): The padding length of - embedding conv. When it is a string, it means the mode - of adaptive padding, support "same" and "corner" now. - Default: "corner". - dilation (int | tuple, optional): dilation parameter in the unfold - layer. Default: 1. - bias (bool, optional): Whether to add bias in linear layer or not. - Defaults: False. - norm_cfg (dict, optional): Config dict for normalization layer. - Default: dict(type='LN'). - init_cfg (dict, optional): The extra config for initialization. - Default: None. - """ - - def __init__(self, - in_channels, - out_channels, - kernel_size=2, - stride=None, - padding='corner', - dilation=1, - bias=False, - norm_cfg=dict(type='LN'), - init_cfg=None): - super().__init__(init_cfg=init_cfg) - self.in_channels = in_channels - self.out_channels = out_channels - if stride: - stride = stride - else: - stride = kernel_size - - kernel_size = to_2tuple(kernel_size) - stride = to_2tuple(stride) - dilation = to_2tuple(dilation) - - if isinstance(padding, str): - self.adap_padding = AdaptivePadding( - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - padding=padding) - # disable the padding of unfold - padding = 0 - else: - self.adap_padding = None - - padding = to_2tuple(padding) - self.sampler = nn.Unfold( - kernel_size=kernel_size, - dilation=dilation, - padding=padding, - stride=stride) - - sample_dim = kernel_size[0] * kernel_size[1] * in_channels - - if norm_cfg is not None: - self.norm = build_norm_layer(norm_cfg, sample_dim)[1] - else: - self.norm = None - - self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) - - def forward(self, x, input_size): - """ - Args: - x (Tensor): Has shape (B, H*W, C_in). - input_size (tuple[int]): The spatial shape of x, arrange as (H, W). - Default: None. - - Returns: - tuple: Contains merged results and its spatial shape. - - - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) - - out_size (tuple[int]): Spatial shape of x, arrange as - (Merged_H, Merged_W). - """ - B, L, C = x.shape - assert isinstance(input_size, Sequence), f'Expect ' \ - f'input_size is ' \ - f'`Sequence` ' \ - f'but get {input_size}' - - H, W = input_size - assert L == H * W, 'input feature has wrong size' - - x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W - # Use nn.Unfold to merge patch. About 25% faster than original method, - # but need to modify pretrained model for compatibility - - if self.adap_padding: - x = self.adap_padding(x) - H, W = x.shape[-2:] - - x = self.sampler(x) - # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) - - out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * - (self.sampler.kernel_size[0] - 1) - - 1) // self.sampler.stride[0] + 1 - out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * - (self.sampler.kernel_size[1] - 1) - - 1) // self.sampler.stride[1] + 1 - - output_size = (out_h, out_w) - x = x.transpose(1, 2) # B, H/2*W/2, 4*C - x = self.norm(x) if self.norm else x - x = self.reduction(x) - return x, output_size - - -def inverse_sigmoid(x, eps=1e-5): - """Inverse function of sigmoid. - - Args: - x (Tensor): The tensor to do the - inverse. - eps (float): EPS avoid numerical - overflow. Defaults 1e-5. - Returns: - Tensor: The x has passed the inverse - function of sigmoid, has same - shape with input. - """ - x = x.clamp(min=0, max=1) - x1 = x.clamp(min=eps) - x2 = (1 - x).clamp(min=eps) - return torch.log(x1 / x2) - - -@MODELS.register_module() -class DetrTransformerDecoderLayer(BaseTransformerLayer): - """Implements decoder layer in DETR transformer. - - Args: - attn_cfgs (list[`mmengine.ConfigDict`] | list[dict] | dict )): - Configs for self_attention or cross_attention, the order - should be consistent with it in `operation_order`. If it is - a dict, it would be expand to the number of attention in - `operation_order`. - feedforward_channels (int): The hidden dimension for FFNs. - ffn_dropout (float): Probability of an element to be zeroed - in ffn. Default 0.0. - operation_order (tuple[str]): The execution order of operation - in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). - Default:None - act_cfg (dict): The activation config for FFNs. Default: `LN` - norm_cfg (dict): Config dict for normalization layer. - Default: `LN`. - ffn_num_fcs (int): The number of fully-connected layers in FFNs. - Default:2. - """ - - def __init__(self, - attn_cfgs, - feedforward_channels, - ffn_dropout=0.0, - operation_order=None, - act_cfg=dict(type='ReLU', inplace=True), - norm_cfg=dict(type='LN'), - ffn_num_fcs=2, - **kwargs): - super(DetrTransformerDecoderLayer, self).__init__( - attn_cfgs=attn_cfgs, - feedforward_channels=feedforward_channels, - ffn_dropout=ffn_dropout, - operation_order=operation_order, - act_cfg=act_cfg, - norm_cfg=norm_cfg, - ffn_num_fcs=ffn_num_fcs, - **kwargs) - assert len(operation_order) == 6 - assert set(operation_order) == set( - ['self_attn', 'norm', 'cross_attn', 'ffn']) - - -@MODELS.register_module() -class DetrTransformerEncoder(TransformerLayerSequence): - """TransformerEncoder of DETR. - - Args: - post_norm_cfg (dict): Config of last normalization layer. Default: - `LN`. Only used when `self.pre_norm` is `True` - """ - - def __init__(self, *args, post_norm_cfg=dict(type='LN'), **kwargs): - super(DetrTransformerEncoder, self).__init__(*args, **kwargs) - if post_norm_cfg is not None: - self.post_norm = build_norm_layer( - post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None - else: - assert not self.pre_norm, f'Use prenorm in ' \ - f'{self.__class__.__name__},' \ - f'Please specify post_norm_cfg' - self.post_norm = None - - def forward(self, *args, **kwargs): - """Forward function for `TransformerCoder`. - - Returns: - Tensor: forwarded results with shape [num_query, bs, embed_dims]. - """ - x = super(DetrTransformerEncoder, self).forward(*args, **kwargs) - if self.post_norm is not None: - x = self.post_norm(x) - return x - - -@MODELS.register_module() -class DetrTransformerDecoder(TransformerLayerSequence): - """Implements the decoder in DETR transformer. - - Args: - return_intermediate (bool): Whether to return intermediate outputs. - post_norm_cfg (dict): Config of last normalization layer. Default: - `LN`. - """ - - def __init__(self, - *args, - post_norm_cfg=dict(type='LN'), - return_intermediate=False, - **kwargs): - - super(DetrTransformerDecoder, self).__init__(*args, **kwargs) - self.return_intermediate = return_intermediate - if post_norm_cfg is not None: - self.post_norm = build_norm_layer(post_norm_cfg, - self.embed_dims)[1] - else: - self.post_norm = None - - def forward(self, query, *args, **kwargs): - """Forward function for `TransformerDecoder`. - - Args: - query (Tensor): Input query with shape - `(num_query, bs, embed_dims)`. - - Returns: - Tensor: Results with shape [1, num_query, bs, embed_dims] when - return_intermediate is `False`, otherwise it has shape - [num_layers, num_query, bs, embed_dims]. - """ - if not self.return_intermediate: - x = super().forward(query, *args, **kwargs) - if self.post_norm: - x = self.post_norm(x)[None] - return x - - intermediate = [] - for layer in self.layers: - query = layer(query, *args, **kwargs) - if self.return_intermediate: - if self.post_norm is not None: - intermediate.append(self.post_norm(query)) - else: - intermediate.append(query) - return torch.stack(intermediate) - - -@MODELS.register_module() -class Transformer(BaseModule): - """Implements the DETR transformer. - - Following the official DETR implementation, this module copy-paste - from torch.nn.Transformer with modifications: - - * positional encodings are passed in MultiheadAttention - * extra LN at the end of encoder is removed - * decoder returns a stack of activations from all decoding layers - - See `paper: End-to-End Object Detection with Transformers - `_ for details. - - Args: - encoder (`mmengine.ConfigDict` | Dict): Config of - TransformerEncoder. Defaults to None. - decoder ((`mmengine.ConfigDict` | Dict)): Config of - TransformerDecoder. Defaults to None - init_cfg (obj:`mmegine.ConfigDict`): The Config for initialization. - Defaults to None. - """ - - def __init__(self, encoder=None, decoder=None, init_cfg=None): - super(Transformer, self).__init__(init_cfg=init_cfg) - self.encoder = build_transformer_layer_sequence(encoder) - self.decoder = build_transformer_layer_sequence(decoder) - self.embed_dims = self.encoder.embed_dims - - def init_weights(self): - # follow the official DETR to init parameters - for m in self.modules(): - if hasattr(m, 'weight') and m.weight.dim() > 1: - xavier_init(m, distribution='uniform') - self._is_init = True - - def forward(self, x, mask, query_embed, pos_embed): - """Forward function for `Transformer`. - - Args: - x (Tensor): Input query with shape [bs, c, h, w] where - c = embed_dims. - mask (Tensor): The key_padding_mask used for encoder and decoder, - with shape [bs, h, w]. - query_embed (Tensor): The query embedding for decoder, with shape - [num_query, c]. - pos_embed (Tensor): The positional encoding for encoder and - decoder, with the same shape as `x`. - - Returns: - tuple[Tensor]: results of decoder containing the following tensor. - - - out_dec: Output from decoder. If return_intermediate_dec \ - is True output has shape [num_dec_layers, bs, - num_query, embed_dims], else has shape [1, bs, \ - num_query, embed_dims]. - - memory: Output results from encoder, with shape \ - [bs, embed_dims, h, w]. - """ - bs, c, h, w = x.shape - # use `view` instead of `flatten` for dynamically exporting to ONNX - x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c] - pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1) - query_embed = query_embed.unsqueeze(1).repeat( - 1, bs, 1) # [num_query, dim] -> [num_query, bs, dim] - mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w] - memory = self.encoder( - query=x, - key=None, - value=None, - query_pos=pos_embed, - query_key_padding_mask=mask) - target = torch.zeros_like(query_embed) - # out_dec: [num_layers, num_query, bs, dim] - out_dec = self.decoder( - query=target, - key=memory, - value=memory, - key_pos=pos_embed, - query_pos=query_embed, - key_padding_mask=mask) - out_dec = out_dec.transpose(1, 2) - memory = memory.permute(1, 2, 0).reshape(bs, c, h, w) - return out_dec, memory - - -@MODELS.register_module() -class DeformableDetrTransformerDecoder(TransformerLayerSequence): - """Implements the decoder in DETR transformer. - - Args: - return_intermediate (bool): Whether to return intermediate outputs. - coder_norm_cfg (dict): Config of last normalization layer. Default: - `LN`. - """ - - def __init__(self, *args, return_intermediate=False, **kwargs): - - super(DeformableDetrTransformerDecoder, self).__init__(*args, **kwargs) - self.return_intermediate = return_intermediate - - def forward(self, - query, - *args, - reference_points=None, - valid_ratios=None, - reg_branches=None, - **kwargs): - """Forward function for `TransformerDecoder`. - - Args: - query (Tensor): Input query with shape - `(num_query, bs, embed_dims)`. - reference_points (Tensor): The reference - points of offset. has shape - (bs, num_query, 4) when as_two_stage, - otherwise has shape ((bs, num_query, 2). - valid_ratios (Tensor): The radios of valid - points on the feature map, has shape - (bs, num_levels, 2) - reg_branch: (obj:`nn.ModuleList`): Used for - refining the regression results. Only would - be passed when with_box_refine is True, - otherwise would be passed a `None`. - - Returns: - Tensor: Results with shape [1, num_query, bs, embed_dims] when - return_intermediate is `False`, otherwise it has shape - [num_layers, num_query, bs, embed_dims]. - """ - output = query - intermediate = [] - intermediate_reference_points = [] - for lid, layer in enumerate(self.layers): - if reference_points.shape[-1] == 4: - reference_points_input = reference_points[:, :, None] * \ - torch.cat([valid_ratios, valid_ratios], -1)[:, None] - else: - assert reference_points.shape[-1] == 2 - reference_points_input = reference_points[:, :, None] * \ - valid_ratios[:, None] - output = layer( - output, - *args, - reference_points=reference_points_input, - **kwargs) - output = output.permute(1, 0, 2) - - if reg_branches is not None: - tmp = reg_branches[lid](output) - if reference_points.shape[-1] == 4: - new_reference_points = tmp + inverse_sigmoid( - reference_points) - new_reference_points = new_reference_points.sigmoid() - else: - assert reference_points.shape[-1] == 2 - new_reference_points = tmp - new_reference_points[..., :2] = tmp[ - ..., :2] + inverse_sigmoid(reference_points) - new_reference_points = new_reference_points.sigmoid() - reference_points = new_reference_points.detach() - - output = output.permute(1, 0, 2) - if self.return_intermediate: - intermediate.append(output) - intermediate_reference_points.append(reference_points) - - if self.return_intermediate: - return torch.stack(intermediate), torch.stack( - intermediate_reference_points) - - return output, reference_points - - -@MODELS.register_module() -class DeformableDetrTransformer(Transformer): - """Implements the DeformableDETR transformer. - - Args: - as_two_stage (bool): Generate query from encoder features. - Default: False. - num_feature_levels (int): Number of feature maps from FPN: - Default: 4. - two_stage_num_proposals (int): Number of proposals when set - `as_two_stage` as True. Default: 300. - """ - - def __init__(self, - as_two_stage=False, - num_feature_levels=4, - two_stage_num_proposals=300, - **kwargs): - super(DeformableDetrTransformer, self).__init__(**kwargs) - self.as_two_stage = as_two_stage - self.num_feature_levels = num_feature_levels - self.two_stage_num_proposals = two_stage_num_proposals - self.embed_dims = self.encoder.embed_dims - self.init_layers() - - def init_layers(self): - """Initialize layers of the DeformableDetrTransformer.""" - self.level_embeds = nn.Parameter( - torch.Tensor(self.num_feature_levels, self.embed_dims)) - - if self.as_two_stage: - self.enc_output = nn.Linear(self.embed_dims, self.embed_dims) - self.enc_output_norm = nn.LayerNorm(self.embed_dims) - self.pos_trans = nn.Linear(self.embed_dims * 2, - self.embed_dims * 2) - self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2) - else: - self.reference_points = nn.Linear(self.embed_dims, 2) - - def init_weights(self): - """Initialize the transformer weights.""" - for p in self.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) - for m in self.modules(): - if isinstance(m, MultiScaleDeformableAttention): - m.init_weights() - if not self.as_two_stage: - xavier_init(self.reference_points, distribution='uniform', bias=0.) - normal_(self.level_embeds) - - def gen_encoder_output_proposals(self, memory, memory_padding_mask, - spatial_shapes): - """Generate proposals from encoded memory. - - Args: - memory (Tensor) : The output of encoder, - has shape (bs, num_key, embed_dim). num_key is - equal the number of points on feature map from - all level. - memory_padding_mask (Tensor): Padding mask for memory. - has shape (bs, num_key). - spatial_shapes (Tensor): The shape of all feature maps. - has shape (num_level, 2). - - Returns: - tuple: A tuple of feature map and bbox prediction. - - - output_memory (Tensor): The input of decoder, \ - has shape (bs, num_key, embed_dim). num_key is \ - equal the number of points on feature map from \ - all levels. - - output_proposals (Tensor): The normalized proposal \ - after a inverse sigmoid, has shape \ - (bs, num_keys, 4). - """ - - N, S, C = memory.shape - proposals = [] - _cur = 0 - for lvl, (H, W) in enumerate(spatial_shapes): - mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H * W)].view( - N, H, W, 1) - valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) - valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) - - grid_y, grid_x = torch.meshgrid( - torch.linspace( - 0, H - 1, H, dtype=torch.float32, device=memory.device), - torch.linspace( - 0, W - 1, W, dtype=torch.float32, device=memory.device)) - grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) - - scale = torch.cat([valid_W.unsqueeze(-1), - valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2) - grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale - wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) - proposal = torch.cat((grid, wh), -1).view(N, -1, 4) - proposals.append(proposal) - _cur += (H * W) - output_proposals = torch.cat(proposals, 1) - output_proposals_valid = ((output_proposals > 0.01) & - (output_proposals < 0.99)).all( - -1, keepdim=True) - output_proposals = torch.log(output_proposals / (1 - output_proposals)) - output_proposals = output_proposals.masked_fill( - memory_padding_mask.unsqueeze(-1), float('inf')) - output_proposals = output_proposals.masked_fill( - ~output_proposals_valid, float('inf')) - - output_memory = memory - output_memory = output_memory.masked_fill( - memory_padding_mask.unsqueeze(-1), float(0)) - output_memory = output_memory.masked_fill(~output_proposals_valid, - float(0)) - output_memory = self.enc_output_norm(self.enc_output(output_memory)) - return output_memory, output_proposals - - @staticmethod - def get_reference_points(spatial_shapes, valid_ratios, device): - """Get the reference points used in decoder. - - Args: - spatial_shapes (Tensor): The shape of all - feature maps, has shape (num_level, 2). - valid_ratios (Tensor): The radios of valid - points on the feature map, has shape - (bs, num_levels, 2) - device (obj:`device`): The device where - reference_points should be. - - Returns: - Tensor: reference points used in decoder, has \ - shape (bs, num_keys, num_levels, 2). - """ - reference_points_list = [] - for lvl, (H, W) in enumerate(spatial_shapes): - # TODO check this 0.5 - ref_y, ref_x = torch.meshgrid( - torch.linspace( - 0.5, H - 0.5, H, dtype=torch.float32, device=device), - torch.linspace( - 0.5, W - 0.5, W, dtype=torch.float32, device=device)) - ref_y = ref_y.reshape(-1)[None] / ( - valid_ratios[:, None, lvl, 1] * H) - ref_x = ref_x.reshape(-1)[None] / ( - valid_ratios[:, None, lvl, 0] * W) - ref = torch.stack((ref_x, ref_y), -1) - reference_points_list.append(ref) - reference_points = torch.cat(reference_points_list, 1) - reference_points = reference_points[:, :, None] * valid_ratios[:, None] - return reference_points - - def get_valid_ratio(self, mask): - """Get the valid radios of feature maps of all level.""" - _, H, W = mask.shape - valid_H = torch.sum(~mask[:, :, 0], 1) - valid_W = torch.sum(~mask[:, 0, :], 1) - valid_ratio_h = valid_H.float() / H - valid_ratio_w = valid_W.float() / W - valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) - return valid_ratio - - def get_proposal_pos_embed(self, - proposals, - num_pos_feats=128, - temperature=10000): - """Get the position embedding of proposal.""" - scale = 2 * math.pi - dim_t = torch.arange( - num_pos_feats, dtype=torch.float32, device=proposals.device) - dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) - # N, L, 4 - proposals = proposals.sigmoid() * scale - # N, L, 4, 128 - pos = proposals[:, :, :, None] / dim_t - # N, L, 4, 64, 2 - pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), - dim=4).flatten(2) - return pos - - def forward(self, - mlvl_feats, - mlvl_masks, - query_embed, - mlvl_pos_embeds, - reg_branches=None, - cls_branches=None, - **kwargs): - """Forward function for `Transformer`. - - Args: - mlvl_feats (list(Tensor)): Input queries from - different level. Each element has shape - [bs, embed_dims, h, w]. - mlvl_masks (list(Tensor)): The key_padding_mask from - different level used for encoder and decoder, - each element has shape [bs, h, w]. - query_embed (Tensor): The query embedding for decoder, - with shape [num_query, c]. - mlvl_pos_embeds (list(Tensor)): The positional encoding - of feats from different level, has the shape - [bs, embed_dims, h, w]. - reg_branches (obj:`nn.ModuleList`): Regression heads for - feature maps from each decoder layer. Only would - be passed when - `with_box_refine` is True. Default to None. - cls_branches (obj:`nn.ModuleList`): Classification heads - for feature maps from each decoder layer. Only would - be passed when `as_two_stage` - is True. Default to None. - - - Returns: - tuple[Tensor]: results of decoder containing the following tensor. - - - inter_states: Outputs from decoder. If - return_intermediate_dec is True output has shape \ - (num_dec_layers, bs, num_query, embed_dims), else has \ - shape (1, bs, num_query, embed_dims). - - init_reference_out: The initial value of reference \ - points, has shape (bs, num_queries, 4). - - inter_references_out: The internal value of reference \ - points in decoder, has shape \ - (num_dec_layers, bs,num_query, embed_dims) - - enc_outputs_class: The classification score of \ - proposals generated from \ - encoder's feature maps, has shape \ - (batch, h*w, num_classes). \ - Only would be returned when `as_two_stage` is True, \ - otherwise None. - - enc_outputs_coord_unact: The regression results \ - generated from encoder's feature maps., has shape \ - (batch, h*w, 4). Only would \ - be returned when `as_two_stage` is True, \ - otherwise None. - """ - assert self.as_two_stage or query_embed is not None - - feat_flatten = [] - mask_flatten = [] - lvl_pos_embed_flatten = [] - spatial_shapes = [] - for lvl, (feat, mask, pos_embed) in enumerate( - zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): - bs, c, h, w = feat.shape - spatial_shape = (h, w) - spatial_shapes.append(spatial_shape) - feat = feat.flatten(2).transpose(1, 2) - mask = mask.flatten(1) - pos_embed = pos_embed.flatten(2).transpose(1, 2) - lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) - lvl_pos_embed_flatten.append(lvl_pos_embed) - feat_flatten.append(feat) - mask_flatten.append(mask) - feat_flatten = torch.cat(feat_flatten, 1) - mask_flatten = torch.cat(mask_flatten, 1) - lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) - spatial_shapes = torch.as_tensor( - spatial_shapes, dtype=torch.long, device=feat_flatten.device) - level_start_index = torch.cat((spatial_shapes.new_zeros( - (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) - valid_ratios = torch.stack( - [self.get_valid_ratio(m) for m in mlvl_masks], 1) - - reference_points = \ - self.get_reference_points(spatial_shapes, - valid_ratios, - device=feat.device) - - feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) - lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute( - 1, 0, 2) # (H*W, bs, embed_dims) - memory = self.encoder( - query=feat_flatten, - key=None, - value=None, - query_pos=lvl_pos_embed_flatten, - query_key_padding_mask=mask_flatten, - spatial_shapes=spatial_shapes, - reference_points=reference_points, - level_start_index=level_start_index, - valid_ratios=valid_ratios, - **kwargs) - - memory = memory.permute(1, 0, 2) - bs, _, c = memory.shape - if self.as_two_stage: - output_memory, output_proposals = \ - self.gen_encoder_output_proposals( - memory, mask_flatten, spatial_shapes) - enc_outputs_class = cls_branches[self.decoder.num_layers]( - output_memory) - enc_outputs_coord_unact = \ - reg_branches[ - self.decoder.num_layers](output_memory) + output_proposals - - topk = self.two_stage_num_proposals - # We only use the first channel in enc_outputs_class as foreground, - # the other (num_classes - 1) channels are actually not used. - # Its targets are set to be 0s, which indicates the first - # class (foreground) because we use [0, num_classes - 1] to - # indicate class labels, background class is indicated by - # num_classes (similar convention in RPN). - # See https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/deformable_detr_head.py#L241 # noqa - # This follows the official implementation of Deformable DETR. - topk_proposals = torch.topk( - enc_outputs_class[..., 0], topk, dim=1)[1] - topk_coords_unact = torch.gather( - enc_outputs_coord_unact, 1, - topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) - topk_coords_unact = topk_coords_unact.detach() - reference_points = topk_coords_unact.sigmoid() - init_reference_out = reference_points - pos_trans_out = self.pos_trans_norm( - self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) - query_pos, query = torch.split(pos_trans_out, c, dim=2) - else: - query_pos, query = torch.split(query_embed, c, dim=1) - query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) - query = query.unsqueeze(0).expand(bs, -1, -1) - reference_points = self.reference_points(query_pos).sigmoid() - init_reference_out = reference_points - - # decoder - query = query.permute(1, 0, 2) - memory = memory.permute(1, 0, 2) - query_pos = query_pos.permute(1, 0, 2) - inter_states, inter_references = self.decoder( - query=query, - key=None, - value=memory, - query_pos=query_pos, - key_padding_mask=mask_flatten, - reference_points=reference_points, - spatial_shapes=spatial_shapes, - level_start_index=level_start_index, - valid_ratios=valid_ratios, - reg_branches=reg_branches, - **kwargs) - - inter_references_out = inter_references - if self.as_two_stage: - return inter_states, init_reference_out,\ - inter_references_out, enc_outputs_class,\ - enc_outputs_coord_unact - return inter_states, init_reference_out, \ - inter_references_out, None, None - - -@MODELS.register_module() -class DynamicConv(BaseModule): - """Implements Dynamic Convolution. - - This module generate parameters for each sample and - use bmm to implement 1*1 convolution. Code is modified - from the `official github repo `_ . - - Args: - in_channels (int): The input feature channel. - Defaults to 256. - feat_channels (int): The inner feature channel. - Defaults to 64. - out_channels (int, optional): The output feature channel. - When not specified, it will be set to `in_channels` - by default - input_feat_shape (int): The shape of input feature. - Defaults to 7. - with_proj (bool): Project two-dimentional feature to - one-dimentional feature. Default to True. - act_cfg (dict): The activation config for DynamicConv. - norm_cfg (dict): Config dict for normalization layer. Default - layer normalization. - init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization. - Default: None. - """ - - def __init__(self, - in_channels=256, - feat_channels=64, - out_channels=None, - input_feat_shape=7, - with_proj=True, - act_cfg=dict(type='ReLU', inplace=True), - norm_cfg=dict(type='LN'), - init_cfg=None): - super(DynamicConv, self).__init__(init_cfg) - self.in_channels = in_channels - self.feat_channels = feat_channels - self.out_channels_raw = out_channels - self.input_feat_shape = input_feat_shape - self.with_proj = with_proj - self.act_cfg = act_cfg - self.norm_cfg = norm_cfg - self.out_channels = out_channels if out_channels else in_channels - - self.num_params_in = self.in_channels * self.feat_channels - self.num_params_out = self.out_channels * self.feat_channels - self.dynamic_layer = nn.Linear( - self.in_channels, self.num_params_in + self.num_params_out) - - self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] - self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1] - - self.activation = build_activation_layer(act_cfg) - - num_output = self.out_channels * input_feat_shape**2 - if self.with_proj: - self.fc_layer = nn.Linear(num_output, self.out_channels) - self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] - - def forward(self, param_feature, input_feature): - """Forward function for `DynamicConv`. - - Args: - param_feature (Tensor): The feature can be used - to generate the parameter, has shape - (num_all_proposals, in_channels). - input_feature (Tensor): Feature that - interact with parameters, has shape - (num_all_proposals, in_channels, H, W). - - Returns: - Tensor: The output feature has shape - (num_all_proposals, out_channels). - """ - input_feature = input_feature.flatten(2).permute(2, 0, 1) - - input_feature = input_feature.permute(1, 0, 2) - parameters = self.dynamic_layer(param_feature) - - param_in = parameters[:, :self.num_params_in].view( - -1, self.in_channels, self.feat_channels) - param_out = parameters[:, -self.num_params_out:].view( - -1, self.feat_channels, self.out_channels) - - # input_feature has shape (num_all_proposals, H*W, in_channels) - # param_in has shape (num_all_proposals, in_channels, feat_channels) - # feature has shape (num_all_proposals, H*W, feat_channels) - features = torch.bmm(input_feature, param_in) - features = self.norm_in(features) - features = self.activation(features) - - # param_out has shape (batch_size, feat_channels, out_channels) - features = torch.bmm(features, param_out) - features = self.norm_out(features) - features = self.activation(features) - - if self.with_proj: - features = features.flatten(1) - features = self.fc_layer(features) - features = self.fc_norm(features) - features = self.activation(features) - - return features diff --git a/mmdet/models/layers/transformer/__init__.py b/mmdet/models/layers/transformer/__init__.py new file mode 100644 index 00000000000..e230f422c6f --- /dev/null +++ b/mmdet/models/layers/transformer/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .deformable_detr_transformer import ( + DeformableDetrTransformerDecoder, DeformableDetrTransformerDecoderLayer, + DeformableDetrTransformerEncoder, DeformableDetrTransformerEncoderLayer) +from .detr_transformer import (DetrTransformerDecoder, + DetrTransformerDecoderLayer, + DetrTransformerEncoder, + DetrTransformerEncoderLayer) +from .utils import (MLP, AdaptivePadding, DynamicConv, PatchEmbed, + PatchMerging, inverse_sigmoid, nchw_to_nlc, nlc_to_nchw) + +__all__ = [ + 'nlc_to_nchw', 'nchw_to_nlc', 'AdaptivePadding', 'PatchEmbed', + 'PatchMerging', 'inverse_sigmoid', 'DynamicConv', 'MLP', + 'DetrTransformerEncoder', 'DetrTransformerDecoder', + 'DetrTransformerEncoderLayer', 'DetrTransformerDecoderLayer', + 'DeformableDetrTransformerEncoder', 'DeformableDetrTransformerDecoder', + 'DeformableDetrTransformerEncoderLayer', + 'DeformableDetrTransformerDecoderLayer' +] diff --git a/mmdet/models/layers/transformer/deformable_detr_transformer.py b/mmdet/models/layers/transformer/deformable_detr_transformer.py new file mode 100644 index 00000000000..5dde4cc193b --- /dev/null +++ b/mmdet/models/layers/transformer/deformable_detr_transformer.py @@ -0,0 +1,250 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + +import torch +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmcv.ops import MultiScaleDeformableAttention +from mmengine.model import ModuleList +from torch import Tensor, nn + +from .detr_transformer import (DetrTransformerDecoder, + DetrTransformerDecoderLayer, + DetrTransformerEncoder, + DetrTransformerEncoderLayer) +from .utils import inverse_sigmoid + + +class DeformableDetrTransformerEncoder(DetrTransformerEncoder): + """Transformer encoder of Deformable DETR.""" + + def _init_layers(self) -> None: + """Initialize encoder layers.""" + self.layers = ModuleList([ + DeformableDetrTransformerEncoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + + def forward(self, query: Tensor, query_pos: Tensor, + key_padding_mask: Tensor, spatial_shapes: Tensor, + level_start_index: Tensor, valid_ratios: Tensor, + **kwargs) -> Tensor: + """Forward function of Transformer encoder. + + Args: + query (Tensor): The input query, has shape (num_query, bs, dim). + query_pos (Tensor): The positional encoding for query, has shape + (num_query, bs, dim). If not None, it will be added to the + `query` before forward function. Defaults to None. + key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` + input. ByteTensor, has shape (num_query, bs). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + + Returns: + Tensor: Output queries of Transformer encoder, which is also + called 'encoder output embeddings' or 'memory', has shape + (num_query, bs, dim) + """ + reference_points = self.get_encoder_reference_points( + spatial_shapes, valid_ratios, device=query.device) + for layer in self.layers: + query = layer( + query=query, + query_pos=query_pos, + key_padding_mask=key_padding_mask, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reference_points=reference_points, + **kwargs) + return query + + @staticmethod + def get_encoder_reference_points( + spatial_shapes: Tensor, valid_ratios: Tensor, + device: Union[torch.device, str]) -> Tensor: + """Get the reference points used in encoder. + + Args: + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + device (obj:`device` or str): The device acquired by the + `reference_points`. + + Returns: + Tensor: Reference points used in decoder, has shape (bs, length, + num_levels, 2). + """ + + reference_points_list = [] + for lvl, (H, W) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace( + 0.5, H - 0.5, H, dtype=torch.float32, device=device), + torch.linspace( + 0.5, W - 0.5, W, dtype=torch.float32, device=device)) + ref_y = ref_y.reshape(-1)[None] / ( + valid_ratios[:, None, lvl, 1] * H) + ref_x = ref_x.reshape(-1)[None] / ( + valid_ratios[:, None, lvl, 0] * W) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + # [bs, sum(hw), num_level, 2] + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + +class DeformableDetrTransformerDecoder(DetrTransformerDecoder): + """Transformer Decoder of Deformable DETR.""" + + def _init_layers(self) -> None: + """Initialize decoder layers.""" + self.layers = ModuleList([ + DeformableDetrTransformerDecoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + if self.post_norm_cfg is not None: + raise ValueError('There is not post_norm in ' + 'DeformableDetrTransformerDecoder') + + def forward(self, + query: Tensor, + query_pos: Tensor, + value: Tensor, + key_padding_mask: Tensor, + reference_points: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + valid_ratios: Tensor, + reg_branches: Optional[nn.Module] = None, + **kwargs) -> Tuple[Tensor]: + """Forward function of Transformer decoder. + + Args: + query (Tensor): The input queries, has shape (num_query, bs, dim). + query_pos (Tensor): The input positional query, has shape + (num_query, bs, dim). It will be added to `query` before + forward function. + value (Tensor): The input values, has shape (num_value, bs, dim). + key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` + input. ByteTensor, has shape (num_value, bs). + reference_points (Tensor): The initial reference, has shape + (bs, num_query, 4) when `as_two_stage` is `True`, + otherwise has shape (bs, num_query, 2). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + reg_branches: (obj:`nn.ModuleList`, optional): Used for refining + the regression results. Only would be passed when + `with_box_refine` is `True`, otherwise would be `None`. + + Returns: + tuple[Tensor]: Outputs of Deformable Transformer Decoder. + + - output (Tensor): Output embeddings of the last decoder, has + shape (num_query, bs, embed_dims) when `return_intermediate` + is `False`. Otherwise, Intermediate output embeddings of all + decoder layers, has shape (num_decoder_layers, num_query, bs, + embed_dims). + - reference_points (Tensor): The reference of the last decoder + layer, has shape (bs, num_query, 4) when `return_intermediate` + is `False`. Otherwise, Intermediate references of all decoder + layers, has shape (num_decoder_layers, bs, num_query, 4). + """ + output = query + intermediate = [] + intermediate_reference_points = [] + for layer_id, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = \ + reference_points[:, :, None] * \ + torch.cat([valid_ratios, valid_ratios], -1)[:, None] + else: + assert reference_points.shape[-1] == 2 + reference_points_input = \ + reference_points[:, :, None] * \ + valid_ratios[:, None] + output = layer( + output, + query_pos=query_pos, + value=value, + key_padding_mask=key_padding_mask, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reference_points=reference_points_input, + **kwargs) + output = output.permute(1, 0, 2) + + if reg_branches is not None: + tmp = reg_branches[layer_id](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid( + reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[ + ..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + output = output.permute(1, 0, 2) + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack( + intermediate_reference_points) + + return output, reference_points + + +class DeformableDetrTransformerEncoderLayer(DetrTransformerEncoderLayer): + """Encoder layer of Deformable DETR.""" + + def _init_layers(self) -> None: + """Initialize self_attn, ffn, and norms.""" + self.self_attn = MultiScaleDeformableAttention(**self.self_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(2) + ] + self.norms = ModuleList(norms_list) + + +class DeformableDetrTransformerDecoderLayer(DetrTransformerDecoderLayer): + """Decoder layer of Deformable DETR.""" + + def _init_layers(self) -> None: + """Initialize self_attn, cross-attn, ffn, and norms.""" + self.self_attn = MultiheadAttention(**self.self_attn_cfg) + self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(3) + ] + self.norms = ModuleList(norms_list) diff --git a/mmdet/models/layers/transformer/detr_transformer.py b/mmdet/models/layers/transformer/detr_transformer.py new file mode 100644 index 00000000000..d70282da4b9 --- /dev/null +++ b/mmdet/models/layers/transformer/detr_transformer.py @@ -0,0 +1,359 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import torch +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmengine import ConfigDict +from mmengine.model import BaseModule, ModuleList +from torch import Tensor + +from mmdet.utils import ConfigType, OptConfigType + + +class DetrTransformerEncoder(BaseModule): + """Encoder of DETR. + + Args: + num_layers (int): Number of encoder layers. + layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder + layer. All the layers will share the same config. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + num_layers: int, + layer_cfg: ConfigType, + init_cfg: OptConfigType = None) -> None: + + super().__init__(init_cfg=init_cfg) + self.num_layers = num_layers + self.layer_cfg = layer_cfg + self._init_layers() + + def _init_layers(self) -> None: + """Initialize encoder layers.""" + self.layers = ModuleList([ + DetrTransformerEncoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + + def forward(self, query: Tensor, query_pos: Tensor, + key_padding_mask: Tensor, **kwargs) -> Tensor: + """Forward function of encoder. + + Args: + query (Tensor): Input queries of encoder, has shape + (num_query, bs, dim). + query_pos (Tensor): The positional embeddings of the queries, has + shape (num_query, bs, dim). + key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` + input. ByteTensor, has shape (num_query, bs). + + Returns: + Tensor: Has shape (bs, num_query, dim) if `batch_first` is `True`, + otherwise (num_query, bs, dim). + """ + for layer in self.layers: + query = layer(query, query_pos, key_padding_mask, **kwargs) + return query + + +class DetrTransformerDecoder(BaseModule): + """Decoder of DETR. + + Args: + num_layers (int): Number of decoder layers. + layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder + layer. All the layers will share the same config. + post_norm_cfg (:obj:`ConfigDict` or dict, optional): Config of the + post normalization layer. Defaults to `LN`. + return_intermediate (bool, optional): Whether to return outputs of + intermediate layers. Defaults to `True`, + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + num_layers: int, + layer_cfg: ConfigType, + post_norm_cfg: OptConfigType = dict(type='LN'), + return_intermediate: bool = True, + init_cfg: Union[dict, ConfigDict] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.layer_cfg = layer_cfg + self.num_layers = num_layers + self.post_norm_cfg = post_norm_cfg + self.return_intermediate = return_intermediate + self._init_layers() + + def _init_layers(self) -> None: + """Initialize decoder layers.""" + self.layers = ModuleList([ + DetrTransformerDecoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + self.post_norm = build_norm_layer(self.post_norm_cfg, + self.embed_dims)[1] + + def forward(self, query: Tensor, key: Tensor, value: Tensor, + query_pos: Tensor, key_pos: Tensor, key_padding_mask: Tensor, + **kwargs) -> Tensor: + """Forward function of decoder + Args: + query (Tensor): The input query, has shape (num_query, bs, dim) + if `batch_first` is `False`, else (bs, num_query, dim). + key (Tensor): The input key, has shape (num_key, bs, dim) if + `batch_first` is `False`, else (bs, num_key, dim). If + `None`, the `query` will be used. Defaults to `None`. + value (Tensor): The input value with the same shape as `key`. + If `None`, the `key` will be used. Defaults to `None`. + query_pos (Tensor): The positional encoding for `query`, with the + same shape as `query`. If not `None`, it will be added to + `query` before forward function. Defaults to `None`. + key_pos (Tensor): The positional encoding for `key`, with the + same shape as `key`. If not `None`, it will be added to + `key` before forward function. If `None`, and `query_pos` + has the same shape as `key`, then `query_pos` will be used + as `key_pos`. Defaults to `None`. + key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` + input. ByteTensor, has shape (num_value, bs). + + Returns: + Tensor: When `batch_first` is `False`. The forwarded results + will have shape (num_decoder_layers, num_query, bs, dim) if + `return_intermediate` is `True` else (num_query, bs, dim). + + When `batch_first` is `True`. The forwarded results will + have shape (num_decoder_layers, bs, num_query, dim) if + `return_intermediate` is `True` else (bs, num_query, dim). + """ + intermediate = [] + for layer in self.layers: + query = layer( + query, + key=key, + value=value, + query_pos=query_pos, + key_pos=key_pos, + key_padding_mask=key_padding_mask, + **kwargs) + if self.return_intermediate: + intermediate.append(self.post_norm(query)) + + if self.return_intermediate: + return torch.stack(intermediate) + + return query + + +class DetrTransformerEncoderLayer(BaseModule): + """Implements encoder layer in DETR transformer. + + Args: + self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self + attention. + ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN. + norm_cfg (:obj:`ConfigDict` or dict, optional): Config for + normalization layers. All the layers will share the same + config. Defaults to `LN`. + init_cfg (:obj:`ConfigDict` or dict, optional): Config to control + the initialization. Defaults to None. + batch_first (bool, optional): If `True`, the output will have shape + (bs, h*w, dim), otherwise (h*w, bs, dim). Defaults to False. + """ + + def __init__(self, + self_attn_cfg: OptConfigType = dict( + embed_dims=256, num_heads=8, dropout=0.0), + ffn_cfg: OptConfigType = dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0., + act_cfg=dict(type='ReLU', inplace=True)), + norm_cfg: OptConfigType = dict(type='LN'), + init_cfg: OptConfigType = None, + batch_first: bool = False) -> None: + + super().__init__(init_cfg=init_cfg) + if 'batch_first' in self_attn_cfg: + assert batch_first == self_attn_cfg['batch_first'] + else: + self_attn_cfg['batch_first'] = batch_first + self.batch_first = batch_first + self.self_attn_cfg = self_attn_cfg + self.ffn_cfg = ffn_cfg + self.norm_cfg = norm_cfg + self._init_layers() + + def _init_layers(self) -> None: + """Initialize self-attention, FFN, and normalization.""" + self.self_attn = MultiheadAttention(**self.self_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(2) + ] + self.norms = ModuleList(norms_list) + + def forward(self, query: Tensor, query_pos: Tensor, + key_padding_mask: Tensor, **kwargs) -> Tensor: + """Forward function of an encoder layer. + + Args: + query (Tensor): The input query, has shape (num_query, bs, dim) + if `batch_first` is `False`, else (bs, num_query, dim). + query_pos (Tensor): The positional encoding for query, with + the same shape as `query`. If not None, it will + be added to `query` before forward function. Defaults to None. + key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` + input. ByteTensor. has shape (num_query, bs). + Returns: + Tensor: forwarded results, has shape (num_query, bs, dim) if + `self.batch_first` is `False`, else (bs, num_query, dim). + """ + query = self.self_attn( + query=query, + key=query, + value=query, + query_pos=query_pos, + key_pos=query_pos, + key_padding_mask=key_padding_mask, + **kwargs) + query = self.norms[0](query) + query = self.ffn(query) + query = self.norms[1](query) + + return query + + +class DetrTransformerDecoderLayer(BaseModule): + """Implements decoder layer in DETR transformer. + + Args: + self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self + attention. + cross_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for cross + attention. + ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN. + norm_cfg (:obj:`ConfigDict` or dict, optional): Config for + normalization layers. All the layers will share the same + config. Defaults to `LN`. + init_cfg (:obj:`ConfigDict` or dict, optional): Config to control + the initialization. Defaults to None. + batch_first (bool, optional): If `True`, the output will have shape + (bs, h*w, dim), otherwise (h*w, bs, dim). Defaults to False. + """ + + def __init__(self, + self_attn_cfg: OptConfigType = dict( + embed_dims=256, num_heads=8, dropout=0.0), + cross_attn_cfg: OptConfigType = dict( + embed_dims=256, num_heads=8, dropout=0.0), + ffn_cfg: OptConfigType = dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0., + act_cfg=dict(type='ReLU', inplace=True), + ), + norm_cfg: OptConfigType = dict(type='LN'), + init_cfg: OptConfigType = None, + batch_first: bool = False) -> None: + + super().__init__(init_cfg=init_cfg) + for attn_cfg in (self_attn_cfg, cross_attn_cfg): + if 'batch_first' in attn_cfg: + assert batch_first == attn_cfg['batch_first'] + else: + attn_cfg['batch_first'] = batch_first + self.batch_first = batch_first + self.self_attn_cfg = self_attn_cfg + self.cross_attn_cfg = cross_attn_cfg + self.ffn_cfg = ffn_cfg + self.norm_cfg = norm_cfg + self._init_layers() + + def _init_layers(self) -> None: + """Initialize self-attention, FFN, and normalization.""" + self.self_attn = MultiheadAttention(**self.self_attn_cfg) + self.cross_attn = MultiheadAttention(**self.cross_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(3) + ] + self.norms = ModuleList(norms_list) + + def forward(self, + query: Tensor, + key: Tensor = None, + value: Tensor = None, + query_pos: Tensor = None, + key_pos: Tensor = None, + self_attn_masks: Tensor = None, + cross_attn_masks: Tensor = None, + key_padding_mask: Tensor = None, + **kwargs) -> Tensor: + """ + Args: + query (Tensor): The input query, has shape (num_query, bs, dim) + if `self.batch_first` is `False`, else (bs, num_query, dim). + key (Tensor, optional): The input key, has shape (num_key, bs, dim) + if `self.batch_first` is `False`, else (bs, num_key, dim). + If `None`, the `query` will be used. Defaults to `None`. + value (Tensor, optional): The input value, has the same shape as + `key`, as in `nn.MultiheadAttention.forward`. If `None`, the + `key` will be used. Defaults to `None`. + query_pos (Tensor, optional): The positional encoding for `query`, + has the same shape as `query`. If not `None`, it will be added + to `query` before forward function. Defaults to `None`. + key_pos (Tensor, optional): The positional encoding for `key`, has + the same shape as `key`. If not `None`, it will be added to + `key` before forward function. If None, and `query_pos` has the + same shape as `key`, then `query_pos` will be used for + `key_pos`. Defaults to None. + self_attn_masks (Tensor, optional): ByteTensor mask, has shape + (num_query, num_key), as in `nn.MultiheadAttention.forward`. + Defaults to None. + cross_attn_masks (Tensor, optional): ByteTensor mask, has shape + (num_query, num_key), as in `nn.MultiheadAttention.forward`. + Defaults to None. + key_padding_mask (Tensor, optional): The `key_padding_mask` of + `self_attn` input. ByteTensor, has shape (num_value, bs). + Defaults to None. + + Returns: + Tensor: forwarded results, has shape (num_query, bs, dim) if + `self.batch_first` is `False`, else (bs, num_query, dim). + """ + + query = self.self_attn( + query=query, + key=query, + value=query, + query_pos=query_pos, + key_pos=query_pos, + attn_mask=self_attn_masks, + **kwargs) + query = self.norms[0](query) + query = self.cross_attn( + query=query, + key=key, + value=value, + query_pos=query_pos, + key_pos=key_pos, + attn_mask=cross_attn_masks, + key_padding_mask=key_padding_mask, + **kwargs) + query = self.norms[1](query) + query = self.ffn(query) + query = self.norms[2](query) + + return query diff --git a/mmdet/models/layers/transformer/utils.py b/mmdet/models/layers/transformer/utils.py new file mode 100644 index 00000000000..a15e83a9dca --- /dev/null +++ b/mmdet/models/layers/transformer/utils.py @@ -0,0 +1,527 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Optional, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F +from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer, + build_norm_layer) +from mmengine.model import BaseModule, ModuleList +from mmengine.utils import to_2tuple +from torch import Tensor, nn + +from mmdet.registry import MODELS +from mmdet.utils import OptConfigType + + +def nlc_to_nchw(x: Tensor, hw_shape: Sequence[int]) -> Tensor: + """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, L, C] before conversion. + hw_shape (Sequence[int]): The height and width of output feature map. + + Returns: + Tensor: The output tensor of shape [N, C, H, W] after conversion. + """ + H, W = hw_shape + assert len(x.shape) == 3 + B, L, C = x.shape + assert L == H * W, 'The seq_len does not match H, W' + return x.transpose(1, 2).reshape(B, C, H, W).contiguous() + + +def nchw_to_nlc(x): + """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, C, H, W] before conversion. + + Returns: + Tensor: The output tensor of shape [N, L, C] after conversion. + """ + assert len(x.shape) == 4 + return x.flatten(2).transpose(1, 2).contiguous() + + +class AdaptivePadding(nn.Module): + """Applies padding to input (if needed) so that input can get fully covered + by filter you specified. It support two modes "same" and "corner". The + "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around + input. The "corner" mode would pad zero to bottom right. + + Args: + kernel_size (int | tuple): Size of the kernel: + stride (int | tuple): Stride of the filter. Default: 1: + dilation (int | tuple): Spacing between kernel elements. + Default: 1 + padding (str): Support "same" and "corner", "corner" mode + would pad zero to bottom right, and "same" mode would + pad zero around input. Default: "corner". + Example: + >>> kernel_size = 16 + >>> stride = 16 + >>> dilation = 1 + >>> input = torch.rand(1, 1, 15, 17) + >>> adap_pad = AdaptivePadding( + >>> kernel_size=kernel_size, + >>> stride=stride, + >>> dilation=dilation, + >>> padding="corner") + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + >>> input = torch.rand(1, 1, 16, 17) + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + """ + + def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'): + + super(AdaptivePadding, self).__init__() + + assert padding in ('same', 'corner') + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + padding = to_2tuple(padding) + dilation = to_2tuple(dilation) + + self.padding = padding + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + + def get_pad_shape(self, input_shape): + input_h, input_w = input_shape + kernel_h, kernel_w = self.kernel_size + stride_h, stride_w = self.stride + output_h = math.ceil(input_h / stride_h) + output_w = math.ceil(input_w / stride_w) + pad_h = max((output_h - 1) * stride_h + + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) + pad_w = max((output_w - 1) * stride_w + + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) + return pad_h, pad_w + + def forward(self, x): + pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) + if pad_h > 0 or pad_w > 0: + if self.padding == 'corner': + x = F.pad(x, [0, pad_w, 0, pad_h]) + elif self.padding == 'same': + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2 + ]) + return x + + +class PatchEmbed(BaseModule): + """Image to Patch Embedding. + + We use a conv layer to implement PatchEmbed. + + Args: + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + conv_type (str): The config dict for embedding + conv layer type selection. Default: "Conv2d. + kernel_size (int): The kernel_size of embedding conv. Default: 16. + stride (int): The slide stride of embedding conv. + Default: None (Would be set as `kernel_size`). + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int): The dilation rate of embedding conv. Default: 1. + bias (bool): Bias of embed conv. Default: True. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: None. + input_size (int | tuple | None): The size of input, which will be + used to calculate the out size. Only work when `dynamic_size` + is False. Default: None. + init_cfg (`mmengine.ConfigDict`, optional): The Config for + initialization. Default: None. + """ + + def __init__(self, + in_channels: int = 3, + embed_dims: int = 768, + conv_type: str = 'Conv2d', + kernel_size: int = 16, + stride: int = 16, + padding: Union[int, tuple, str] = 'corner', + dilation: int = 1, + bias: bool = True, + norm_cfg: OptConfigType = None, + input_size: Union[int, tuple] = None, + init_cfg: OptConfigType = None) -> None: + super(PatchEmbed, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + if stride is None: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of conv + padding = 0 + else: + self.adap_padding = None + padding = to_2tuple(padding) + + self.projection = build_conv_layer( + dict(type=conv_type), + in_channels=in_channels, + out_channels=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + else: + self.norm = None + + if input_size: + input_size = to_2tuple(input_size) + # `init_out_size` would be used outside to + # calculate the num_patches + # when `use_abs_pos_embed` outside + self.init_input_size = input_size + if self.adap_padding: + pad_h, pad_w = self.adap_padding.get_pad_shape(input_size) + input_h, input_w = input_size + input_h = input_h + pad_h + input_w = input_w + pad_w + input_size = (input_h, input_w) + + # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + h_out = (input_size[0] + 2 * padding[0] - dilation[0] * + (kernel_size[0] - 1) - 1) // stride[0] + 1 + w_out = (input_size[1] + 2 * padding[1] - dilation[1] * + (kernel_size[1] - 1) - 1) // stride[1] + 1 + self.init_out_size = (h_out, w_out) + else: + self.init_input_size = None + self.init_out_size = None + + def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int]]: + """ + Args: + x (Tensor): Has shape (B, C, H, W). In most case, C is 3. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, out_h * out_w, embed_dims) + - out_size (tuple[int]): Spatial shape of x, arrange as + (out_h, out_w). + """ + + if self.adap_padding: + x = self.adap_padding(x) + + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + x = x.flatten(2).transpose(1, 2) + if self.norm is not None: + x = self.norm(x) + return x, out_size + + +class PatchMerging(BaseModule): + """Merge patch feature map. + + This layer groups feature map by kernel_size, and applies norm and linear + layers to the grouped feature map. Our implementation uses `nn.Unfold` to + merge patch, which is about 25% faster than original implementation. + Instead, we need to modify pretrained models for compatibility. + + Args: + in_channels (int): The num of input channels. + to gets fully covered by filter and stride you specified.. + Default: True. + out_channels (int): The num of output channels. + kernel_size (int | tuple, optional): the kernel size in the unfold + layer. Defaults to 2. + stride (int | tuple, optional): the stride of the sliding blocks in the + unfold layer. Default: None. (Would be set as `kernel_size`) + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int | tuple, optional): dilation parameter in the unfold + layer. Default: 1. + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults: False. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Optional[Union[int, tuple]] = 2, + stride: Optional[Union[int, tuple]] = None, + padding: Union[int, tuple, str] = 'corner', + dilation: Optional[Union[int, tuple]] = 1, + bias: Optional[bool] = False, + norm_cfg: OptConfigType = dict(type='LN'), + init_cfg: OptConfigType = None) -> None: + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + if stride: + stride = stride + else: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of unfold + padding = 0 + else: + self.adap_padding = None + + padding = to_2tuple(padding) + self.sampler = nn.Unfold( + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride) + + sample_dim = kernel_size[0] * kernel_size[1] * in_channels + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) + + def forward(self, x: Tensor, + input_size: Tuple[int]) -> Tuple[Tensor, Tuple[int]]: + """ + Args: + x (Tensor): Has shape (B, H*W, C_in). + input_size (tuple[int]): The spatial shape of x, arrange as (H, W). + Default: None. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) + - out_size (tuple[int]): Spatial shape of x, arrange as + (Merged_H, Merged_W). + """ + B, L, C = x.shape + assert isinstance(input_size, Sequence), f'Expect ' \ + f'input_size is ' \ + f'`Sequence` ' \ + f'but get {input_size}' + + H, W = input_size + assert L == H * W, 'input feature has wrong size' + + x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + + if self.adap_padding: + x = self.adap_padding(x) + H, W = x.shape[-2:] + + x = self.sampler(x) + # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) + + out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * + (self.sampler.kernel_size[0] - 1) - + 1) // self.sampler.stride[0] + 1 + out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * + (self.sampler.kernel_size[1] - 1) - + 1) // self.sampler.stride[1] + 1 + + output_size = (out_h, out_w) + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + x = self.norm(x) if self.norm else x + x = self.reduction(x) + return x, output_size + + +def inverse_sigmoid(x: Tensor, eps: float = 1e-5) -> Tensor: + """Inverse function of sigmoid. + + Args: + x (Tensor): The tensor to do the inverse. + eps (float): EPS avoid numerical overflow. Defaults 1e-5. + Returns: + Tensor: The x has passed the inverse function of sigmoid, has the same + shape with input. + """ + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +class MLP(BaseModule): + """Very simple multi-layer perceptron (also called FFN) with relu. Mostly + used in DETR series detectors. + + Args: + input_dim (int): Feature dim of the input tensor. + hidden_dim (int): Feature dim of the hidden layer. + output_dim (int): Feature dim of the output tensor. + num_layers (int): Number of FFN layers. As the last + layer of MLP only contains FFN (Linear). + """ + + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, + num_layers: int) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = ModuleList( + Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x: Tensor) -> Tensor: + """Forward function of MLP. + + Args: + x (Tensor): The input feature, has shape + (num_query, bs, input_dim). + Returns: + Tensor: The output feature, has shape (num_query, bs, output_dim). + """ + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +@MODELS.register_module() +class DynamicConv(BaseModule): + """Implements Dynamic Convolution. + + This module generate parameters for each sample and + use bmm to implement 1*1 convolution. Code is modified + from the `official github repo `_ . + + Args: + in_channels (int): The input feature channel. + Defaults to 256. + feat_channels (int): The inner feature channel. + Defaults to 64. + out_channels (int, optional): The output feature channel. + When not specified, it will be set to `in_channels` + by default + input_feat_shape (int): The shape of input feature. + Defaults to 7. + with_proj (bool): Project two-dimentional feature to + one-dimentional feature. Default to True. + act_cfg (dict): The activation config for DynamicConv. + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization. + init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + in_channels: int = 256, + feat_channels: int = 64, + out_channels: Optional[int] = None, + input_feat_shape: int = 7, + with_proj: bool = True, + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + norm_cfg: OptConfigType = dict(type='LN'), + init_cfg: OptConfigType = None) -> None: + super(DynamicConv, self).__init__(init_cfg) + self.in_channels = in_channels + self.feat_channels = feat_channels + self.out_channels_raw = out_channels + self.input_feat_shape = input_feat_shape + self.with_proj = with_proj + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.out_channels = out_channels if out_channels else in_channels + + self.num_params_in = self.in_channels * self.feat_channels + self.num_params_out = self.out_channels * self.feat_channels + self.dynamic_layer = nn.Linear( + self.in_channels, self.num_params_in + self.num_params_out) + + self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1] + + self.activation = build_activation_layer(act_cfg) + + num_output = self.out_channels * input_feat_shape**2 + if self.with_proj: + self.fc_layer = nn.Linear(num_output, self.out_channels) + self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] + + def forward(self, param_feature: Tensor, input_feature: Tensor) -> Tensor: + """Forward function for `DynamicConv`. + + Args: + param_feature (Tensor): The feature can be used + to generate the parameter, has shape + (num_all_proposals, in_channels). + input_feature (Tensor): Feature that + interact with parameters, has shape + (num_all_proposals, in_channels, H, W). + + Returns: + Tensor: The output feature has shape + (num_all_proposals, out_channels). + """ + input_feature = input_feature.flatten(2).permute(2, 0, 1) + + input_feature = input_feature.permute(1, 0, 2) + parameters = self.dynamic_layer(param_feature) + + param_in = parameters[:, :self.num_params_in].view( + -1, self.in_channels, self.feat_channels) + param_out = parameters[:, -self.num_params_out:].view( + -1, self.feat_channels, self.out_channels) + + # input_feature has shape (num_all_proposals, H*W, in_channels) + # param_in has shape (num_all_proposals, in_channels, feat_channels) + # feature has shape (num_all_proposals, H*W, feat_channels) + features = torch.bmm(input_feature, param_in) + features = self.norm_in(features) + features = self.activation(features) + + # param_out has shape (batch_size, feat_channels, out_channels) + features = torch.bmm(features, param_out) + features = self.norm_out(features) + features = self.activation(features) + + if self.with_proj: + features = features.flatten(1) + features = self.fc_layer(features) + features = self.fc_norm(features) + features = self.activation(features) + + return features diff --git a/tests/test_models/test_layers/test_transformer.py b/tests/test_models/test_layers/test_transformer.py index 9151e308424..c219c103f4c 100644 --- a/tests/test_models/test_layers/test_transformer.py +++ b/tests/test_models/test_layers/test_transformer.py @@ -6,8 +6,7 @@ from mmdet.models.layers.transformer import (AdaptivePadding, DetrTransformerDecoder, DetrTransformerEncoder, - PatchEmbed, PatchMerging, - Transformer) + PatchEmbed, PatchMerging) def test_adaptive_padding(): @@ -532,39 +531,3 @@ def test_detr_transformer_dencoder_encoder_layer(): DetrTransformerEncoder(**config) -def test_transformer(): - config = ConfigDict( - dict( - encoder=dict( - type='DetrTransformerEncoder', - num_layers=6, - transformerlayers=dict( - type='BaseTransformerLayer', - attn_cfgs=[ - dict( - type='MultiheadAttention', - embed_dims=256, - num_heads=8, - dropout=0.1) - ], - feedforward_channels=2048, - ffn_dropout=0.1, - operation_order=('self_attn', 'norm', 'ffn', 'norm'))), - decoder=dict( - type='DetrTransformerDecoder', - return_intermediate=True, - num_layers=6, - transformerlayers=dict( - type='DetrTransformerDecoderLayer', - attn_cfgs=dict( - type='MultiheadAttention', - embed_dims=256, - num_heads=8, - dropout=0.1), - feedforward_channels=2048, - ffn_dropout=0.1, - operation_order=('self_attn', 'norm', 'cross_attn', 'norm', - 'ffn', 'norm')), - ))) - transformer = Transformer(**config) - transformer.init_weights()