diff --git a/paddlex/inference/models/object_detection/modeling/__init__.py b/paddlex/inference/models/object_detection/modeling/__init__.py index b64cf01fdc..acc3858aaf 100644 --- a/paddlex/inference/models/object_detection/modeling/__init__.py +++ b/paddlex/inference/models/object_detection/modeling/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .rt_detr import RTDETR diff --git a/paddlex/inference/models/object_detection/modeling/rt_detr.py b/paddlex/inference/models/object_detection/modeling/rt_detr.py new file mode 100644 index 0000000000..ecf5d03879 --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/rt_detr.py @@ -0,0 +1,345 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import, division, print_function + +import paddle +import paddle.nn.functional as F + +from ...common.transformers.transformers import PretrainedConfig, PretrainedModel +from .rtdetrl_modules.detr_head import DINOHead +from .rtdetrl_modules.hgnet_v2 import PPHGNetV2 +from .rtdetrl_modules.hybrid_encoder import HybridEncoder, TransformerLayer +from .rtdetrl_modules.modules.detr_loss import DINOLoss +from .rtdetrl_modules.modules.matchers import HungarianMatcher +from .rtdetrl_modules.modules.utils import bbox_cxcywh_to_xyxy +from .rtdetrl_modules.rtdetr_transformer import RTDETRTransformer + +__all__ = ["RTDETR"] + + +class DETRPostProcess(object): + __shared__ = ["num_classes", "use_focal_loss", "with_mask"] + __inject__ = [] + + def __init__( + self, + num_classes=80, + num_top_queries=100, + dual_queries=False, + dual_groups=0, + use_focal_loss=False, + with_mask=False, + mask_stride=4, + mask_threshold=0.5, + use_avg_mask_score=False, + bbox_decode_type="origin", + ): + super(DETRPostProcess, self).__init__() + assert bbox_decode_type in ["origin", "pad"] + + self.num_classes = num_classes + self.num_top_queries = num_top_queries + self.dual_queries = dual_queries + self.dual_groups = dual_groups + self.use_focal_loss = use_focal_loss + self.with_mask = with_mask + self.mask_stride = mask_stride + self.mask_threshold = mask_threshold + self.use_avg_mask_score = use_avg_mask_score + self.bbox_decode_type = bbox_decode_type + + def _mask_postprocess(self, mask_pred, score_pred): + mask_score = F.sigmoid(mask_pred) + mask_pred = (mask_score > self.mask_threshold).astype(mask_score.dtype) + if self.use_avg_mask_score: + avg_mask_score = (mask_pred * mask_score).sum([-2, -1]) / ( + mask_pred.sum([-2, -1]) + 1e-6 + ) + score_pred *= avg_mask_score + + return mask_pred.flatten(0, 1).astype("int32"), score_pred + + def __call__(self, head_out, im_shape, scale_factor, pad_shape): + """ + Decode the bbox and mask. + + Args: + head_out (tuple): bbox_pred, cls_logit and masks of bbox_head output. + im_shape (Tensor): The shape of the input image without padding. + scale_factor (Tensor): The scale factor of the input image. + pad_shape (Tensor): The shape of the input image with padding. + Returns: + bbox_pred (Tensor): The output prediction with shape [N, 6], including + labels, scores and bboxes. The size of bboxes are corresponding + to the input image, the bboxes may be used in other branch. + bbox_num (Tensor): The number of prediction boxes of each batch with + shape [bs], and is N. + """ + bboxes, logits, masks = head_out + if self.dual_queries: + num_queries = logits.shape[1] + logits, bboxes = ( + logits[:, : int(num_queries // (self.dual_groups + 1)), :], + bboxes[:, : int(num_queries // (self.dual_groups + 1)), :], + ) + + bbox_pred = bbox_cxcywh_to_xyxy(bboxes) + # calculate the original shape of the image + origin_shape = paddle.floor(im_shape / scale_factor + 0.5) + img_h, img_w = paddle.split(origin_shape, 2, axis=-1) + if self.bbox_decode_type == "pad": + # calculate the shape of the image with padding + out_shape = pad_shape / im_shape * origin_shape + out_shape = out_shape.flip(1).tile([1, 2]).unsqueeze(1) + elif self.bbox_decode_type == "origin": + out_shape = origin_shape.flip(1).tile([1, 2]).unsqueeze(1) + else: + raise Exception(f"Wrong `bbox_decode_type`: {self.bbox_decode_type}.") + bbox_pred *= out_shape + + scores = ( + F.sigmoid(logits) if self.use_focal_loss else F.softmax(logits)[:, :, :-1] + ) + + if not self.use_focal_loss: + scores, labels = scores.max(-1), scores.argmax(-1) + if scores.shape[1] > self.num_top_queries: + scores, index = paddle.topk(scores, self.num_top_queries, axis=-1) + batch_ind = ( + paddle.arange(end=scores.shape[0]) + .unsqueeze(-1) + .tile([1, self.num_top_queries]) + ) + index = paddle.stack([batch_ind, index], axis=-1) + labels = paddle.gather_nd(labels, index) + bbox_pred = paddle.gather_nd(bbox_pred, index) + else: + scores, index = paddle.topk( + scores.flatten(1), self.num_top_queries, axis=-1 + ) + labels = index % self.num_classes + index = index // self.num_classes + batch_ind = ( + paddle.arange(end=scores.shape[0]) + .unsqueeze(-1) + .tile([1, self.num_top_queries]) + ) + index = paddle.stack([batch_ind, index], axis=-1) + bbox_pred = paddle.gather_nd(bbox_pred, index) + + mask_pred = None + if self.with_mask: + assert masks is not None + assert masks.shape[0] == 1 + masks = paddle.gather_nd(masks, index) + if self.bbox_decode_type == "pad": + masks = F.interpolate( + masks, + scale_factor=self.mask_stride, + mode="bilinear", + align_corners=False, + ) + # TODO: Support prediction with bs>1. + # remove padding for input image + h, w = im_shape.astype("int32")[0] + masks = masks[..., :h, :w] + # get pred_mask in the original resolution. + img_h = img_h[0].astype("int32") + img_w = img_w[0].astype("int32") + masks = F.interpolate( + masks, size=[img_h, img_w], mode="bilinear", align_corners=False + ) + mask_pred, scores = self._mask_postprocess(masks, scores) + + bbox_pred = paddle.concat( + [labels.unsqueeze(-1).astype("float32"), scores.unsqueeze(-1), bbox_pred], + axis=-1, + ) + bbox_num = paddle.to_tensor(self.num_top_queries, dtype="int32").tile( + [bbox_pred.shape[0]] + ) + bbox_pred = bbox_pred.reshape([-1, 6]) + return bbox_pred, bbox_num, mask_pred + + +class RTDETRConfig(PretrainedConfig): + def __init__( + self, + backbone, + HybridEncoder, + RTDETRTransformer, + DINOHead, + DETRPostProcess, + ): + if backbone["name"] == "PPHGNetV2": + self.arch = backbone["arch"] + self.return_idx = backbone["return_idx"] + self.freeze_stem_only = backbone["freeze_stem_only"] + self.freeze_at = backbone["freeze_at"] + self.freeze_norm = backbone["freeze_norm"] + self.lr_mult_list = backbone["lr_mult_list"] + else: + raise RuntimeError( + f"There is no dynamic graph implementation for backbone {backbone['name']}." + ) + self.hidden_dim = HybridEncoder["hidden_dim"] + self.use_encoder_idx = HybridEncoder["use_encoder_idx"] + self.num_encoder_layers = HybridEncoder["num_encoder_layers"] + self.el_d_model = HybridEncoder["encoder_layer"]["d_model"] + self.el_nhead = HybridEncoder["encoder_layer"]["nhead"] + self.el_dim_feedforward = HybridEncoder["encoder_layer"]["dim_feedforward"] + self.el_dropout = HybridEncoder["encoder_layer"]["dropout"] + self.el_activation = HybridEncoder["encoder_layer"]["activation"] + self.expansion = HybridEncoder["expansion"] + self.tf_num_queries = RTDETRTransformer["num_queries"] + self.tf_position_embed_type = RTDETRTransformer["position_embed_type"] + self.tf_feat_strides = RTDETRTransformer["feat_strides"] + self.tf_num_levels = RTDETRTransformer["num_levels"] + self.tf_nhead = RTDETRTransformer["nhead"] + self.tf_num_decoder_layers = RTDETRTransformer["num_decoder_layers"] + self.tf_backbone_feat_channels = RTDETRTransformer["backbone_feat_channels"] + self.tf_dim_feedforward = RTDETRTransformer["dim_feedforward"] + self.tf_dropout = RTDETRTransformer["dropout"] + self.tf_activation = RTDETRTransformer["activation"] + self.tf_num_denoising = RTDETRTransformer["num_denoising"] + self.tf_label_noise_ratio = RTDETRTransformer["label_noise_ratio"] + self.tf_box_noise_scale = RTDETRTransformer["box_noise_scale"] + self.tf_learnt_init_query = RTDETRTransformer["learnt_init_query"] + self.loss_coeff = DINOHead["loss"]["loss_coeff"] + self.aux_loss = DINOHead["loss"]["aux_loss"] + self.use_vfl = DINOHead["loss"]["use_vfl"] + self.matcher_coeff = DINOHead["loss"]["matcher"]["matcher_coeff"] + self.num_top_queries = DETRPostProcess["num_top_queries"] + self.use_focal_loss = DETRPostProcess["use_focal_loss"] + self.tensor_parallel_degree = 1 + + +class RTDETR(PretrainedModel): + + config_class = RTDETRConfig + + def __init__(self, config: RTDETRConfig): + super().__init__(config) + + self.backbone = PPHGNetV2( + arch=self.config.arch, + lr_mult_list=self.config.lr_mult_list, + return_idx=self.config.return_idx, + freeze_stem_only=self.config.freeze_stem_only, + freeze_at=self.config.freeze_at, + freeze_norm=self.config.freeze_norm, + ) + self.neck = HybridEncoder( + hidden_dim=self.config.hidden_dim, + use_encoder_idx=self.config.use_encoder_idx, + num_encoder_layers=self.config.num_encoder_layers, + encoder_layer=TransformerLayer( + d_model=self.config.el_d_model, + nhead=self.config.el_nhead, + dim_feedforward=self.config.el_dim_feedforward, + dropout=self.config.el_dropout, + activation=self.config.el_activation, + ), + expansion=self.config.expansion, + ) + self.transformer = RTDETRTransformer( + num_queries=self.config.tf_num_queries, + position_embed_type=self.config.tf_position_embed_type, + feat_strides=self.config.tf_feat_strides, + backbone_feat_channels=self.config.tf_backbone_feat_channels, + num_levels=self.config.tf_num_levels, + nhead=self.config.tf_nhead, + num_decoder_layers=self.config.tf_num_decoder_layers, + dim_feedforward=self.config.tf_dim_feedforward, + dropout=self.config.tf_dropout, + activation=self.config.tf_activation, + num_denoising=self.config.tf_num_denoising, + label_noise_ratio=self.config.tf_label_noise_ratio, + box_noise_scale=self.config.tf_box_noise_scale, + learnt_init_query=self.config.tf_learnt_init_query, + ) + self.head = DINOHead( + loss=DINOLoss( + loss_coeff=self.config.loss_coeff, + aux_loss=self.config.aux_loss, + use_vfl=self.config.use_vfl, + matcher=HungarianMatcher( + matcher_coeff=self.config.matcher_coeff, + ), + ) + ) + self.post_process = DETRPostProcess( + num_top_queries=self.config.num_top_queries, + use_focal_loss=self.config.use_focal_loss, + ) + + def forward(self, inputs): + x = paddle.to_tensor(inputs[1]) + x = self.backbone(x) + x_neck = self.neck(x) + x = self.transformer(x_neck) + preds = self.head(x, x_neck) + bbox, bbox_num, mask = self.post_process( + preds, + paddle.to_tensor(inputs[0]), + paddle.to_tensor(inputs[2]), + inputs[1][2:].shape, + ) + output = [bbox, bbox_num] + return output + + def get_transpose_weight_keys(self): + need_to_transpose = [] + all_weight_keys = [] + for name, param in self.neck.named_parameters(): + all_weight_keys.append("neck." + name) + for name, param in self.transformer.named_parameters(): + all_weight_keys.append("transformer." + name) + for i in range(len(all_weight_keys)): + if ("out_proj" in all_weight_keys[i]) and ( + "bias" not in all_weight_keys[i] + ): + need_to_transpose.append(all_weight_keys[i]) + return need_to_transpose + + def get_hf_state_dict(self, *args, **kwargs): + + model_state_dict = self.state_dict(*args, **kwargs) + + hf_state_dict = {} + for old_key, value in model_state_dict.items(): + if "_mean" in old_key: + new_key = old_key.replace("_mean", "running_mean") + elif "_variance" in old_key: + new_key = old_key.replace("_variance", "running_var") + else: + new_key = old_key + hf_state_dict[new_key] = value + + return hf_state_dict + + def set_hf_state_dict(self, state_dict, *args, **kwargs): + + key_mapping = {} + for old_key in list(state_dict.keys()): + if "running_mean" in old_key: + key_mapping[old_key] = old_key.replace("running_mean", "_mean") + elif "running_var" in old_key: + key_mapping[old_key] = old_key.replace("running_var", "_variance") + + for old_key, new_key in key_mapping.items(): + state_dict[new_key] = state_dict.pop(old_key) + + return self.set_state_dict(state_dict, *args, **kwargs) diff --git a/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/detr_head.py b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/detr_head.py new file mode 100644 index 0000000000..d7a6c3a33b --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/detr_head.py @@ -0,0 +1,728 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import, division, print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import pycocotools.mask as mask_util + +from .modules.detr_ops import inverse_sigmoid +from .modules.initializer import constant_, linear_init_ + +__all__ = ["DETRHead", "DeformableDETRHead", "DINOHead", "MaskDINOHead"] + + +def get_activation(name="LeakyReLU"): + if name == "silu": + module = nn.Silu() + elif name == "relu": + module = nn.ReLU() + elif name in ["LeakyReLU", "leakyrelu", "lrelu"]: + module = nn.LeakyReLU(0.1) + elif name is None: + module = nn.Identity() + else: + raise AttributeError("Unsupported act type: {}".format(name)) + return module + + +class MLP(nn.Layer): + """This code is based on + https://github.com/facebookresearch/detr/blob/main/models/detr.py + """ + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act="relu"): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.LayerList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.act = get_activation(act) + self._reset_parameters() + + def _reset_parameters(self): + for l in self.layers: + linear_init_(l) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class MultiHeadAttentionMap(nn.Layer): + """This code is based on + https://github.com/facebookresearch/detr/blob/main/models/segmentation.py + + This is a 2D attention module, which only returns the attention softmax (no multiplication by value) + """ + + def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True): + super().__init__() + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.dropout = nn.Dropout(dropout) + + weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.XavierUniform() + ) + bias_attr = ( + paddle.framework.ParamAttr(initializer=paddle.nn.initializer.Constant()) + if bias + else False + ) + + self.q_proj = nn.Linear(query_dim, hidden_dim, weight_attr, bias_attr) + self.k_proj = nn.Conv2D( + query_dim, hidden_dim, 1, weight_attr=weight_attr, bias_attr=bias_attr + ) + + self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 + + def forward(self, q, k, mask=None): + q = self.q_proj(q) + k = self.k_proj(k) + bs, num_queries, n, c, h, w = ( + q.shape[0], + q.shape[1], + self.num_heads, + self.hidden_dim // self.num_heads, + k.shape[-2], + k.shape[-1], + ) + qh = q.reshape([bs, num_queries, n, c]) + kh = k.reshape([bs, n, c, h, w]) + # weights = paddle.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh) + qh = qh.transpose([0, 2, 1, 3]).reshape([-1, num_queries, c]) + kh = kh.reshape([-1, c, h * w]) + weights = ( + paddle.bmm(qh * self.normalize_fact, kh) + .reshape([bs, n, num_queries, h, w]) + .transpose([0, 2, 1, 3, 4]) + ) + + if mask is not None: + weights += mask + # fix a potenial bug: https://github.com/facebookresearch/detr/issues/247 + weights = F.softmax(weights.flatten(3), axis=-1).reshape(weights.shape) + weights = self.dropout(weights) + return weights + + +class MaskHeadFPNConv(nn.Layer): + """This code is based on + https://github.com/facebookresearch/detr/blob/main/models/segmentation.py + + Simple convolutional head, using group norm. + Upsampling is done using a FPN approach + """ + + def __init__(self, input_dim, fpn_dims, context_dim, num_groups=8): + super().__init__() + + inter_dims = [ + input_dim, + ] + [context_dim // (2**i) for i in range(1, 5)] + weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.KaimingUniform() + ) + bias_attr = paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant() + ) + + self.conv0 = self._make_layers( + input_dim, input_dim, 3, num_groups, weight_attr, bias_attr + ) + self.conv_inter = nn.LayerList() + for in_dims, out_dims in zip(inter_dims[:-1], inter_dims[1:]): + self.conv_inter.append( + self._make_layers( + in_dims, out_dims, 3, num_groups, weight_attr, bias_attr + ) + ) + + self.conv_out = nn.Conv2D( + inter_dims[-1], + 1, + 3, + padding=1, + weight_attr=weight_attr, + bias_attr=bias_attr, + ) + + self.adapter = nn.LayerList() + for i in range(len(fpn_dims)): + self.adapter.append( + nn.Conv2D( + fpn_dims[i], + inter_dims[i + 1], + 1, + weight_attr=weight_attr, + bias_attr=bias_attr, + ) + ) + + def _make_layers( + self, + in_dims, + out_dims, + kernel_size, + num_groups, + weight_attr=None, + bias_attr=None, + ): + return nn.Sequential( + nn.Conv2D( + in_dims, + out_dims, + kernel_size, + padding=kernel_size // 2, + weight_attr=weight_attr, + bias_attr=bias_attr, + ), + nn.GroupNorm(num_groups, out_dims), + nn.ReLU(), + ) + + def forward(self, x, bbox_attention_map, fpns): + x = paddle.concat( + [ + x.tile([bbox_attention_map.shape[1], 1, 1, 1]), + bbox_attention_map.flatten(0, 1), + ], + 1, + ) + x = self.conv0(x) + for inter_layer, adapter_layer, feat in zip( + self.conv_inter[:-1], self.adapter, fpns + ): + feat = adapter_layer(feat).tile([bbox_attention_map.shape[1], 1, 1, 1]) + x = inter_layer(x) + x = feat + F.interpolate(x, size=feat.shape[-2:]) + + x = self.conv_inter[-1](x) + x = self.conv_out(x) + return x + + +class DETRHead(nn.Layer): + __shared__ = ["num_classes", "hidden_dim", "use_focal_loss"] + __inject__ = ["loss"] + + def __init__( + self, + num_classes=80, + hidden_dim=256, + nhead=8, + num_mlp_layers=3, + loss="DETRLoss", + fpn_dims=[1024, 512, 256], + with_mask_head=False, + use_focal_loss=False, + ): + super(DETRHead, self).__init__() + # add background class + self.num_classes = num_classes if use_focal_loss else num_classes + 1 + self.hidden_dim = hidden_dim + self.loss = loss + self.with_mask_head = with_mask_head + self.use_focal_loss = use_focal_loss + + self.score_head = nn.Linear(hidden_dim, self.num_classes) + self.bbox_head = MLP( + hidden_dim, hidden_dim, output_dim=4, num_layers=num_mlp_layers + ) + if self.with_mask_head: + self.bbox_attention = MultiHeadAttentionMap(hidden_dim, hidden_dim, nhead) + self.mask_head = MaskHeadFPNConv(hidden_dim + nhead, fpn_dims, hidden_dim) + self._reset_parameters() + + def _reset_parameters(self): + linear_init_(self.score_head) + + @classmethod + def from_config(cls, cfg, hidden_dim, nhead, input_shape): + + return { + "hidden_dim": hidden_dim, + "nhead": nhead, + "fpn_dims": [i.channels for i in input_shape[::-1]][1:], + } + + @staticmethod + def get_gt_mask_from_polygons(gt_poly, pad_mask): + out_gt_mask = [] + for polygons, padding in zip(gt_poly, pad_mask): + height, width = int(padding[:, 0].sum()), int(padding[0, :].sum()) + masks = [] + for obj_poly in polygons: + rles = mask_util.frPyObjects(obj_poly, height, width) + rle = mask_util.merge(rles) + masks.append(paddle.to_tensor(mask_util.decode(rle)).astype("float32")) + masks = paddle.stack(masks) + masks_pad = paddle.zeros( + [masks.shape[0], pad_mask.shape[1], pad_mask.shape[2]] + ) + masks_pad[:, :height, :width] = masks + out_gt_mask.append(masks_pad) + return out_gt_mask + + def forward(self, out_transformer, body_feats, inputs=None): + r""" + Args: + out_transformer (Tuple): (feats: [num_levels, batch_size, + num_queries, hidden_dim], + memory: [batch_size, hidden_dim, h, w], + src_proj: [batch_size, h*w, hidden_dim], + src_mask: [batch_size, 1, 1, h, w]) + body_feats (List(Tensor)): list[[B, C, H, W]] + inputs (dict): dict(inputs) + """ + feats, memory, src_proj, src_mask = out_transformer + outputs_logit = self.score_head(feats) + outputs_bbox = F.sigmoid(self.bbox_head(feats)) + outputs_seg = None + if self.with_mask_head: + bbox_attention_map = self.bbox_attention(feats[-1], memory, src_mask) + fpn_feats = [a for a in body_feats[::-1]][1:] + outputs_seg = self.mask_head(src_proj, bbox_attention_map, fpn_feats) + outputs_seg = outputs_seg.reshape( + [ + feats.shape[1], + feats.shape[2], + outputs_seg.shape[-2], + outputs_seg.shape[-1], + ] + ) + + if self.training: + assert inputs is not None + assert "gt_bbox" in inputs and "gt_class" in inputs + gt_mask = ( + self.get_gt_mask_from_polygons(inputs["gt_poly"], inputs["pad_mask"]) + if "gt_poly" in inputs + else None + ) + return self.loss( + outputs_bbox, + outputs_logit, + inputs["gt_bbox"], + inputs["gt_class"], + masks=outputs_seg, + gt_mask=gt_mask, + ) + else: + return (outputs_bbox[-1], outputs_logit[-1], outputs_seg) + + +class DeformableDETRHead(nn.Layer): + __shared__ = ["num_classes", "hidden_dim"] + __inject__ = ["loss"] + + def __init__( + self, num_classes=80, hidden_dim=512, nhead=8, num_mlp_layers=3, loss="DETRLoss" + ): + super(DeformableDETRHead, self).__init__() + self.num_classes = num_classes + self.hidden_dim = hidden_dim + self.nhead = nhead + self.loss = loss + + self.score_head = nn.Linear(hidden_dim, self.num_classes) + self.bbox_head = MLP( + hidden_dim, hidden_dim, output_dim=4, num_layers=num_mlp_layers + ) + + self._reset_parameters() + + def _reset_parameters(self): + linear_init_(self.score_head) + constant_(self.score_head.bias, -4.595) + constant_(self.bbox_head.layers[-1].weight) + + with paddle.no_grad(): + bias = paddle.zeros_like(self.bbox_head.layers[-1].bias) + bias[2:] = -2.0 + self.bbox_head.layers[-1].bias.set_value(bias) + + @classmethod + def from_config(cls, cfg, hidden_dim, nhead, input_shape): + return {"hidden_dim": hidden_dim, "nhead": nhead} + + def forward(self, out_transformer, body_feats, inputs=None): + r""" + Args: + out_transformer (Tuple): (feats: [num_levels, batch_size, + num_queries, hidden_dim], + memory: [batch_size, + \sum_{l=0}^{L-1} H_l \cdot W_l, hidden_dim], + reference_points: [batch_size, num_queries, 2]) + body_feats (List(Tensor)): list[[B, C, H, W]] + inputs (dict): dict(inputs) + """ + feats, memory, reference_points = out_transformer + reference_points = inverse_sigmoid(reference_points.unsqueeze(0)) + outputs_bbox = self.bbox_head(feats) + + # It's equivalent to "outputs_bbox[:, :, :, :2] += reference_points", + # but the gradient is wrong in paddle. + outputs_bbox = paddle.concat( + [outputs_bbox[:, :, :, :2] + reference_points, outputs_bbox[:, :, :, 2:]], + axis=-1, + ) + + outputs_bbox = F.sigmoid(outputs_bbox) + outputs_logit = self.score_head(feats) + + if self.training: + assert inputs is not None + assert "gt_bbox" in inputs and "gt_class" in inputs + + return self.loss( + outputs_bbox, outputs_logit, inputs["gt_bbox"], inputs["gt_class"] + ) + else: + return (outputs_bbox[-1], outputs_logit[-1], None) + + +class DINOHead(nn.Layer): + __inject__ = ["loss"] + + def __init__(self, loss="DINOLoss", eval_idx=-1): + super(DINOHead, self).__init__() + self.loss = loss + self.eval_idx = eval_idx + + def forward(self, out_transformer, body_feats, inputs=None): + (dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta) = ( + out_transformer + ) + if self.training: + assert inputs is not None + assert "gt_bbox" in inputs and "gt_class" in inputs + + if dn_meta is not None: + if isinstance(dn_meta, list): + dual_groups = len(dn_meta) - 1 + dec_out_bboxes = paddle.split( + dec_out_bboxes, dual_groups + 1, axis=2 + ) + dec_out_logits = paddle.split( + dec_out_logits, dual_groups + 1, axis=2 + ) + enc_topk_bboxes = paddle.split( + enc_topk_bboxes, dual_groups + 1, axis=1 + ) + enc_topk_logits = paddle.split( + enc_topk_logits, dual_groups + 1, axis=1 + ) + + dec_out_bboxes_list = [] + dec_out_logits_list = [] + dn_out_bboxes_list = [] + dn_out_logits_list = [] + loss = {} + for g_id in range(dual_groups + 1): + if dn_meta[g_id] is not None: + dn_out_bboxes_gid, dec_out_bboxes_gid = paddle.split( + dec_out_bboxes[g_id], + dn_meta[g_id]["dn_num_split"], + axis=2, + ) + dn_out_logits_gid, dec_out_logits_gid = paddle.split( + dec_out_logits[g_id], + dn_meta[g_id]["dn_num_split"], + axis=2, + ) + else: + dn_out_bboxes_gid, dn_out_logits_gid = None, None + dec_out_bboxes_gid = dec_out_bboxes[g_id] + dec_out_logits_gid = dec_out_logits[g_id] + out_bboxes_gid = paddle.concat( + [enc_topk_bboxes[g_id].unsqueeze(0), dec_out_bboxes_gid] + ) + out_logits_gid = paddle.concat( + [enc_topk_logits[g_id].unsqueeze(0), dec_out_logits_gid] + ) + loss_gid = self.loss( + out_bboxes_gid, + out_logits_gid, + inputs["gt_bbox"], + inputs["gt_class"], + dn_out_bboxes=dn_out_bboxes_gid, + dn_out_logits=dn_out_logits_gid, + dn_meta=dn_meta[g_id], + ) + # sum loss + for key, value in loss_gid.items(): + loss.update({key: loss.get(key, paddle.zeros([1])) + value}) + + # average across (dual_groups + 1) + for key, value in loss.items(): + loss.update({key: value / (dual_groups + 1)}) + return loss + else: + dn_out_bboxes, dec_out_bboxes = paddle.split( + dec_out_bboxes, dn_meta["dn_num_split"], axis=2 + ) + dn_out_logits, dec_out_logits = paddle.split( + dec_out_logits, dn_meta["dn_num_split"], axis=2 + ) + else: + dn_out_bboxes, dn_out_logits = None, None + + out_bboxes = paddle.concat([enc_topk_bboxes.unsqueeze(0), dec_out_bboxes]) + out_logits = paddle.concat([enc_topk_logits.unsqueeze(0), dec_out_logits]) + + return self.loss( + out_bboxes, + out_logits, + inputs["gt_bbox"], + inputs["gt_class"], + dn_out_bboxes=dn_out_bboxes, + dn_out_logits=dn_out_logits, + dn_meta=dn_meta, + gt_score=inputs.get("gt_score", None), + ) + else: + return (dec_out_bboxes[self.eval_idx], dec_out_logits[self.eval_idx], None) + + +class MaskDINOHead(nn.Layer): + __inject__ = ["loss"] + + def __init__(self, loss="DINOLoss"): + super(MaskDINOHead, self).__init__() + self.loss = loss + + def forward(self, out_transformer, body_feats, inputs=None): + (dec_out_logits, dec_out_bboxes, dec_out_masks, enc_out, init_out, dn_meta) = ( + out_transformer + ) + if self.training: + assert inputs is not None + assert "gt_bbox" in inputs and "gt_class" in inputs + assert "gt_segm" in inputs + + if dn_meta is not None: + dn_out_logits, dec_out_logits = paddle.split( + dec_out_logits, dn_meta["dn_num_split"], axis=2 + ) + dn_out_bboxes, dec_out_bboxes = paddle.split( + dec_out_bboxes, dn_meta["dn_num_split"], axis=2 + ) + dn_out_masks, dec_out_masks = paddle.split( + dec_out_masks, dn_meta["dn_num_split"], axis=2 + ) + if init_out is not None: + init_out_logits, init_out_bboxes, init_out_masks = init_out + init_out_logits_dn, init_out_logits = paddle.split( + init_out_logits, dn_meta["dn_num_split"], axis=1 + ) + init_out_bboxes_dn, init_out_bboxes = paddle.split( + init_out_bboxes, dn_meta["dn_num_split"], axis=1 + ) + init_out_masks_dn, init_out_masks = paddle.split( + init_out_masks, dn_meta["dn_num_split"], axis=1 + ) + + dec_out_logits = paddle.concat( + [init_out_logits.unsqueeze(0), dec_out_logits] + ) + dec_out_bboxes = paddle.concat( + [init_out_bboxes.unsqueeze(0), dec_out_bboxes] + ) + dec_out_masks = paddle.concat( + [init_out_masks.unsqueeze(0), dec_out_masks] + ) + + dn_out_logits = paddle.concat( + [init_out_logits_dn.unsqueeze(0), dn_out_logits] + ) + dn_out_bboxes = paddle.concat( + [init_out_bboxes_dn.unsqueeze(0), dn_out_bboxes] + ) + dn_out_masks = paddle.concat( + [init_out_masks_dn.unsqueeze(0), dn_out_masks] + ) + else: + dn_out_bboxes, dn_out_logits = None, None + dn_out_masks = None + + enc_out_logits, enc_out_bboxes, enc_out_masks = enc_out + out_logits = paddle.concat([enc_out_logits.unsqueeze(0), dec_out_logits]) + out_bboxes = paddle.concat([enc_out_bboxes.unsqueeze(0), dec_out_bboxes]) + out_masks = paddle.concat([enc_out_masks.unsqueeze(0), dec_out_masks]) + + inputs["gt_segm"] = [ + gt_segm.astype(out_masks.dtype) for gt_segm in inputs["gt_segm"] + ] + + return self.loss( + out_bboxes, + out_logits, + inputs["gt_bbox"], + inputs["gt_class"], + masks=out_masks, + gt_mask=inputs["gt_segm"], + dn_out_logits=dn_out_logits, + dn_out_bboxes=dn_out_bboxes, + dn_out_masks=dn_out_masks, + dn_meta=dn_meta, + ) + else: + return (dec_out_bboxes[-1], dec_out_logits[-1], dec_out_masks[-1]) + + +class RTDETRv3Head(nn.Layer): + __inject__ = ["loss"] + __shared__ = ["o2m_branch", "num_queries_o2m"] + + def __init__( + self, loss="DINOLoss", eval_idx=-1, o2m=4, o2m_branch=False, num_queries_o2m=450 + ): + super(RTDETRv3Head, self).__init__() + self.loss = loss + self.eval_idx = eval_idx + self.o2m = o2m + self.o2m_branch = o2m_branch + self.num_queries_o2m = num_queries_o2m + + def forward(self, out_transformer, body_feats, inputs=None): + (dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta) = ( + out_transformer + ) + if self.training: + assert inputs is not None + assert "gt_bbox" in inputs and "gt_class" in inputs + + if dn_meta is not None: + num_groups = len(dn_meta) + total_dec_queries = dec_out_bboxes.shape[2] + total_enc_queries = enc_topk_bboxes.shape[1] + loss = {} + if self.o2m_branch: + dec_out_bboxes, dec_out_bboxes_o2m = paddle.split( + dec_out_bboxes, + [ + total_dec_queries - self.num_queries_o2m, + self.num_queries_o2m, + ], + axis=2, + ) + dec_out_logits, dec_out_logits_o2m = paddle.split( + dec_out_logits, + [ + total_dec_queries - self.num_queries_o2m, + self.num_queries_o2m, + ], + axis=2, + ) + enc_topk_bboxes, enc_topk_bboxes_o2m = paddle.split( + enc_topk_bboxes, + [ + total_enc_queries - self.num_queries_o2m, + self.num_queries_o2m, + ], + axis=1, + ) + enc_topk_logits, enc_topk_logits_o2m = paddle.split( + enc_topk_logits, + [ + total_enc_queries - self.num_queries_o2m, + self.num_queries_o2m, + ], + axis=1, + ) + + out_bboxes_o2m = paddle.concat( + [enc_topk_bboxes_o2m.unsqueeze(0), dec_out_bboxes_o2m] + ) + out_logits_o2m = paddle.concat( + [enc_topk_logits_o2m.unsqueeze(0), dec_out_logits_o2m] + ) + loss_o2m = self.loss( + out_bboxes_o2m, + out_logits_o2m, + inputs["gt_bbox"], + inputs["gt_class"], + dn_out_bboxes=None, + dn_out_logits=None, + dn_meta=None, + o2m=self.o2m, + ) + for key, value in loss_o2m.items(): + key = key + "_o2m_branch" + loss.update({key: loss.get(key, paddle.zeros([1])) + value}) + + split_dec_num = [sum(dn["dn_num_split"]) for dn in dn_meta] + split_enc_num = [dn["dn_num_split"][1] for dn in dn_meta] + dec_out_bboxes = paddle.split(dec_out_bboxes, split_dec_num, axis=2) + dec_out_logits = paddle.split(dec_out_logits, split_dec_num, axis=2) + enc_topk_bboxes = paddle.split(enc_topk_bboxes, split_enc_num, axis=1) + enc_topk_logits = paddle.split(enc_topk_logits, split_enc_num, axis=1) + + for g_id in range(num_groups): + dn_out_bboxes_gid, dec_out_bboxes_gid = paddle.split( + dec_out_bboxes[g_id], dn_meta[g_id]["dn_num_split"], axis=2 + ) + dn_out_logits_gid, dec_out_logits_gid = paddle.split( + dec_out_logits[g_id], dn_meta[g_id]["dn_num_split"], axis=2 + ) + out_bboxes_gid = paddle.concat( + [enc_topk_bboxes[g_id].unsqueeze(0), dec_out_bboxes_gid] + ) + out_logits_gid = paddle.concat( + [enc_topk_logits[g_id].unsqueeze(0), dec_out_logits_gid] + ) + + loss_gid = self.loss( + out_bboxes_gid, + out_logits_gid, + inputs["gt_bbox"], + inputs["gt_class"], + dn_out_bboxes=dn_out_bboxes_gid, + dn_out_logits=dn_out_logits_gid, + dn_meta=dn_meta[g_id], + ) + # sum loss + for key, value in loss_gid.items(): + loss.update({key: loss.get(key, paddle.zeros([1])) + value}) + + # average across (dual_groups + 1) + for key, value in loss.items(): + if "_o2m_branch" not in key: + loss.update({key: value / num_groups}) + return loss + else: + dn_out_bboxes, dn_out_logits = None, None + + out_bboxes = paddle.concat([enc_topk_bboxes.unsqueeze(0), dec_out_bboxes]) + out_logits = paddle.concat([enc_topk_logits.unsqueeze(0), dec_out_logits]) + + return self.loss( + out_bboxes, + out_logits, + inputs["gt_bbox"], + inputs["gt_class"], + dn_out_bboxes=dn_out_bboxes, + dn_out_logits=dn_out_logits, + dn_meta=dn_meta, + gt_score=inputs.get("gt_score", None), + ) + else: + return (dec_out_bboxes[self.eval_idx], dec_out_logits[self.eval_idx], None) diff --git a/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/detr_transformer.py b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/detr_transformer.py new file mode 100644 index 0000000000..7489219b2f --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/detr_transformer.py @@ -0,0 +1,394 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from __future__ import absolute_import, division, print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from .modules.detr_ops import _get_clones +from .modules.initializer import conv_init_, linear_init_, normal_, xavier_uniform_ +from .modules.layers import MultiHeadAttention, _convert_attention_mask +from .modules.position_encoding import PositionEmbedding + +__all__ = ["TransformerEncoderLayer"] + + +class TransformerEncoderLayer(nn.Layer): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + attn_dropout=None, + act_dropout=None, + normalize_before=False, + ): + super(TransformerEncoderLayer, self).__init__() + attn_dropout = dropout if attn_dropout is None else attn_dropout + act_dropout = dropout if act_dropout is None else act_dropout + self.normalize_before = normalize_before + + self.self_attn = MultiHeadAttention(d_model, nhead, attn_dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(act_dropout, mode="upscale_in_train") + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train") + self.dropout2 = nn.Dropout(dropout, mode="upscale_in_train") + self.activation = getattr(F, activation) + self._reset_parameters() + + def _reset_parameters(self): + linear_init_(self.linear1) + linear_init_(self.linear2) + + @staticmethod + def with_pos_embed(tensor, pos_embed): + return tensor if pos_embed is None else tensor + pos_embed + + def forward(self, src, src_mask=None, pos_embed=None): + residual = src + if self.normalize_before: + src = self.norm1(src) + q = k = self.with_pos_embed(src, pos_embed) + src = self.self_attn(q, k, value=src, attn_mask=src_mask) + + src = residual + self.dropout1(src) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src) + if not self.normalize_before: + src = self.norm2(src) + return src + + +class TransformerEncoder(nn.Layer): + def __init__(self, encoder_layer, num_layers, norm=None, with_rp=-1): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + assert with_rp <= num_layers + self.with_rp = with_rp + + def forward(self, src, src_mask=None, pos_embed=None): + output = src + for i, layer in enumerate(self.layers): + if self.training and i < self.with_rp: + output = recompute( + layer, + output, + src_mask=src_mask, + pos_embed=pos_embed, + **{"preserve_rng_state": True, "use_reentrant": False}, + ) + else: + output = layer(output, src_mask=src_mask, pos_embed=pos_embed) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoderLayer(nn.Layer): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + attn_dropout=None, + act_dropout=None, + normalize_before=False, + ): + super(TransformerDecoderLayer, self).__init__() + attn_dropout = dropout if attn_dropout is None else attn_dropout + act_dropout = dropout if act_dropout is None else act_dropout + self.normalize_before = normalize_before + + self.self_attn = MultiHeadAttention(d_model, nhead, attn_dropout) + self.cross_attn = MultiHeadAttention(d_model, nhead, attn_dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(act_dropout, mode="upscale_in_train") + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train") + self.dropout2 = nn.Dropout(dropout, mode="upscale_in_train") + self.dropout3 = nn.Dropout(dropout, mode="upscale_in_train") + self.activation = getattr(F, activation) + self._reset_parameters() + + def _reset_parameters(self): + linear_init_(self.linear1) + linear_init_(self.linear2) + + @staticmethod + def with_pos_embed(tensor, pos_embed): + return tensor if pos_embed is None else tensor + pos_embed + + def forward( + self, + tgt, + memory, + tgt_mask=None, + memory_mask=None, + pos_embed=None, + query_pos_embed=None, + ): + tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype) + + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + q = k = self.with_pos_embed(tgt, query_pos_embed) + tgt = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask) + tgt = residual + self.dropout1(tgt) + if not self.normalize_before: + tgt = self.norm1(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm2(tgt) + q = self.with_pos_embed(tgt, query_pos_embed) + k = self.with_pos_embed(memory, pos_embed) + tgt = self.cross_attn(q, k, value=memory, attn_mask=memory_mask) + tgt = residual + self.dropout2(tgt) + if not self.normalize_before: + tgt = self.norm2(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm3(tgt) + tgt = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = residual + self.dropout3(tgt) + if not self.normalize_before: + tgt = self.norm3(tgt) + return tgt + + +class TransformerDecoder(nn.Layer): + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super(TransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward( + self, + tgt, + memory, + tgt_mask=None, + memory_mask=None, + pos_embed=None, + query_pos_embed=None, + ): + tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype) + + output = tgt + intermediate = [] + for layer in self.layers: + output = layer( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + pos_embed=pos_embed, + query_pos_embed=query_pos_embed, + ) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + + if self.return_intermediate: + return paddle.stack(intermediate) + + return output.unsqueeze(0) + + +class DETRTransformer(nn.Layer): + __shared__ = ["hidden_dim"] + + def __init__( + self, + num_queries=100, + position_embed_type="sine", + return_intermediate_dec=True, + backbone_num_channels=2048, + hidden_dim=256, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + pe_temperature=10000, + pe_offset=0.0, + attn_dropout=None, + act_dropout=None, + normalize_before=False, + ): + super(DETRTransformer, self).__init__() + assert position_embed_type in [ + "sine", + "learned", + ], f"ValueError: position_embed_type not supported {position_embed_type}!" + self.hidden_dim = hidden_dim + self.nhead = nhead + + encoder_layer = TransformerEncoderLayer( + hidden_dim, + nhead, + dim_feedforward, + dropout, + activation, + attn_dropout, + act_dropout, + normalize_before, + ) + encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None + self.encoder = TransformerEncoder( + encoder_layer, num_encoder_layers, encoder_norm + ) + + decoder_layer = TransformerDecoderLayer( + hidden_dim, + nhead, + dim_feedforward, + dropout, + activation, + attn_dropout, + act_dropout, + normalize_before, + ) + decoder_norm = nn.LayerNorm(hidden_dim) + self.decoder = TransformerDecoder( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + + self.input_proj = nn.Conv2D(backbone_num_channels, hidden_dim, kernel_size=1) + self.query_pos_embed = nn.Embedding(num_queries, hidden_dim) + self.position_embedding = PositionEmbedding( + hidden_dim // 2, + temperature=pe_temperature, + normalize=True if position_embed_type == "sine" else False, + embed_type=position_embed_type, + offset=pe_offset, + ) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + conv_init_(self.input_proj) + normal_(self.query_pos_embed.weight) + + @classmethod + def from_config(cls, cfg, input_shape): + return { + "backbone_num_channels": [i.channels for i in input_shape][-1], + } + + def _convert_attention_mask(self, mask): + return (mask - 1.0) * 1e9 + + def forward(self, src, src_mask=None, *args, **kwargs): + r""" + Applies a Transformer model on the inputs. + + Parameters: + src (List(Tensor)): Backbone feature maps with shape [[bs, c, h, w]]. + src_mask (Tensor, optional): A tensor used in multi-head attention + to prevents attention to some unwanted positions, usually the + paddings or the subsequent positions. It is a tensor with shape + [bs, H, W]`. When the data type is bool, the unwanted positions + have `False` values and the others have `True` values. When the + data type is int, the unwanted positions have 0 values and the + others have 1 values. When the data type is float, the unwanted + positions have `-INF` values and the others have 0 values. It + can be None when nothing wanted or needed to be prevented + attention to. Default None. + + Returns: + output (Tensor): [num_levels, batch_size, num_queries, hidden_dim] + memory (Tensor): [batch_size, hidden_dim, h, w] + """ + # use last level feature map + src_proj = self.input_proj(src[-1]) + bs, c, h, w = src_proj.shape + # flatten [B, C, H, W] to [B, HxW, C] + src_flatten = src_proj.flatten(2).transpose([0, 2, 1]) + if src_mask is not None: + src_mask = F.interpolate(src_mask.unsqueeze(0), size=(h, w))[0] + else: + src_mask = paddle.ones([bs, h, w]) + pos_embed = self.position_embedding(src_mask).flatten(1, 2) + + if self.training: + src_mask = self._convert_attention_mask(src_mask) + src_mask = src_mask.reshape([bs, 1, 1, h * w]) + else: + src_mask = None + + memory = self.encoder(src_flatten, src_mask=src_mask, pos_embed=pos_embed) + + query_pos_embed = self.query_pos_embed.weight.unsqueeze(0).tile([bs, 1, 1]) + tgt = paddle.zeros_like(query_pos_embed) + output = self.decoder( + tgt, + memory, + memory_mask=src_mask, + pos_embed=pos_embed, + query_pos_embed=query_pos_embed, + ) + + if self.training: + src_mask = src_mask.reshape([bs, 1, 1, h, w]) + else: + src_mask = None + + return ( + output, + memory.transpose([0, 2, 1]).reshape([bs, c, h, w]), + src_proj, + src_mask, + ) diff --git a/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/hgnet_v2.py b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/hgnet_v2.py new file mode 100644 index 0000000000..3c60cd6831 --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/hgnet_v2.py @@ -0,0 +1,513 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +from paddle import ParamAttr +from paddle.nn import BatchNorm2D, Conv2D, ReLU +from paddle.nn.initializer import Constant, KaimingNormal +from paddle.regularizer import L2Decay + +from .modules.detr_ops import ShapeSpec + +__all__ = ["PPHGNetV2"] + +kaiming_normal_ = KaimingNormal() +zeros_ = Constant(value=0.0) +ones_ = Constant(value=1.0) + + +class LearnableAffineBlock(nn.Layer): + def __init__(self, scale_value=1.0, bias_value=0.0, lr_mult=1.0, lab_lr=0.01): + super().__init__() + self.scale = self.create_parameter( + shape=[ + 1, + ], + default_initializer=Constant(value=scale_value), + attr=ParamAttr(learning_rate=lr_mult * lab_lr), + ) + self.add_parameter("scale", self.scale) + self.bias = self.create_parameter( + shape=[ + 1, + ], + default_initializer=Constant(value=bias_value), + attr=ParamAttr(learning_rate=lr_mult * lab_lr), + ) + self.add_parameter("bias", self.bias) + + def forward(self, x): + return self.scale * x + self.bias + + +class ConvBNAct(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + groups=1, + use_act=True, + use_lab=False, + lr_mult=1.0, + ): + super().__init__() + self.use_act = use_act + self.use_lab = use_lab + self.conv = Conv2D( + in_channels, + out_channels, + kernel_size, + stride, + padding=padding if isinstance(padding, str) else (kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(learning_rate=lr_mult), + bias_attr=False, + ) + self.bn = BatchNorm2D( + out_channels, + weight_attr=ParamAttr(regularizer=L2Decay(0.0), learning_rate=lr_mult), + bias_attr=ParamAttr(regularizer=L2Decay(0.0), learning_rate=lr_mult), + ) + if self.use_act: + self.act = ReLU() + if self.use_lab: + self.lab = LearnableAffineBlock(lr_mult=lr_mult) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.use_act: + x = self.act(x) + if self.use_lab: + x = self.lab(x) + return x + + +class LightConvBNAct(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + groups=1, + use_lab=False, + lr_mult=1.0, + ): + super().__init__() + self.conv1 = ConvBNAct( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + use_act=False, + use_lab=use_lab, + lr_mult=lr_mult, + ) + self.conv2 = ConvBNAct( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + groups=out_channels, + use_act=True, + use_lab=use_lab, + lr_mult=lr_mult, + ) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +class StemBlock(nn.Layer): + def __init__( + self, in_channels, mid_channels, out_channels, use_lab=False, lr_mult=1.0 + ): + super().__init__() + self.stem1 = ConvBNAct( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=3, + stride=2, + use_lab=use_lab, + lr_mult=lr_mult, + ) + self.stem2a = ConvBNAct( + in_channels=mid_channels, + out_channels=mid_channels // 2, + kernel_size=2, + stride=1, + padding="SAME", + use_lab=use_lab, + lr_mult=lr_mult, + ) + self.stem2b = ConvBNAct( + in_channels=mid_channels // 2, + out_channels=mid_channels, + kernel_size=2, + stride=1, + padding="SAME", + use_lab=use_lab, + lr_mult=lr_mult, + ) + self.stem3 = ConvBNAct( + in_channels=mid_channels * 2, + out_channels=mid_channels, + kernel_size=3, + stride=2, + use_lab=use_lab, + lr_mult=lr_mult, + ) + self.stem4 = ConvBNAct( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + use_lab=use_lab, + lr_mult=lr_mult, + ) + self.pool = nn.MaxPool2D( + kernel_size=2, stride=1, ceil_mode=True, padding="SAME" + ) + + def forward(self, x): + x = self.stem1(x) + x2 = self.stem2a(x) + x2 = self.stem2b(x2) + x1 = self.pool(x) + x = paddle.concat([x1, x2], 1) + x = self.stem3(x) + x = self.stem4(x) + + return x + + +class HG_Block(nn.Layer): + def __init__( + self, + in_channels, + mid_channels, + out_channels, + kernel_size=3, + layer_num=6, + identity=False, + light_block=True, + use_lab=False, + lr_mult=1.0, + ): + super().__init__() + self.identity = identity + + self.layers = nn.LayerList() + block_type = "LightConvBNAct" if light_block else "ConvBNAct" + for i in range(layer_num): + self.layers.append( + eval(block_type)( + in_channels=in_channels if i == 0 else mid_channels, + out_channels=mid_channels, + stride=1, + kernel_size=kernel_size, + use_lab=use_lab, + lr_mult=lr_mult, + ) + ) + # feature aggregation + total_channels = in_channels + layer_num * mid_channels + self.aggregation_squeeze_conv = ConvBNAct( + in_channels=total_channels, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + use_lab=use_lab, + lr_mult=lr_mult, + ) + self.aggregation_excitation_conv = ConvBNAct( + in_channels=out_channels // 2, + out_channels=out_channels, + kernel_size=1, + stride=1, + use_lab=use_lab, + lr_mult=lr_mult, + ) + + def forward(self, x): + identity = x + output = [] + output.append(x) + for layer in self.layers: + x = layer(x) + output.append(x) + x = paddle.concat(output, axis=1) + x = self.aggregation_squeeze_conv(x) + x = self.aggregation_excitation_conv(x) + if self.identity: + x += identity + return x + + +class HG_Stage(nn.Layer): + def __init__( + self, + in_channels, + mid_channels, + out_channels, + block_num, + layer_num=6, + downsample=True, + light_block=True, + kernel_size=3, + use_lab=False, + lr_mult=1.0, + ): + super().__init__() + self.downsample = downsample + if downsample: + self.downsample = ConvBNAct( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=2, + groups=in_channels, + use_act=False, + use_lab=use_lab, + lr_mult=lr_mult, + ) + + blocks_list = [] + for i in range(block_num): + blocks_list.append( + HG_Block( + in_channels=in_channels if i == 0 else out_channels, + mid_channels=mid_channels, + out_channels=out_channels, + kernel_size=kernel_size, + layer_num=layer_num, + identity=False if i == 0 else True, + light_block=light_block, + use_lab=use_lab, + lr_mult=lr_mult, + ) + ) + self.blocks = nn.Sequential(*blocks_list) + + def forward(self, x): + if self.downsample: + x = self.downsample(x) + x = self.blocks(x) + return x + + +def _freeze_norm(m: nn.BatchNorm2D): + param_attr = ParamAttr(learning_rate=0.0, regularizer=L2Decay(0.0), trainable=False) + bias_attr = ParamAttr(learning_rate=0.0, regularizer=L2Decay(0.0), trainable=False) + global_stats = True + norm = nn.BatchNorm2D( + m._num_features, + weight_attr=param_attr, + bias_attr=bias_attr, + use_global_stats=global_stats, + ) + for param in norm.parameters(): + param.stop_gradient = True + return norm + + +def reset_bn(model: nn.Layer, reset_func=_freeze_norm): + if isinstance(model, nn.BatchNorm2D): + model = reset_func(model) + else: + for name, child in model.named_children(): + _child = reset_bn(child, reset_func) + if _child is not child: + setattr(model, name, _child) + return model + + +class PPHGNetV2(nn.Layer): + """ + PPHGNetV2 + Args: + stem_channels: list. Number of channels for the stem block. + stage_type: str. The stage configuration of PPHGNet. such as the number of channels, stride, etc. + use_lab: boolean. Whether to use LearnableAffineBlock in network. + lr_mult_list: list. Control the learning rate of different stages. + Returns: + model: nn.Layer. Specific PPHGNetV2 model depends on args. + """ + + arch_configs = { + "N": { + "stem_channels": [3, 16, 16], + "stage_config": { + # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num + "stage1": [16, 16, 64, 1, False, False, 3, 3], + "stage2": [64, 32, 256, 1, True, False, 3, 3], + "stage3": [256, 64, 512, 2, True, True, 5, 3], + "stage4": [512, 128, 1024, 1, True, True, 5, 3], + }, + }, + "S": { + "stem_channels": [3, 24, 32], + "stage_config": { + # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num + "stage1": [32, 32, 64, 1, False, False, 3, 3], + "stage2": [64, 48, 256, 1, True, False, 3, 3], + "stage3": [256, 96, 512, 2, True, True, 5, 3], + "stage4": [512, 192, 1024, 1, True, True, 5, 3], + }, + }, + "M": { + "stem_channels": [3, 24, 32], + "stage_config": { + # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num + "stage1": [32, 32, 96, 1, False, False, 3, 4], + "stage2": [96, 64, 384, 1, True, False, 3, 4], + "stage3": [384, 128, 768, 3, True, True, 5, 4], + "stage4": [768, 256, 1536, 1, True, True, 5, 4], + }, + }, + "L": { + "stem_channels": [3, 32, 48], + "stage_config": { + # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num + "stage1": [48, 48, 128, 1, False, False, 3, 6], + "stage2": [128, 96, 512, 1, True, False, 3, 6], + "stage3": [512, 192, 1024, 3, True, True, 5, 6], + "stage4": [1024, 384, 2048, 1, True, True, 5, 6], + }, + }, + "X": { + "stem_channels": [3, 32, 64], + "stage_config": { + # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num + "stage1": [64, 64, 128, 1, False, False, 3, 6], + "stage2": [128, 128, 512, 2, True, False, 3, 6], + "stage3": [512, 256, 1024, 5, True, True, 5, 6], + "stage4": [1024, 512, 2048, 2, True, True, 5, 6], + }, + }, + "H": { + "stem_channels": [3, 48, 96], + "stage_config": { + # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num + "stage1": [96, 96, 192, 2, False, False, 3, 6], + "stage2": [192, 192, 512, 3, True, False, 3, 6], + "stage3": [512, 384, 1024, 6, True, True, 5, 6], + "stage4": [1024, 768, 2048, 3, True, True, 5, 6], + }, + }, + } + + def __init__( + self, + arch, + use_lab=False, + lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0], + return_idx=[1, 2, 3], + freeze_stem_only=True, + freeze_at=0, + freeze_norm=True, + ): + super().__init__() + self.use_lab = use_lab + self.return_idx = return_idx + + stem_channels = self.arch_configs[arch]["stem_channels"] + stage_config = self.arch_configs[arch]["stage_config"] + + self._out_strides = [4, 8, 16, 32] + self._out_channels = [stage_config[k][2] for k in stage_config] + + # stem + self.stem = StemBlock( + in_channels=stem_channels[0], + mid_channels=stem_channels[1], + out_channels=stem_channels[2], + use_lab=use_lab, + lr_mult=lr_mult_list[0], + ) + + # stages + self.stages = nn.LayerList() + for i, k in enumerate(stage_config): + ( + in_channels, + mid_channels, + out_channels, + block_num, + downsample, + light_block, + kernel_size, + layer_num, + ) = stage_config[k] + self.stages.append( + HG_Stage( + in_channels, + mid_channels, + out_channels, + block_num, + layer_num, + downsample, + light_block, + kernel_size, + use_lab, + lr_mult=lr_mult_list[i + 1], + ) + ) + + if freeze_at >= 0: + self._freeze_parameters(self.stem) + if not freeze_stem_only: + for i in range(min(freeze_at + 1, len(self.stages))): + self._freeze_parameters(self.stages[i]) + + if freeze_norm: + reset_bn(self, reset_func=_freeze_norm) + + self._init_weights() + + def _freeze_parameters(self, m): + for p in m.parameters(): + p.stop_gradient = True + + def _init_weights(self): + for m in self.sublayers(): + if isinstance(m, nn.Conv2D): + kaiming_normal_(m.weight) + elif isinstance(m, (nn.BatchNorm2D)): + ones_(m.weight) + zeros_(m.bias) + elif isinstance(m, nn.Linear): + zeros_(m.bias) + + @property + def out_shape(self): + return [ + ShapeSpec(channels=self._out_channels[i], stride=self._out_strides[i]) + for i in self.return_idx + ] + + def forward(self, inputs): + x = inputs + x = self.stem(x) + outs = [] + for idx, stage in enumerate(self.stages): + x = stage(x) + if idx in self.return_idx: + outs.append(x) + return outs diff --git a/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/hybrid_encoder.py b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/hybrid_encoder.py new file mode 100644 index 0000000000..257f583b3d --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/hybrid_encoder.py @@ -0,0 +1,314 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.regularizer import L2Decay + +from .detr_transformer import TransformerEncoder +from .modules.csp_darknet import BaseConv +from .modules.cspresnet import RepVggBlock +from .modules.detr_ops import ShapeSpec +from .modules.initializer import linear_init_ +from .modules.layers import MultiHeadAttention +from .modules.ops import get_act_fn + +__all__ = ["HybridEncoder"] + + +class CSPRepLayer(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + num_blocks=3, + expansion=1.0, + bias=False, + act="silu", + ): + super(CSPRepLayer, self).__init__() + hidden_channels = int(out_channels * expansion) + self.conv1 = BaseConv( + in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act + ) + self.conv2 = BaseConv( + in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act + ) + self.bottlenecks = nn.Sequential( + *[ + RepVggBlock(hidden_channels, hidden_channels, act=act) + for _ in range(num_blocks) + ] + ) + if hidden_channels != out_channels: + self.conv3 = BaseConv( + hidden_channels, out_channels, ksize=1, stride=1, bias=bias, act=act + ) + else: + self.conv3 = nn.Identity() + + def forward(self, x): + x_1 = self.conv1(x) + x_1 = self.bottlenecks(x_1) + x_2 = self.conv2(x) + return self.conv3(x_1 + x_2) + + +class TransformerLayer(nn.Layer): + def __init__( + self, + d_model, + nhead, + dim_feedforward=1024, + dropout=0.0, + activation="relu", + attn_dropout=None, + act_dropout=None, + normalize_before=False, + ): + super(TransformerLayer, self).__init__() + attn_dropout = dropout if attn_dropout is None else attn_dropout + act_dropout = dropout if act_dropout is None else act_dropout + self.normalize_before = normalize_before + + self.self_attn = MultiHeadAttention(d_model, nhead, attn_dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(act_dropout, mode="upscale_in_train") + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train") + self.dropout2 = nn.Dropout(dropout, mode="upscale_in_train") + self.activation = getattr(F, activation) + self._reset_parameters() + + def _reset_parameters(self): + linear_init_(self.linear1) + linear_init_(self.linear2) + + @staticmethod + def with_pos_embed(tensor, pos_embed): + return tensor if pos_embed is None else tensor + pos_embed + + def forward(self, src, src_mask=None, pos_embed=None): + residual = src + if self.normalize_before: + src = self.norm1(src) + q = k = self.with_pos_embed(src, pos_embed) + src = self.self_attn(q, k, value=src, attn_mask=src_mask) + + src = residual + self.dropout1(src) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src) + if not self.normalize_before: + src = self.norm2(src) + return src + + +class HybridEncoder(nn.Layer): + __shared__ = ["depth_mult", "act", "trt", "eval_size"] + __inject__ = ["encoder_layer"] + + def __init__( + self, + in_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + hidden_dim=256, + use_encoder_idx=[2], + num_encoder_layers=1, + encoder_layer="TransformerLayer", + pe_temperature=10000, + expansion=1.0, + depth_mult=1.0, + act="silu", + trt=False, + eval_size=None, + with_rp=-1, + ): + super(HybridEncoder, self).__init__() + self.in_channels = in_channels + self.feat_strides = feat_strides + self.hidden_dim = hidden_dim + self.use_encoder_idx = use_encoder_idx + self.num_encoder_layers = num_encoder_layers + self.pe_temperature = pe_temperature + self.eval_size = eval_size + + # channel projection + self.input_proj = nn.LayerList() + for in_channel in in_channels: + self.input_proj.append( + nn.Sequential( + nn.Conv2D(in_channel, hidden_dim, kernel_size=1, bias_attr=False), + nn.BatchNorm2D( + hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)), + ), + ) + ) + # encoder transformer + self.encoder = nn.LayerList( + [ + TransformerEncoder(encoder_layer, num_encoder_layers, with_rp=with_rp) + for _ in range(len(use_encoder_idx)) + ] + ) + + act = ( + get_act_fn(act, trt=trt) + if act is None or isinstance(act, (str, dict)) + else act + ) + # top-down fpn + self.lateral_convs = nn.LayerList() + self.fpn_blocks = nn.LayerList() + for idx in range(len(in_channels) - 1, 0, -1): + self.lateral_convs.append(BaseConv(hidden_dim, hidden_dim, 1, 1, act=act)) + self.fpn_blocks.append( + CSPRepLayer( + hidden_dim * 2, + hidden_dim, + round(3 * depth_mult), + act=act, + expansion=expansion, + ) + ) + + # bottom-up pan + self.downsample_convs = nn.LayerList() + self.pan_blocks = nn.LayerList() + for idx in range(len(in_channels) - 1): + self.downsample_convs.append( + BaseConv(hidden_dim, hidden_dim, 3, stride=2, act=act) + ) + self.pan_blocks.append( + CSPRepLayer( + hidden_dim * 2, + hidden_dim, + round(3 * depth_mult), + act=act, + expansion=expansion, + ) + ) + + self._reset_parameters() + + def _reset_parameters(self): + if self.eval_size: + for idx in self.use_encoder_idx: + stride = self.feat_strides[idx] + pos_embed = self.build_2d_sincos_position_embedding( + self.eval_size[1] // stride, + self.eval_size[0] // stride, + self.hidden_dim, + self.pe_temperature, + ) + setattr(self, f"pos_embed{idx}", pos_embed) + + @staticmethod + def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0): + grid_w = paddle.arange(int(w), dtype=paddle.float32) + grid_h = paddle.arange(int(h), dtype=paddle.float32) + grid_w, grid_h = paddle.meshgrid(grid_w, grid_h) + assert ( + embed_dim % 4 == 0 + ), "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" + pos_dim = embed_dim // 4 + omega = paddle.arange(pos_dim, dtype=paddle.float32) / pos_dim + omega = 1.0 / (temperature**omega) + + out_w = grid_w.flatten()[..., None] @ omega[None] + out_h = grid_h.flatten()[..., None] @ omega[None] + + return paddle.concat( + [ + paddle.sin(out_w), + paddle.cos(out_w), + paddle.sin(out_h), + paddle.cos(out_h), + ], + axis=1, + )[None, :, :] + + def forward(self, feats, for_mot=False, is_teacher=False): + assert len(feats) == len(self.in_channels) + # get projection features + proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] + # encoder + if self.num_encoder_layers > 0: + for i, enc_ind in enumerate(self.use_encoder_idx): + h, w = proj_feats[enc_ind].shape[2:] + # flatten [B, C, H, W] to [B, HxW, C] + src_flatten = proj_feats[enc_ind].flatten(2).transpose([0, 2, 1]) + if self.training or self.eval_size is None or is_teacher: + pos_embed = self.build_2d_sincos_position_embedding( + w, h, self.hidden_dim, self.pe_temperature + ) + else: + pos_embed = getattr(self, f"pos_embed{enc_ind}", None) + memory = self.encoder[i](src_flatten, pos_embed=pos_embed) + proj_feats[enc_ind] = memory.transpose([0, 2, 1]).reshape( + [-1, self.hidden_dim, h, w] + ) + + # top-down fpn + inner_outs = [proj_feats[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_heigh = inner_outs[0] + feat_low = proj_feats[idx - 1] + feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_heigh) + inner_outs[0] = feat_heigh + upsample_feat = F.interpolate(feat_heigh, scale_factor=2.0, mode="nearest") + inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx]( + paddle.concat([upsample_feat, feat_low], axis=1) + ) + inner_outs.insert(0, inner_out) + + # bottom-up pan + outs = [inner_outs[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = outs[-1] + feat_height = inner_outs[idx + 1] + downsample_feat = self.downsample_convs[idx](feat_low) + out = self.pan_blocks[idx]( + paddle.concat([downsample_feat, feat_height], axis=1) + ) + outs.append(out) + + return outs + + @classmethod + def from_config(cls, cfg, input_shape): + return { + "in_channels": [i.channels for i in input_shape], + "feat_strides": [i.stride for i in input_shape], + } + + @property + def out_shape(self): + return [ + ShapeSpec(channels=self.hidden_dim, stride=self.feat_strides[idx]) + for idx in range(len(self.in_channels)) + ] diff --git a/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/csp_darknet.py b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/csp_darknet.py new file mode 100644 index 0000000000..8a8531aeb4 --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/csp_darknet.py @@ -0,0 +1,396 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.regularizer import L2Decay + +from .detr_ops import ShapeSpec +from .initializer import conv_init_ + +__all__ = ["CSPDarkNet", "BaseConv", "DWConv", "BottleNeck", "SPPLayer", "SPPFLayer"] + + +class BaseConv(nn.Layer): + def __init__( + self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu" + ): + super(BaseConv, self).__init__() + self.conv = nn.Conv2D( + in_channels, + out_channels, + kernel_size=ksize, + stride=stride, + padding=(ksize - 1) // 2, + groups=groups, + bias_attr=bias, + ) + self.bn = nn.BatchNorm2D( + out_channels, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)), + ) + + self._init_weights() + + def _init_weights(self): + conv_init_(self.conv) + + def forward(self, x): + # use 'x * F.sigmoid(x)' replace 'silu' + x = self.bn(self.conv(x)) + y = x * F.sigmoid(x) + return y + + +class DWConv(nn.Layer): + """Depthwise Conv""" + + def __init__( + self, in_channels, out_channels, ksize, stride=1, bias=False, act="silu" + ): + super(DWConv, self).__init__() + self.dw_conv = BaseConv( + in_channels, + in_channels, + ksize=ksize, + stride=stride, + groups=in_channels, + bias=bias, + act=act, + ) + self.pw_conv = BaseConv( + in_channels, out_channels, ksize=1, stride=1, groups=1, bias=bias, act=act + ) + + def forward(self, x): + return self.pw_conv(self.dw_conv(x)) + + +class Focus(nn.Layer): + """Focus width and height information into channel space, used in YOLOX.""" + + def __init__( + self, in_channels, out_channels, ksize=3, stride=1, bias=False, act="silu" + ): + super(Focus, self).__init__() + self.conv = BaseConv( + in_channels * 4, + out_channels, + ksize=ksize, + stride=stride, + bias=bias, + act=act, + ) + + def forward(self, inputs): + # inputs [bs, C, H, W] -> outputs [bs, 4C, W/2, H/2] + top_left = inputs[:, :, 0::2, 0::2] + top_right = inputs[:, :, 0::2, 1::2] + bottom_left = inputs[:, :, 1::2, 0::2] + bottom_right = inputs[:, :, 1::2, 1::2] + outputs = paddle.concat([top_left, bottom_left, top_right, bottom_right], 1) + return self.conv(outputs) + + +class BottleNeck(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + shortcut=True, + expansion=0.5, + depthwise=False, + bias=False, + act="silu", + ): + super(BottleNeck, self).__init__() + hidden_channels = int(out_channels * expansion) + Conv = DWConv if depthwise else BaseConv + self.conv1 = BaseConv( + in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act + ) + self.conv2 = Conv( + hidden_channels, out_channels, ksize=3, stride=1, bias=bias, act=act + ) + self.add_shortcut = shortcut and in_channels == out_channels + + def forward(self, x): + y = self.conv2(self.conv1(x)) + if self.add_shortcut: + y = y + x + return y + + +class SPPLayer(nn.Layer): + """Spatial Pyramid Pooling (SPP) layer used in YOLOv3-SPP and YOLOX""" + + def __init__( + self, in_channels, out_channels, kernel_sizes=(5, 9, 13), bias=False, act="silu" + ): + super(SPPLayer, self).__init__() + hidden_channels = in_channels // 2 + self.conv1 = BaseConv( + in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act + ) + self.maxpoolings = nn.LayerList( + [ + nn.MaxPool2D(kernel_size=ks, stride=1, padding=ks // 2) + for ks in kernel_sizes + ] + ) + conv2_channels = hidden_channels * (len(kernel_sizes) + 1) + self.conv2 = BaseConv( + conv2_channels, out_channels, ksize=1, stride=1, bias=bias, act=act + ) + + def forward(self, x): + x = self.conv1(x) + x = paddle.concat([x] + [mp(x) for mp in self.maxpoolings], axis=1) + x = self.conv2(x) + return x + + +class SPPFLayer(nn.Layer): + """Spatial Pyramid Pooling - Fast (SPPF) layer used in YOLOv5 by Glenn Jocher, + equivalent to SPP(k=(5, 9, 13)) + """ + + def __init__(self, in_channels, out_channels, ksize=5, bias=False, act="silu"): + super(SPPFLayer, self).__init__() + hidden_channels = in_channels // 2 + self.conv1 = BaseConv( + in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act + ) + self.maxpooling = nn.MaxPool2D(kernel_size=ksize, stride=1, padding=ksize // 2) + conv2_channels = hidden_channels * 4 + self.conv2 = BaseConv( + conv2_channels, out_channels, ksize=1, stride=1, bias=bias, act=act + ) + + def forward(self, x): + x = self.conv1(x) + y1 = self.maxpooling(x) + y2 = self.maxpooling(y1) + y3 = self.maxpooling(y2) + concats = paddle.concat([x, y1, y2, y3], axis=1) + out = self.conv2(concats) + return out + + +class CSPLayer(nn.Layer): + """CSP (Cross Stage Partial) layer with 3 convs, named C3 in YOLOv5""" + + def __init__( + self, + in_channels, + out_channels, + num_blocks=1, + shortcut=True, + expansion=0.5, + depthwise=False, + bias=False, + act="silu", + ): + super(CSPLayer, self).__init__() + hidden_channels = int(out_channels * expansion) + self.conv1 = BaseConv( + in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act + ) + self.conv2 = BaseConv( + in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act + ) + self.bottlenecks = nn.Sequential( + *[ + BottleNeck( + hidden_channels, + hidden_channels, + shortcut=shortcut, + expansion=1.0, + depthwise=depthwise, + bias=bias, + act=act, + ) + for _ in range(num_blocks) + ] + ) + self.conv3 = BaseConv( + hidden_channels * 2, out_channels, ksize=1, stride=1, bias=bias, act=act + ) + + def forward(self, x): + x_1 = self.conv1(x) + x_1 = self.bottlenecks(x_1) + x_2 = self.conv2(x) + x = paddle.concat([x_1, x_2], axis=1) + x = self.conv3(x) + return x + + +class CSPDarkNet(nn.Layer): + """ + CSPDarkNet backbone. + Args: + arch (str): Architecture of CSPDarkNet, from {P5, P6, X}, default as X, + and 'X' means used in YOLOX, 'P5/P6' means used in YOLOv5. + depth_mult (float): Depth multiplier, multiply number of channels in + each layer, default as 1.0. + width_mult (float): Width multiplier, multiply number of blocks in + CSPLayer, default as 1.0. + depthwise (bool): Whether to use depth-wise conv layer. + act (str): Activation function type, default as 'silu'. + return_idx (list): Index of stages whose feature maps are returned. + """ + + __shared__ = ["depth_mult", "width_mult", "act", "trt"] + + # in_channels, out_channels, num_blocks, add_shortcut, use_spp(use_sppf) + # 'X' means setting used in YOLOX, 'P5/P6' means setting used in YOLOv5. + arch_settings = { + "X": [ + [64, 128, 3, True, False], + [128, 256, 9, True, False], + [256, 512, 9, True, False], + [512, 1024, 3, False, True], + ], + "P5": [ + [64, 128, 3, True, False], + [128, 256, 6, True, False], + [256, 512, 9, True, False], + [512, 1024, 3, True, True], + ], + "P6": [ + [64, 128, 3, True, False], + [128, 256, 6, True, False], + [256, 512, 9, True, False], + [512, 768, 3, True, False], + [768, 1024, 3, True, True], + ], + } + + def __init__( + self, + arch="X", + depth_mult=1.0, + width_mult=1.0, + depthwise=False, + act="silu", + trt=False, + return_idx=[2, 3, 4], + ): + super(CSPDarkNet, self).__init__() + self.arch = arch + self.return_idx = return_idx + Conv = DWConv if depthwise else BaseConv + arch_setting = self.arch_settings[arch] + base_channels = int(arch_setting[0][0] * width_mult) + + # Note: differences between the latest YOLOv5 and the original YOLOX + # 1. self.stem, use SPPF(in YOLOv5) or SPP(in YOLOX) + # 2. use SPPF(in YOLOv5) or SPP(in YOLOX) + # 3. put SPPF before(YOLOv5) or SPP after(YOLOX) the last cspdark block's CSPLayer + # 4. whether SPPF(SPP)'CSPLayer add shortcut, True in YOLOv5, False in YOLOX + if arch in ["P5", "P6"]: + # in the latest YOLOv5, use Conv stem, and SPPF (fast, only single spp kernal size) + self.stem = Conv(3, base_channels, ksize=6, stride=2, bias=False, act=act) + spp_kernal_sizes = 5 + elif arch in ["X"]: + # in the original YOLOX, use Focus stem, and SPP (three spp kernal sizes) + self.stem = Focus(3, base_channels, ksize=3, stride=1, bias=False, act=act) + spp_kernal_sizes = (5, 9, 13) + else: + raise AttributeError("Unsupported arch type: {}".format(arch)) + + _out_channels = [base_channels] + layers_num = 1 + self.csp_dark_blocks = [] + + for i, (in_channels, out_channels, num_blocks, shortcut, use_spp) in enumerate( + arch_setting + ): + in_channels = int(in_channels * width_mult) + out_channels = int(out_channels * width_mult) + _out_channels.append(out_channels) + num_blocks = max(round(num_blocks * depth_mult), 1) + stage = [] + + conv_layer = self.add_sublayer( + "layers{}.stage{}.conv_layer".format(layers_num, i + 1), + Conv(in_channels, out_channels, 3, 2, bias=False, act=act), + ) + stage.append(conv_layer) + layers_num += 1 + + if use_spp and arch in ["X"]: + # in YOLOX use SPPLayer + spp_layer = self.add_sublayer( + "layers{}.stage{}.spp_layer".format(layers_num, i + 1), + SPPLayer( + out_channels, + out_channels, + kernel_sizes=spp_kernal_sizes, + bias=False, + act=act, + ), + ) + stage.append(spp_layer) + layers_num += 1 + + csp_layer = self.add_sublayer( + "layers{}.stage{}.csp_layer".format(layers_num, i + 1), + CSPLayer( + out_channels, + out_channels, + num_blocks=num_blocks, + shortcut=shortcut, + depthwise=depthwise, + bias=False, + act=act, + ), + ) + stage.append(csp_layer) + layers_num += 1 + + if use_spp and arch in ["P5", "P6"]: + # in latest YOLOv5 use SPPFLayer instead of SPPLayer + sppf_layer = self.add_sublayer( + "layers{}.stage{}.sppf_layer".format(layers_num, i + 1), + SPPFLayer(out_channels, out_channels, ksize=5, bias=False, act=act), + ) + stage.append(sppf_layer) + layers_num += 1 + + self.csp_dark_blocks.append(nn.Sequential(*stage)) + + self._out_channels = [_out_channels[i] for i in self.return_idx] + self.strides = [[2, 4, 8, 16, 32, 64][i] for i in self.return_idx] + + def forward(self, inputs): + x = inputs["image"] + outputs = [] + x = self.stem(x) + for i, layer in enumerate(self.csp_dark_blocks): + x = layer(x) + if i + 1 in self.return_idx: + outputs.append(x) + return outputs + + @property + def out_shape(self): + return [ + ShapeSpec(channels=c, stride=s) + for c, s in zip(self._out_channels, self.strides) + ] diff --git a/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/cspresnet.py b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/cspresnet.py new file mode 100644 index 0000000000..7ac992aebc --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/cspresnet.py @@ -0,0 +1,338 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import, division, print_function + +import paddle +import paddle.nn as nn +from paddle import ParamAttr +from paddle.nn.initializer import Constant +from paddle.regularizer import L2Decay + +from .detr_ops import ShapeSpec +from .ops import get_act_fn + +__all__ = ["CSPResNet", "BasicBlock", "EffectiveSELayer", "ConvBNLayer"] + + +class ConvBNLayer(nn.Layer): + def __init__( + self, ch_in, ch_out, filter_size=3, stride=1, groups=1, padding=0, act=None + ): + super(ConvBNLayer, self).__init__() + + self.conv = nn.Conv2D( + in_channels=ch_in, + out_channels=ch_out, + kernel_size=filter_size, + stride=stride, + padding=padding, + groups=groups, + bias_attr=False, + ) + + self.bn = nn.BatchNorm2D( + ch_out, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)), + ) + self.act = ( + get_act_fn(act) if act is None or isinstance(act, (str, dict)) else act + ) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.act(x) + + return x + + +class RepVggBlock(nn.Layer): + def __init__(self, ch_in, ch_out, act="relu", alpha=False): + super(RepVggBlock, self).__init__() + self.ch_in = ch_in + self.ch_out = ch_out + self.conv1 = ConvBNLayer(ch_in, ch_out, 3, stride=1, padding=1, act=None) + self.conv2 = ConvBNLayer(ch_in, ch_out, 1, stride=1, padding=0, act=None) + self.act = ( + get_act_fn(act) if act is None or isinstance(act, (str, dict)) else act + ) + if alpha: + self.alpha = self.create_parameter( + shape=[1], + attr=ParamAttr(initializer=Constant(value=1.0)), + dtype="float32", + ) + else: + self.alpha = None + + def forward(self, x): + if hasattr(self, "conv"): + y = self.conv(x) + else: + if self.alpha is not None: + y = self.conv1(x) + self.alpha * self.conv2(x) + else: + y = self.conv1(x) + self.conv2(x) + y = self.act(y) + return y + + def convert_to_deploy(self): + if not hasattr(self, "conv"): + self.conv = nn.Conv2D( + in_channels=self.ch_in, + out_channels=self.ch_out, + kernel_size=3, + stride=1, + padding=1, + groups=1, + ) + kernel, bias = self.get_equivalent_kernel_bias() + self.conv.weight.set_value(kernel) + self.conv.bias.set_value(bias) + self.__delattr__("conv1") + self.__delattr__("conv2") + + def get_equivalent_kernel_bias(self): + kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1) + kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2) + if self.alpha is not None: + return ( + kernel3x3 + self.alpha * self._pad_1x1_to_3x3_tensor(kernel1x1), + bias3x3 + self.alpha * bias1x1, + ) + else: + return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1 + + def _pad_1x1_to_3x3_tensor(self, kernel1x1): + if kernel1x1 is None: + return 0 + else: + return nn.functional.pad(kernel1x1, [1, 1, 1, 1]) + + def _fuse_bn_tensor(self, branch): + if branch is None: + return 0, 0 + kernel = branch.conv.weight + running_mean = branch.bn._mean + running_var = branch.bn._variance + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn._epsilon + std = (running_var + eps).sqrt() + t = (gamma / std).reshape((-1, 1, 1, 1)) + return kernel * t, beta - running_mean * gamma / std + + +class BasicBlock(nn.Layer): + def __init__(self, ch_in, ch_out, act="relu", shortcut=True, use_alpha=False): + super(BasicBlock, self).__init__() + assert ch_in == ch_out + self.conv1 = ConvBNLayer(ch_in, ch_out, 3, stride=1, padding=1, act=act) + self.conv2 = RepVggBlock(ch_out, ch_out, act=act, alpha=use_alpha) + self.shortcut = shortcut + + def forward(self, x): + y = self.conv1(x) + y = self.conv2(y) + if self.shortcut: + return paddle.add(x, y) + else: + return y + + +class EffectiveSELayer(nn.Layer): + """Effective Squeeze-Excitation + From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + """ + + def __init__(self, channels, act="hardsigmoid"): + super(EffectiveSELayer, self).__init__() + self.fc = nn.Conv2D(channels, channels, kernel_size=1, padding=0) + self.act = ( + get_act_fn(act) if act is None or isinstance(act, (str, dict)) else act + ) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + x_se = self.fc(x_se) + return x * self.act(x_se) + + +class CSPResStage(nn.Layer): + def __init__( + self, + block_fn, + ch_in, + ch_out, + n, + stride, + act="relu", + attn="eca", + use_alpha=False, + ): + super(CSPResStage, self).__init__() + + ch_mid = (ch_in + ch_out) // 2 + if stride == 2: + self.conv_down = ConvBNLayer(ch_in, ch_mid, 3, stride=2, padding=1, act=act) + else: + self.conv_down = None + self.conv1 = ConvBNLayer(ch_mid, ch_mid // 2, 1, act=act) + self.conv2 = ConvBNLayer(ch_mid, ch_mid // 2, 1, act=act) + self.blocks = nn.Sequential( + *[ + block_fn( + ch_mid // 2, + ch_mid // 2, + act=act, + shortcut=True, + use_alpha=use_alpha, + ) + for i in range(n) + ] + ) + if attn: + self.attn = EffectiveSELayer(ch_mid, act="hardsigmoid") + else: + self.attn = None + + self.conv3 = ConvBNLayer(ch_mid, ch_out, 1, act=act) + + def forward(self, x): + if self.conv_down is not None: + x = self.conv_down(x) + y1 = self.conv1(x) + y2 = self.blocks(self.conv2(x)) + y = paddle.concat([y1, y2], axis=1) + if self.attn is not None: + y = self.attn(y) + y = self.conv3(y) + return y + + +class CSPResNet(nn.Layer): + __shared__ = ["width_mult", "depth_mult", "trt"] + + def __init__( + self, + layers=[3, 6, 6, 3], + channels=[64, 128, 256, 512, 1024], + act="swish", + return_idx=[1, 2, 3], + depth_wise=False, + use_large_stem=False, + width_mult=1.0, + depth_mult=1.0, + trt=False, + use_checkpoint=False, + use_alpha=False, + **args + ): + super(CSPResNet, self).__init__() + self.use_checkpoint = use_checkpoint + channels = [max(round(c * width_mult), 1) for c in channels] + layers = [max(round(l * depth_mult), 1) for l in layers] + act = ( + get_act_fn(act, trt=trt) + if act is None or isinstance(act, (str, dict)) + else act + ) + + if use_large_stem: + self.stem = nn.Sequential( + ( + "conv1", + ConvBNLayer(3, channels[0] // 2, 3, stride=2, padding=1, act=act), + ), + ( + "conv2", + ConvBNLayer( + channels[0] // 2, + channels[0] // 2, + 3, + stride=1, + padding=1, + act=act, + ), + ), + ( + "conv3", + ConvBNLayer( + channels[0] // 2, channels[0], 3, stride=1, padding=1, act=act + ), + ), + ) + else: + self.stem = nn.Sequential( + ( + "conv1", + ConvBNLayer(3, channels[0] // 2, 3, stride=2, padding=1, act=act), + ), + ( + "conv2", + ConvBNLayer( + channels[0] // 2, channels[0], 3, stride=1, padding=1, act=act + ), + ), + ) + + n = len(channels) - 1 + self.stages = nn.Sequential( + *[ + ( + str(i), + CSPResStage( + BasicBlock, + channels[i], + channels[i + 1], + layers[i], + 2, + act=act, + use_alpha=use_alpha, + ), + ) + for i in range(n) + ] + ) + + self._out_channels = channels[1:] + self._out_strides = [4 * 2**i for i in range(n)] + self.return_idx = return_idx + if use_checkpoint: + paddle.seed(0) + + def forward(self, inputs): + x = inputs["image"] + x = self.stem(x) + outs = [] + for idx, stage in enumerate(self.stages): + if self.use_checkpoint and self.training: + x = paddle.distributed.fleet.utils.recompute( + stage, x, **{"preserve_rng_state": True} + ) + else: + x = stage(x) + if idx in self.return_idx: + outs.append(x) + + return outs + + @property + def out_shape(self): + return [ + ShapeSpec(channels=self._out_channels[i], stride=self._out_strides[i]) + for i in self.return_idx + ] diff --git a/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/deformable_transformer.py b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/deformable_transformer.py new file mode 100644 index 0000000000..8d5045e153 --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/deformable_transformer.py @@ -0,0 +1,758 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modified from Deformable-DETR (https://github.com/fundamentalvision/Deformable-DETR) +# Copyright (c) 2020 SenseTime. All Rights Reserved. + +from __future__ import absolute_import, division, print_function + +import math + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.distributed.fleet.utils import recompute + +from .detr_ops import _get_clones, get_valid_ratio +from .initializer import constant_, linear_init_, normal_, xavier_uniform_ +from .layers import MultiHeadAttention +from .position_encoding import PositionEmbedding +from .utils import deformable_attention_core_func as ms_deformable_attn + +__all__ = ["DeformableTransformer"] + + +class MSDeformableAttention(nn.Layer): + def __init__( + self, embed_dim=256, num_heads=8, num_levels=4, num_points=4, lr_mult=0.1 + ): + """ + Multi-Scale Deformable Attention Module + """ + super(MSDeformableAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_levels = num_levels + self.num_points = num_points + self.total_points = num_heads * num_levels * num_points + + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.sampling_offsets = nn.Linear( + embed_dim, + self.total_points * 2, + weight_attr=ParamAttr(learning_rate=lr_mult), + bias_attr=ParamAttr(learning_rate=lr_mult), + ) + + self.attention_weights = nn.Linear(embed_dim, self.total_points) + self.value_proj = nn.Linear(embed_dim, embed_dim) + self.output_proj = nn.Linear(embed_dim, embed_dim) + self.ms_deformable_attn_core = ms_deformable_attn + + self._reset_parameters() + + def _reset_parameters(self): + # sampling_offsets + constant_(self.sampling_offsets.weight) + thetas = paddle.arange(self.num_heads, dtype=paddle.float32) * ( + 2.0 * math.pi / self.num_heads + ) + grid_init = paddle.stack([thetas.cos(), thetas.sin()], -1) + grid_init = grid_init / grid_init.abs().max(-1, keepdim=True) + grid_init = grid_init.reshape([self.num_heads, 1, 1, 2]).tile( + [1, self.num_levels, self.num_points, 1] + ) + scaling = paddle.arange(1, self.num_points + 1, dtype=paddle.float32).reshape( + [1, 1, -1, 1] + ) + grid_init *= scaling + self.sampling_offsets.bias.set_value(grid_init.flatten()) + # attention_weights + constant_(self.attention_weights.weight) + constant_(self.attention_weights.bias) + # proj + xavier_uniform_(self.value_proj.weight) + constant_(self.value_proj.bias) + xavier_uniform_(self.output_proj.weight) + constant_(self.output_proj.bias) + + def forward( + self, + query, + reference_points, + value, + value_spatial_shapes, + value_level_start_index, + value_mask=None, + ): + """ + Args: + query (Tensor): [bs, query_length, C] + reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area + value (Tensor): [bs, value_length, C] + value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + value_level_start_index (Tensor(int64)): [n_levels], [0, H_0*W_0, H_0*W_0+H_1*W_1, ...] + value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements + + Returns: + output (Tensor): [bs, Length_{query}, C] + """ + bs, Len_q = query.shape[:2] + Len_v = value.shape[1] + assert int(value_spatial_shapes.prod(1).sum()) == Len_v + + value = self.value_proj(value) + if value_mask is not None: + value_mask = value_mask.astype(value.dtype).unsqueeze(-1) + value *= value_mask + value = value.reshape([bs, Len_v, self.num_heads, self.head_dim]) + + sampling_offsets = self.sampling_offsets(query).reshape( + [bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2] + ) + attention_weights = self.attention_weights(query).reshape( + [bs, Len_q, self.num_heads, self.num_levels * self.num_points] + ) + attention_weights = F.softmax(attention_weights).reshape( + [bs, Len_q, self.num_heads, self.num_levels, self.num_points] + ) + + if reference_points.shape[-1] == 2: + offset_normalizer = value_spatial_shapes.flip([1]).reshape( + [1, 1, 1, self.num_levels, 1, 2] + ) + sampling_locations = reference_points.reshape( + [bs, Len_q, 1, self.num_levels, 1, 2] + ) + sampling_offsets / offset_normalizer.astype(sampling_offsets.dtype) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets + / self.num_points + * reference_points[:, :, None, :, None, 2:] + * 0.5 + ) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.".format( + reference_points.shape[-1] + ) + ) + + output = self.ms_deformable_attn_core( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ) + output = self.output_proj(output) + + return output + + +class DeformableTransformerEncoderLayer(nn.Layer): + def __init__( + self, + d_model=256, + n_head=8, + dim_feedforward=1024, + dropout=0.1, + activation="relu", + n_levels=4, + n_points=4, + lr_mult=0.1, + weight_attr=None, + bias_attr=None, + ): + super(DeformableTransformerEncoderLayer, self).__init__() + # self attention + self.self_attn = MSDeformableAttention( + d_model, n_head, n_levels, n_points, lr_mult + ) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model, weight_attr=weight_attr, bias_attr=bias_attr) + # ffn + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.activation = getattr(F, activation) + self.dropout2 = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + self.dropout3 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model, weight_attr=weight_attr, bias_attr=bias_attr) + self._reset_parameters() + + def _reset_parameters(self): + linear_init_(self.linear1) + linear_init_(self.linear2) + xavier_uniform_(self.linear1.weight) + xavier_uniform_(self.linear2.weight) + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, src): + src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) + src = src + self.dropout3(src2) + src = self.norm2(src) + return src + + def forward( + self, + src, + reference_points, + spatial_shapes, + level_start_index, + src_mask=None, + query_pos_embed=None, + ): + # self attention + src2 = self.self_attn( + self.with_pos_embed(src, query_pos_embed), + reference_points, + src, + spatial_shapes, + level_start_index, + src_mask, + ) + src = src + self.dropout1(src2) + src = self.norm1(src) + # ffn + src = self.forward_ffn(src) + + return src + + +class DeformableTransformerEncoder(nn.Layer): + __inject__ = ["encoder_layer"] + + def __init__(self, encoder_layer, num_layers, with_rp=-1): + super(DeformableTransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + assert with_rp <= num_layers + self.with_rp = with_rp + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, offset=0.5): + valid_ratios = valid_ratios.unsqueeze(1) + reference_points = [] + for i, (H, W) in enumerate(spatial_shapes): + ref_y, ref_x = paddle.meshgrid( + paddle.arange(end=H) + offset, paddle.arange(end=W) + offset + ) + ref_y = ref_y.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 1] * H) + ref_x = ref_x.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 0] * W) + reference_points.append(paddle.stack((ref_x, ref_y), axis=-1)) + reference_points = paddle.concat(reference_points, 1).unsqueeze(2) + reference_points = reference_points * valid_ratios + return reference_points + + def forward( + self, + feat, + spatial_shapes, + level_start_index, + feat_mask=None, + query_pos_embed=None, + valid_ratios=None, + ): + if valid_ratios is None: + valid_ratios = paddle.ones([feat.shape[0], spatial_shapes.shape[0], 2]) + reference_points = self.get_reference_points(spatial_shapes, valid_ratios) + for i, layer in enumerate(self.layers): + if self.training and i < self.with_rp: + feat = recompute( + layer, + feat, + reference_points, + spatial_shapes, + level_start_index, + feat_mask, + query_pos_embed, + **{"preserve_rng_state": True, "use_reentrant": False}, + ) + else: + feat = layer( + feat, + reference_points, + spatial_shapes, + level_start_index, + feat_mask, + query_pos_embed, + ) + + return feat + + +class DeformableTransformerDecoderLayer(nn.Layer): + def __init__( + self, + d_model=256, + n_head=8, + dim_feedforward=1024, + dropout=0.1, + activation="relu", + n_levels=4, + n_points=4, + lr_mult=0.1, + weight_attr=None, + bias_attr=None, + ): + super(DeformableTransformerDecoderLayer, self).__init__() + + # self attention + self.self_attn = MultiHeadAttention(d_model, n_head, dropout=dropout) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model, weight_attr=weight_attr, bias_attr=bias_attr) + + # cross attention + self.cross_attn = MSDeformableAttention( + d_model, n_head, n_levels, n_points, lr_mult + ) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model, weight_attr=weight_attr, bias_attr=bias_attr) + + # ffn + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.activation = getattr(F, activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model, weight_attr=weight_attr, bias_attr=bias_attr) + self._reset_parameters() + + def _reset_parameters(self): + linear_init_(self.linear1) + linear_init_(self.linear2) + xavier_uniform_(self.linear1.weight) + xavier_uniform_(self.linear2.weight) + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward( + self, + tgt, + reference_points, + memory, + memory_spatial_shapes, + memory_level_start_index, + memory_mask=None, + query_pos_embed=None, + ): + # self attention + q = k = self.with_pos_embed(tgt, query_pos_embed) + tgt2 = self.self_attn(q, k, value=tgt) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # cross attention + tgt2 = self.cross_attn( + self.with_pos_embed(tgt, query_pos_embed), + reference_points, + memory, + memory_spatial_shapes, + memory_level_start_index, + memory_mask, + ) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # ffn + tgt = self.forward_ffn(tgt) + + return tgt + + +class DeformableTransformerDecoder(nn.Layer): + def __init__(self, decoder_layer, num_layers, return_intermediate=False): + super(DeformableTransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.return_intermediate = return_intermediate + + def forward( + self, + tgt, + reference_points, + memory, + memory_spatial_shapes, + memory_level_start_index, + memory_mask=None, + query_pos_embed=None, + ): + output = tgt + intermediate = [] + for lid, layer in enumerate(self.layers): + output = layer( + output, + reference_points, + memory, + memory_spatial_shapes, + memory_level_start_index, + memory_mask, + query_pos_embed, + ) + + if self.return_intermediate: + intermediate.append(output) + + if self.return_intermediate: + return paddle.stack(intermediate) + + return output.unsqueeze(0) + + +class DeformableTransformer(nn.Layer): + __shared__ = ["hidden_dim"] + + def __init__( + self, + num_queries=300, + position_embed_type="sine", + return_intermediate_dec=True, + in_feats_channel=[512, 1024, 2048], + num_feature_levels=4, + num_encoder_points=4, + num_decoder_points=4, + hidden_dim=256, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=1024, + dropout=0.1, + activation="relu", + lr_mult=0.1, + pe_temperature=10000, + pe_offset=-0.5, + ): + super(DeformableTransformer, self).__init__() + assert position_embed_type in [ + "sine", + "learned", + ], f"ValueError: position_embed_type not supported {position_embed_type}!" + assert len(in_feats_channel) <= num_feature_levels + + self.hidden_dim = hidden_dim + self.nhead = nhead + self.num_feature_levels = num_feature_levels + + encoder_layer = DeformableTransformerEncoderLayer( + hidden_dim, + nhead, + dim_feedforward, + dropout, + activation, + num_feature_levels, + num_encoder_points, + lr_mult, + ) + self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers) + + decoder_layer = DeformableTransformerDecoderLayer( + hidden_dim, + nhead, + dim_feedforward, + dropout, + activation, + num_feature_levels, + num_decoder_points, + ) + self.decoder = DeformableTransformerDecoder( + decoder_layer, num_decoder_layers, return_intermediate_dec + ) + + self.level_embed = nn.Embedding(num_feature_levels, hidden_dim) + self.tgt_embed = nn.Embedding(num_queries, hidden_dim) + self.query_pos_embed = nn.Embedding(num_queries, hidden_dim) + + self.reference_points = nn.Linear( + hidden_dim, + 2, + weight_attr=ParamAttr(learning_rate=lr_mult), + bias_attr=ParamAttr(learning_rate=lr_mult), + ) + + self.input_proj = nn.LayerList() + for in_channels in in_feats_channel: + self.input_proj.append( + nn.Sequential( + nn.Conv2D(in_channels, hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + ) + ) + in_channels = in_feats_channel[-1] + for _ in range(num_feature_levels - len(in_feats_channel)): + self.input_proj.append( + nn.Sequential( + nn.Conv2D( + in_channels, hidden_dim, kernel_size=3, stride=2, padding=1 + ), + nn.GroupNorm(32, hidden_dim), + ) + ) + in_channels = hidden_dim + + self.position_embedding = PositionEmbedding( + hidden_dim // 2, + temperature=pe_temperature, + normalize=True if position_embed_type == "sine" else False, + embed_type=position_embed_type, + offset=pe_offset, + eps=1e-4, + ) + + self._reset_parameters() + + def _reset_parameters(self): + normal_(self.level_embed.weight) + normal_(self.tgt_embed.weight) + normal_(self.query_pos_embed.weight) + xavier_uniform_(self.reference_points.weight) + constant_(self.reference_points.bias) + for l in self.input_proj: + xavier_uniform_(l[0].weight) + constant_(l[0].bias) + + @classmethod + def from_config(cls, cfg, input_shape): + return { + "in_feats_channel": [i.channels for i in input_shape], + } + + def forward(self, src_feats, src_mask=None, *args, **kwargs): + srcs = [] + for i in range(len(src_feats)): + srcs.append(self.input_proj[i](src_feats[i])) + if self.num_feature_levels > len(srcs): + len_srcs = len(srcs) + for i in range(len_srcs, self.num_feature_levels): + if i == len_srcs: + srcs.append(self.input_proj[i](src_feats[-1])) + else: + srcs.append(self.input_proj[i](srcs[-1])) + src_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + valid_ratios = [] + for level, src in enumerate(srcs): + src_shape = paddle.shape(src) + bs = src_shape[0:1] + h = src_shape[2:3] + w = src_shape[3:4] + spatial_shapes.append(paddle.concat([h, w])) + src = src.flatten(2).transpose([0, 2, 1]) + src_flatten.append(src) + if src_mask is not None: + mask = F.interpolate(src_mask.unsqueeze(0), size=(h, w))[0] + else: + mask = paddle.ones([bs, h, w]) + valid_ratios.append(get_valid_ratio(mask)) + pos_embed = self.position_embedding(mask).flatten(1, 2) + lvl_pos_embed = pos_embed + self.level_embed.weight[level] + lvl_pos_embed_flatten.append(lvl_pos_embed) + mask = mask.flatten(1) + mask_flatten.append(mask) + src_flatten = paddle.concat(src_flatten, 1) + mask_flatten = None if src_mask is None else paddle.concat(mask_flatten, 1) + lvl_pos_embed_flatten = paddle.concat(lvl_pos_embed_flatten, 1) + # [l, 2] + spatial_shapes = paddle.to_tensor(paddle.stack(spatial_shapes).astype("int64")) + # [l], 每一个level的起始index + level_start_index = paddle.concat( + [paddle.zeros([1], dtype="int64"), spatial_shapes.prod(1).cumsum(0)[:-1]] + ) + # [b, l, 2] + valid_ratios = paddle.stack(valid_ratios, 1) + + # encoder + memory = self.encoder( + src_flatten, + spatial_shapes, + level_start_index, + mask_flatten, + lvl_pos_embed_flatten, + valid_ratios, + ) + + # prepare input for decoder + bs, _, c = memory.shape + query_embed = self.query_pos_embed.weight.unsqueeze(0).tile([bs, 1, 1]) + tgt = self.tgt_embed.weight.unsqueeze(0).tile([bs, 1, 1]) + reference_points = F.sigmoid(self.reference_points(query_embed)) + reference_points_input = reference_points.unsqueeze(2) * valid_ratios.unsqueeze( + 1 + ) + + # decoder + hs = self.decoder( + tgt, + reference_points_input, + memory, + spatial_shapes, + level_start_index, + mask_flatten, + query_embed, + ) + + return (hs, memory, reference_points) + + +class QRDeformableTransformerDecoder(DeformableTransformerDecoder): + def __init__( + self, + decoder_layer, + num_layers, + start_q=None, + end_q=None, + return_intermediate=False, + ): + super(QRDeformableTransformerDecoder, self).__init__( + decoder_layer, num_layers, return_intermediate=return_intermediate + ) + self.start_q = start_q + self.end_q = end_q + + def forward( + self, + tgt, + reference_points, + memory, + memory_spatial_shapes, + memory_level_start_index, + memory_mask=None, + query_pos_embed=None, + ): + + if not self.training: + return super(QRDeformableTransformerDecoder, self).forward( + tgt, + reference_points, + memory, + memory_spatial_shapes, + memory_level_start_index, + memory_mask=memory_mask, + query_pos_embed=query_pos_embed, + ) + + batchsize = tgt.shape[0] + query_list_reserve = [tgt] + intermediate = [] + for lid, layer in enumerate(self.layers): + + start_q = self.start_q[lid] + end_q = self.end_q[lid] + query_list = query_list_reserve.copy()[start_q:end_q] + + # prepare for parallel process + output = paddle.concat(query_list, axis=0) + fakesetsize = int(output.shape[0] / batchsize) + reference_points_tiled = reference_points.tile([fakesetsize, 1, 1, 1]) + + memory_tiled = memory.tile([fakesetsize, 1, 1]) + query_pos_embed_tiled = query_pos_embed.tile([fakesetsize, 1, 1]) + memory_mask_tiled = memory_mask.tile([fakesetsize, 1]) + + output = layer( + output, + reference_points_tiled, + memory_tiled, + memory_spatial_shapes, + memory_level_start_index, + memory_mask_tiled, + query_pos_embed_tiled, + ) + + for i in range(fakesetsize): + query_list_reserve.append(output[batchsize * i : batchsize * (i + 1)]) + + if self.return_intermediate: + for i in range(fakesetsize): + intermediate.append(output[batchsize * i : batchsize * (i + 1)]) + + if self.return_intermediate: + return paddle.stack(intermediate) + + return output.unsqueeze(0) + + +class QRDeformableTransformer(DeformableTransformer): + + def __init__( + self, + num_queries=300, + position_embed_type="sine", + return_intermediate_dec=True, + in_feats_channel=[512, 1024, 2048], + num_feature_levels=4, + num_encoder_points=4, + num_decoder_points=4, + hidden_dim=256, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=1024, + dropout=0.1, + activation="relu", + lr_mult=0.1, + pe_temperature=10000, + pe_offset=-0.5, + start_q=None, + end_q=None, + ): + super(QRDeformableTransformer, self).__init__( + num_queries=num_queries, + position_embed_type=position_embed_type, + return_intermediate_dec=return_intermediate_dec, + in_feats_channel=in_feats_channel, + num_feature_levels=num_feature_levels, + num_encoder_points=num_encoder_points, + num_decoder_points=num_decoder_points, + hidden_dim=hidden_dim, + nhead=nhead, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + lr_mult=lr_mult, + pe_temperature=pe_temperature, + pe_offset=pe_offset, + ) + + decoder_layer = DeformableTransformerDecoderLayer( + hidden_dim, + nhead, + dim_feedforward, + dropout, + activation, + num_feature_levels, + num_decoder_points, + ) + self.decoder = QRDeformableTransformerDecoder( + decoder_layer, num_decoder_layers, start_q, end_q, return_intermediate_dec + ) diff --git a/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/detr_loss.py b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/detr_loss.py new file mode 100644 index 0000000000..c18b59a4eb --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/detr_loss.py @@ -0,0 +1,944 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import, division, print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from .detr_ops import bbox_iou +from .utils import ( + bbox_cxcywh_to_xyxy, + mal_loss_with_logits, + sigmoid_focal_loss, + varifocal_loss_with_logits, +) + +__all__ = ["DETRLoss", "DINOLoss", "RTDETRv3Loss"] + + +class GIoULoss(object): + """ + Generalized Intersection over Union, see https://arxiv.org/abs/1902.09630 + Args: + loss_weight (float): giou loss weight, default as 1 + eps (float): epsilon to avoid divide by zero, default as 1e-10 + reduction (string): Options are "none", "mean" and "sum". default as none + """ + + def __init__(self, loss_weight=1.0, eps=1e-10, reduction="none"): + self.loss_weight = loss_weight + self.eps = eps + assert reduction in ("none", "mean", "sum") + self.reduction = reduction + + def bbox_overlap(self, box1, box2, eps=1e-10): + """calculate the iou of box1 and box2 + Args: + box1 (Tensor): box1 with the shape (..., 4) + box2 (Tensor): box1 with the shape (..., 4) + eps (float): epsilon to avoid divide by zero + Return: + iou (Tensor): iou of box1 and box2 + overlap (Tensor): overlap of box1 and box2 + union (Tensor): union of box1 and box2 + """ + x1, y1, x2, y2 = box1 + x1g, y1g, x2g, y2g = box2 + + xkis1 = paddle.maximum(x1, x1g) + ykis1 = paddle.maximum(y1, y1g) + xkis2 = paddle.minimum(x2, x2g) + ykis2 = paddle.minimum(y2, y2g) + w_inter = (xkis2 - xkis1).clip(0) + h_inter = (ykis2 - ykis1).clip(0) + overlap = w_inter * h_inter + + area1 = (x2 - x1) * (y2 - y1) + area2 = (x2g - x1g) * (y2g - y1g) + union = area1 + area2 - overlap + eps + iou = overlap / union + + return iou, overlap, union + + def __call__(self, pbox, gbox, iou_weight=1.0, loc_reweight=None): + x1, y1, x2, y2 = paddle.split(pbox, num_or_sections=4, axis=-1) + x1g, y1g, x2g, y2g = paddle.split(gbox, num_or_sections=4, axis=-1) + box1 = [x1, y1, x2, y2] + box2 = [x1g, y1g, x2g, y2g] + iou, overlap, union = self.bbox_overlap(box1, box2, self.eps) + xc1 = paddle.minimum(x1, x1g) + yc1 = paddle.minimum(y1, y1g) + xc2 = paddle.maximum(x2, x2g) + yc2 = paddle.maximum(y2, y2g) + + area_c = (xc2 - xc1) * (yc2 - yc1) + self.eps + miou = iou - ((area_c - union) / area_c) + if loc_reweight is not None: + loc_reweight = paddle.reshape(loc_reweight, shape=(-1, 1)) + loc_thresh = 0.9 + giou = 1 - (1 - loc_thresh) * miou - loc_thresh * miou * loc_reweight + else: + giou = 1 - miou + if self.reduction == "none": + loss = giou + elif self.reduction == "sum": + loss = paddle.sum(giou * iou_weight) + else: + loss = paddle.mean(giou * iou_weight) + return loss * self.loss_weight + + +class DETRLoss(nn.Layer): + __shared__ = ["num_classes", "use_focal_loss"] + __inject__ = ["matcher"] + + def __init__( + self, + num_classes=80, + matcher="HungarianMatcher", + loss_coeff={ + "class": 1, + "bbox": 5, + "giou": 2, + "no_object": 0.1, + "mask": 1, + "dice": 1, + }, + aux_loss=True, + use_focal_loss=True, + use_mal=False, + use_vfl=False, + vfl_iou_type="bbox", + use_uni_match=False, + uni_match_ind=0, + ): + r""" + Args: + num_classes (int): The number of classes. + matcher (HungarianMatcher): It computes an assignment between the targets + and the predictions of the network. + loss_coeff (dict): The coefficient of loss. + aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used. + use_focal_loss (bool): Use focal loss or not. + """ + super(DETRLoss, self).__init__() + + self.num_classes = num_classes + self.matcher = matcher + self.loss_coeff = loss_coeff + self.aux_loss = aux_loss + self.use_focal_loss = use_focal_loss + self.use_mal = use_mal + self.use_vfl = use_vfl + self.vfl_iou_type = vfl_iou_type + self.use_uni_match = use_uni_match + self.uni_match_ind = uni_match_ind + + if not self.use_focal_loss: + self.loss_coeff["class"] = paddle.full( + [num_classes + 1], loss_coeff["class"] + ) + self.loss_coeff["class"][-1] = loss_coeff["no_object"] + self.giou_loss = GIoULoss() + + def _get_loss_class( + self, + logits, + gt_class, + match_indices, + bg_index, + num_gts, + postfix="", + iou_score=None, + gt_score=None, + ): + # logits: [b, query, num_classes], gt_class: list[[n, 1]] + name_class = "loss_class" + postfix + + target_label = paddle.full(logits.shape[:2], bg_index, dtype="int64") + bs, num_query_objects = target_label.shape + num_gt = sum(len(a) for a in gt_class) + if num_gt > 0: + index, updates = self._get_index_updates( + num_query_objects, gt_class, match_indices + ) + target_label = paddle.scatter( + target_label.reshape([-1, 1]), index, updates.astype("int64") + ) + target_label = target_label.reshape([bs, num_query_objects]) + if self.use_focal_loss: + target_label = F.one_hot(target_label, self.num_classes + 1)[..., :-1] + if iou_score is not None and (self.use_vfl or self.use_mal): + if gt_score is not None: + target_score = paddle.zeros([bs, num_query_objects]) + target_score = paddle.scatter( + target_score.reshape([-1, 1]), index, gt_score + ) + target_score = ( + target_score.reshape([bs, num_query_objects, 1]) * target_label + ) + + target_score_iou = paddle.zeros([bs, num_query_objects]) + target_score_iou = paddle.scatter( + target_score_iou.reshape([-1, 1]), index, iou_score + ) + target_score_iou = ( + target_score_iou.reshape([bs, num_query_objects, 1]) + * target_label + ) + target_score = paddle.multiply(target_score, target_score_iou) + if self.use_mal: + loss_ = self.loss_coeff["class"] * mal_loss_with_logits( + logits, + target_score, + target_label, + num_gts / num_query_objects, + ) + else: + loss_ = self.loss_coeff["class"] * varifocal_loss_with_logits( + logits, + target_score, + target_label, + num_gts / num_query_objects, + ) + else: + target_score = paddle.zeros([bs, num_query_objects]) + if num_gt > 0: + target_score = paddle.scatter( + target_score.reshape([-1, 1]), index, iou_score + ) + target_score = ( + target_score.reshape([bs, num_query_objects, 1]) * target_label + ) + if self.use_mal: + loss_ = self.loss_coeff["class"] * mal_loss_with_logits( + logits, + target_score, + target_label, + num_gts / num_query_objects, + ) + else: + loss_ = self.loss_coeff["class"] * varifocal_loss_with_logits( + logits, + target_score, + target_label, + num_gts / num_query_objects, + ) + else: + loss_ = self.loss_coeff["class"] * sigmoid_focal_loss( + logits, target_label, num_gts / num_query_objects + ) + else: + loss_ = F.cross_entropy( + logits, target_label, weight=self.loss_coeff["class"] + ) + return {name_class: loss_} + + def _get_loss_bbox(self, boxes, gt_bbox, match_indices, num_gts, postfix=""): + # boxes: [b, query, 4], gt_bbox: list[[n, 4]] + name_bbox = "loss_bbox" + postfix + name_giou = "loss_giou" + postfix + + loss = dict() + if sum(len(a) for a in gt_bbox) == 0: + loss[name_bbox] = paddle.to_tensor([0.0]) + loss[name_giou] = paddle.to_tensor([0.0]) + return loss + + src_bbox, target_bbox = self._get_src_target_assign( + boxes, gt_bbox, match_indices + ) + loss[name_bbox] = ( + self.loss_coeff["bbox"] + * F.l1_loss(src_bbox, target_bbox, reduction="sum") + / num_gts + ) + loss[name_giou] = self.giou_loss( + bbox_cxcywh_to_xyxy(src_bbox), bbox_cxcywh_to_xyxy(target_bbox) + ) + loss[name_giou] = loss[name_giou].sum() / num_gts + loss[name_giou] = self.loss_coeff["giou"] * loss[name_giou] + return loss + + def _get_loss_mask(self, masks, gt_mask, match_indices, num_gts, postfix=""): + # masks: [b, query, h, w], gt_mask: list[[n, H, W]] + name_mask = "loss_mask" + postfix + name_dice = "loss_dice" + postfix + + loss = dict() + if sum(len(a) for a in gt_mask) == 0: + loss[name_mask] = paddle.to_tensor([0.0]) + loss[name_dice] = paddle.to_tensor([0.0]) + return loss + + src_masks, target_masks = self._get_src_target_assign( + masks, gt_mask, match_indices + ) + src_masks = F.interpolate( + src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode="bilinear" + )[0] + loss[name_mask] = self.loss_coeff["mask"] * F.sigmoid_focal_loss( + src_masks, target_masks, paddle.to_tensor([num_gts], dtype="float32") + ) + loss[name_dice] = self.loss_coeff["dice"] * self._dice_loss( + src_masks, target_masks, num_gts + ) + return loss + + def _dice_loss(self, inputs, targets, num_gts): + inputs = F.sigmoid(inputs) + inputs = inputs.flatten(1) + targets = targets.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_gts + + def _get_loss_aux( + self, + boxes, + logits, + gt_bbox, + gt_class, + bg_index, + num_gts, + dn_match_indices=None, + postfix="", + masks=None, + gt_mask=None, + gt_score=None, + ): + loss_class = [] + loss_bbox, loss_giou = [], [] + loss_mask, loss_dice = [], [] + if dn_match_indices is not None: + match_indices = dn_match_indices + elif self.use_uni_match: + match_indices = self.matcher( + boxes[self.uni_match_ind], + logits[self.uni_match_ind], + gt_bbox, + gt_class, + masks=masks[self.uni_match_ind] if masks is not None else None, + gt_mask=gt_mask, + ) + for i, (aux_boxes, aux_logits) in enumerate(zip(boxes, logits)): + aux_masks = masks[i] if masks is not None else None + if not self.use_uni_match and dn_match_indices is None: + match_indices = self.matcher( + aux_boxes, + aux_logits, + gt_bbox, + gt_class, + masks=aux_masks, + gt_mask=gt_mask, + ) + if self.use_vfl or self.use_mal: + if sum(len(a) for a in gt_bbox) > 0: + src_bbox, target_bbox = self._get_src_target_assign( + aux_boxes.detach(), gt_bbox, match_indices + ) + iou_score = bbox_iou( + bbox_cxcywh_to_xyxy(src_bbox).split(4, -1), + bbox_cxcywh_to_xyxy(target_bbox).split(4, -1), + ) + else: + iou_score = None + if gt_score is not None: + _, target_score = self._get_src_target_assign( + logits[-1].detach(), gt_score, match_indices + ) + else: + iou_score = None + loss_class.append( + self._get_loss_class( + aux_logits, + gt_class, + match_indices, + bg_index, + num_gts, + postfix, + iou_score, + gt_score=target_score if gt_score is not None else None, + )["loss_class" + postfix] + ) + loss_ = self._get_loss_bbox( + aux_boxes, gt_bbox, match_indices, num_gts, postfix + ) + loss_bbox.append(loss_["loss_bbox" + postfix]) + loss_giou.append(loss_["loss_giou" + postfix]) + if masks is not None and gt_mask is not None: + loss_ = self._get_loss_mask( + aux_masks, gt_mask, match_indices, num_gts, postfix + ) + loss_mask.append(loss_["loss_mask" + postfix]) + loss_dice.append(loss_["loss_dice" + postfix]) + loss = { + "loss_class_aux" + postfix: paddle.add_n(loss_class), + "loss_bbox_aux" + postfix: paddle.add_n(loss_bbox), + "loss_giou_aux" + postfix: paddle.add_n(loss_giou), + } + if masks is not None and gt_mask is not None: + loss["loss_mask_aux" + postfix] = paddle.add_n(loss_mask) + loss["loss_dice_aux" + postfix] = paddle.add_n(loss_dice) + return loss + + def _get_index_updates(self, num_query_objects, target, match_indices): + batch_idx = paddle.concat( + [paddle.full_like(src, i) for i, (src, _) in enumerate(match_indices)] + ) + src_idx = paddle.concat([src for (src, _) in match_indices]) + src_idx += batch_idx * num_query_objects + if "npu" in paddle.device.get_device(): + target_assign = paddle.concat( + [ + paddle.gather(t.to(paddle.int32), dst.to(paddle.int32), axis=0) + for t, (_, dst) in zip(target, match_indices) + ] + ) + else: + target_assign = paddle.concat( + [ + paddle.gather(t, dst, axis=0) + for t, (_, dst) in zip(target, match_indices) + ] + ) + return src_idx, target_assign + + def _get_src_target_assign(self, src, target, match_indices): + src_assign = paddle.concat( + [ + ( + paddle.gather(t, I, axis=0) + if len(I) > 0 + else paddle.zeros([0, t.shape[-1]]) + ) + for t, (I, _) in zip(src, match_indices) + ] + ) + target_assign = paddle.concat( + [ + ( + paddle.gather(t, J, axis=0) + if len(J) > 0 + else paddle.zeros([0, t.shape[-1]]) + ) + for t, (_, J) in zip(target, match_indices) + ] + ) + return src_assign, target_assign + + def _get_num_gts(self, targets, dtype="float32"): + num_gts = sum(len(a) for a in targets) + num_gts = paddle.to_tensor([num_gts], dtype=dtype) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.all_reduce(num_gts) + num_gts /= paddle.distributed.get_world_size() + num_gts = paddle.clip(num_gts, min=1.0) + return num_gts + + def _get_prediction_loss( + self, + boxes, + logits, + gt_bbox, + gt_class, + masks=None, + gt_mask=None, + postfix="", + dn_match_indices=None, + num_gts=1, + gt_score=None, + ): + if dn_match_indices is None: + match_indices = self.matcher( + boxes, logits, gt_bbox, gt_class, masks=masks, gt_mask=gt_mask + ) + else: + match_indices = dn_match_indices + + if self.use_vfl or self.use_mal: + if gt_score is not None: # ssod + _, target_score = self._get_src_target_assign( + logits[-1].detach(), gt_score, match_indices + ) + elif sum(len(a) for a in gt_bbox) > 0: + if self.vfl_iou_type == "bbox": + src_bbox, target_bbox = self._get_src_target_assign( + boxes.detach(), gt_bbox, match_indices + ) + iou_score = bbox_iou( + bbox_cxcywh_to_xyxy(src_bbox).split(4, -1), + bbox_cxcywh_to_xyxy(target_bbox).split(4, -1), + ) + elif self.vfl_iou_type == "mask": + assert ( + masks is not None and gt_mask is not None, + "Make sure the input has `mask` and `gt_mask`", + ) + assert sum(len(a) for a in gt_mask) > 0 + src_mask, target_mask = self._get_src_target_assign( + masks.detach(), gt_mask, match_indices + ) + src_mask = F.interpolate( + src_mask.unsqueeze(0), + scale_factor=2, + mode="bilinear", + align_corners=False, + ).squeeze(0) + target_mask = F.interpolate( + target_mask.unsqueeze(0), + size=src_mask.shape[-2:], + mode="bilinear", + align_corners=False, + ).squeeze(0) + src_mask = src_mask.flatten(1) + src_mask = F.sigmoid(src_mask) + src_mask = paddle.where(src_mask > 0.5, 1.0, 0.0).astype( + masks.dtype + ) + target_mask = target_mask.flatten(1) + target_mask = paddle.where(target_mask > 0.5, 1.0, 0.0).astype( + masks.dtype + ) + inter = (src_mask * target_mask).sum(1) + union = src_mask.sum(1) + target_mask.sum(1) - inter + iou_score = (inter + 1e-2) / (union + 1e-2) + iou_score = iou_score.unsqueeze(-1) + else: + iou_score = None + else: + iou_score = None + else: + iou_score = None + + loss = dict() + loss.update( + self._get_loss_class( + logits, + gt_class, + match_indices, + self.num_classes, + num_gts, + postfix, + iou_score, + gt_score=target_score if gt_score is not None else None, + ) + ) + loss.update( + self._get_loss_bbox(boxes, gt_bbox, match_indices, num_gts, postfix) + ) + if masks is not None and gt_mask is not None: + loss.update( + self._get_loss_mask(masks, gt_mask, match_indices, num_gts, postfix) + ) + return loss + + def forward( + self, + boxes, + logits, + gt_bbox, + gt_class, + masks=None, + gt_mask=None, + postfix="", + gt_score=None, + **kwargs + ): + r""" + Args: + boxes (Tensor): [l, b, query, 4] + logits (Tensor): [l, b, query, num_classes] + gt_bbox (List(Tensor)): list[[n, 4]] + gt_class (List(Tensor)): list[[n, 1]] + masks (Tensor, optional): [l, b, query, h, w] + gt_mask (List(Tensor), optional): list[[n, H, W]] + postfix (str): postfix of loss name + """ + + dn_match_indices = kwargs.get("dn_match_indices", None) + num_gts = kwargs.get("num_gts", None) + if num_gts is None: + num_gts = self._get_num_gts(gt_class) + + total_loss = self._get_prediction_loss( + boxes[-1], + logits[-1], + gt_bbox, + gt_class, + masks=masks[-1] if masks is not None else None, + gt_mask=gt_mask, + postfix=postfix, + dn_match_indices=dn_match_indices, + num_gts=num_gts, + gt_score=gt_score if gt_score is not None else None, + ) + + if self.aux_loss: + total_loss.update( + self._get_loss_aux( + boxes[:-1], + logits[:-1], + gt_bbox, + gt_class, + self.num_classes, + num_gts, + dn_match_indices, + postfix, + masks=masks[:-1] if masks is not None else None, + gt_mask=gt_mask, + gt_score=gt_score if gt_score is not None else None, + ) + ) + + return total_loss + + +class DINOLoss(DETRLoss): + def forward( + self, + boxes, + logits, + gt_bbox, + gt_class, + masks=None, + gt_mask=None, + postfix="", + dn_out_bboxes=None, + dn_out_logits=None, + dn_meta=None, + gt_score=None, + **kwargs + ): + num_gts = self._get_num_gts(gt_class) + total_loss = super(DINOLoss, self).forward( + boxes, logits, gt_bbox, gt_class, num_gts=num_gts, gt_score=gt_score + ) + + if dn_meta is not None: + dn_positive_idx, dn_num_group = ( + dn_meta["dn_positive_idx"], + dn_meta["dn_num_group"], + ) + assert len(gt_class) == len(dn_positive_idx) + + # denoising match indices + dn_match_indices = self.get_dn_match_indices( + gt_class, dn_positive_idx, dn_num_group + ) + + # compute denoising training loss + num_gts *= dn_num_group + dn_loss = super(DINOLoss, self).forward( + dn_out_bboxes, + dn_out_logits, + gt_bbox, + gt_class, + postfix="_dn", + dn_match_indices=dn_match_indices, + num_gts=num_gts, + gt_score=gt_score, + ) + total_loss.update(dn_loss) + else: + total_loss.update( + {k + "_dn": paddle.to_tensor([0.0]) for k in total_loss.keys()} + ) + + return total_loss + + @staticmethod + def get_dn_match_indices(labels, dn_positive_idx, dn_num_group): + dn_match_indices = [] + for i in range(len(labels)): + num_gt = len(labels[i]) + if num_gt > 0: + gt_idx = paddle.arange(end=num_gt, dtype="int64") + gt_idx = gt_idx.tile([dn_num_group]) + assert len(dn_positive_idx[i]) == len(gt_idx) + dn_match_indices.append((dn_positive_idx[i], gt_idx)) + else: + dn_match_indices.append( + (paddle.zeros([0], dtype="int64"), paddle.zeros([0], dtype="int64")) + ) + return dn_match_indices + + +class RTDETRv3Loss(DETRLoss): + def forward( + self, + boxes, + logits, + gt_bbox, + gt_class, + masks=None, + gt_mask=None, + postfix="", + dn_out_bboxes=None, + dn_out_logits=None, + dn_meta=None, + gt_score=None, + o2m=1, + **kwargs + ): + if o2m != 1: + gt_boxes_copy = [box.tile([o2m, 1]) for box in gt_bbox] + gt_class_copy = [label.tile([o2m, 1]) for label in gt_class] + else: + gt_boxes_copy = gt_bbox + gt_class_copy = gt_class + num_gts_copy = self._get_num_gts(gt_class_copy) + total_loss = self._get_prediction_loss( + boxes[-1], + logits[-1], + gt_boxes_copy, + gt_class_copy, + masks=masks[-1] if masks is not None else None, + gt_mask=gt_mask, + postfix=postfix, + dn_match_indices=None, + num_gts=num_gts_copy, + gt_score=gt_score if gt_score is not None else None, + ) + + if self.aux_loss: + total_loss.update( + self._get_loss_aux( + boxes[:-1], + logits[:-1], + gt_boxes_copy, + gt_class_copy, + self.num_classes, + num_gts_copy, + dn_match_indices=None, + postfix=postfix, + masks=masks[:-1] if masks is not None else None, + gt_mask=gt_mask, + gt_score=gt_score if gt_score is not None else None, + ) + ) + + if dn_meta is not None: + num_gts = self._get_num_gts(gt_class) + dn_positive_idx, dn_num_group = ( + dn_meta["dn_positive_idx"], + dn_meta["dn_num_group"], + ) + assert len(gt_class) == len(dn_positive_idx) + + # denoising match indices + dn_match_indices = self.get_dn_match_indices( + gt_class, dn_positive_idx, dn_num_group + ) + + # compute denoising training loss + num_gts *= dn_num_group + dn_loss = super(RTDETRv3Loss, self).forward( + dn_out_bboxes, + dn_out_logits, + gt_bbox, + gt_class, + postfix="_dn", + dn_match_indices=dn_match_indices, + num_gts=num_gts, + gt_score=gt_score, + ) + total_loss.update(dn_loss) + else: + total_loss.update( + {k + "_dn": paddle.to_tensor([0.0]) for k in total_loss.keys()} + ) + + return total_loss + + @staticmethod + def get_dn_match_indices(labels, dn_positive_idx, dn_num_group): + dn_match_indices = [] + for i in range(len(labels)): + num_gt = len(labels[i]) + if num_gt > 0: + gt_idx = paddle.arange(end=num_gt, dtype="int64") + gt_idx = gt_idx.tile([dn_num_group]) + assert len(dn_positive_idx[i]) == len(gt_idx) + dn_match_indices.append((dn_positive_idx[i], gt_idx)) + else: + dn_match_indices.append( + (paddle.zeros([0], dtype="int64"), paddle.zeros([0], dtype="int64")) + ) + return dn_match_indices + + +class MaskDINOLoss(DETRLoss): + __shared__ = ["num_classes", "use_focal_loss", "num_sample_points"] + __inject__ = ["matcher"] + + def __init__( + self, + num_classes=80, + matcher="HungarianMatcher", + loss_coeff={"class": 4, "bbox": 5, "giou": 2, "mask": 5, "dice": 5}, + aux_loss=True, + use_focal_loss=False, + use_vfl=False, + vfl_iou_type="bbox", + num_sample_points=12544, + oversample_ratio=3.0, + important_sample_ratio=0.75, + ): + super(MaskDINOLoss, self).__init__( + num_classes, + matcher, + loss_coeff, + aux_loss, + use_focal_loss, + use_vfl, + vfl_iou_type, + ) + assert oversample_ratio >= 1 + assert important_sample_ratio <= 1 and important_sample_ratio >= 0 + + self.num_sample_points = num_sample_points + self.oversample_ratio = oversample_ratio + self.important_sample_ratio = important_sample_ratio + self.num_oversample_points = int(num_sample_points * oversample_ratio) + self.num_important_points = int(num_sample_points * important_sample_ratio) + self.num_random_points = num_sample_points - self.num_important_points + + def forward( + self, + boxes, + logits, + gt_bbox, + gt_class, + masks=None, + gt_mask=None, + postfix="", + dn_out_bboxes=None, + dn_out_logits=None, + dn_out_masks=None, + dn_meta=None, + **kwargs + ): + num_gts = self._get_num_gts(gt_class) + total_loss = super(MaskDINOLoss, self).forward( + boxes, + logits, + gt_bbox, + gt_class, + masks=masks, + gt_mask=gt_mask, + num_gts=num_gts, + ) + + if dn_meta is not None: + dn_positive_idx, dn_num_group = ( + dn_meta["dn_positive_idx"], + dn_meta["dn_num_group"], + ) + assert len(gt_class) == len(dn_positive_idx) + + # denoising match indices + dn_match_indices = DINOLoss.get_dn_match_indices( + gt_class, dn_positive_idx, dn_num_group + ) + + # compute denoising training loss + num_gts *= dn_num_group + dn_loss = super(MaskDINOLoss, self).forward( + dn_out_bboxes, + dn_out_logits, + gt_bbox, + gt_class, + masks=dn_out_masks, + gt_mask=gt_mask, + postfix="_dn", + dn_match_indices=dn_match_indices, + num_gts=num_gts, + ) + total_loss.update(dn_loss) + else: + total_loss.update( + {k + "_dn": paddle.to_tensor([0.0]) for k in total_loss.keys()} + ) + + return total_loss + + def _get_loss_mask(self, masks, gt_mask, match_indices, num_gts, postfix=""): + # masks: [b, query, h, w], gt_mask: list[[n, H, W]] + name_mask = "loss_mask" + postfix + name_dice = "loss_dice" + postfix + + loss = dict() + if sum(len(a) for a in gt_mask) == 0: + loss[name_mask] = paddle.to_tensor([0.0]) + loss[name_dice] = paddle.to_tensor([0.0]) + return loss + + src_masks, target_masks = self._get_src_target_assign( + masks, gt_mask, match_indices + ) + # sample points + sample_points = self._get_point_coords_by_uncertainty(src_masks) + sample_points = 2.0 * sample_points.unsqueeze(1) - 1.0 + + src_masks = F.grid_sample( + src_masks.unsqueeze(1), sample_points, align_corners=False + ).squeeze([1, 2]) + + target_masks = ( + F.grid_sample(target_masks.unsqueeze(1), sample_points, align_corners=False) + .squeeze([1, 2]) + .detach() + ) + + loss[name_mask] = ( + self.loss_coeff["mask"] + * F.binary_cross_entropy_with_logits( + src_masks, target_masks, reduction="none" + ) + .mean(1) + .sum() + / num_gts + ) + loss[name_dice] = self.loss_coeff["dice"] * self._dice_loss( + src_masks, target_masks, num_gts + ) + return loss + + def _get_point_coords_by_uncertainty(self, masks): + # Sample points based on their uncertainty. + masks = masks.detach() + num_masks = masks.shape[0] + sample_points = paddle.rand([num_masks, 1, self.num_oversample_points, 2]) + + out_mask = F.grid_sample( + masks.unsqueeze(1), 2.0 * sample_points - 1.0, align_corners=False + ).squeeze([1, 2]) + out_mask = -paddle.abs(out_mask) + + _, topk_ind = paddle.topk(out_mask, self.num_important_points, axis=1) + batch_ind = paddle.arange(end=num_masks, dtype=topk_ind.dtype) + batch_ind = batch_ind.unsqueeze(-1).tile([1, self.num_important_points]) + topk_ind = paddle.stack([batch_ind, topk_ind], axis=-1) + + sample_points = paddle.gather_nd(sample_points.squeeze(1), topk_ind) + if self.num_random_points > 0: + sample_points = paddle.concat( + [sample_points, paddle.rand([num_masks, self.num_random_points, 2])], + axis=1, + ) + return sample_points diff --git a/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/detr_ops.py b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/detr_ops.py new file mode 100644 index 0000000000..9b9e73b496 --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/detr_ops.py @@ -0,0 +1,174 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +from collections import namedtuple + +import paddle +import paddle.nn as nn + + +class ShapeSpec(namedtuple("_ShapeSpec", ["channels", "height", "width", "stride"])): + def __new__(cls, channels=None, height=None, width=None, stride=None): + return super(ShapeSpec, cls).__new__(cls, channels, height, width, stride) + + +def delta2bbox(deltas, boxes, weights=[1.0, 1.0, 1.0, 1.0], max_shape=None): + """Decode deltas to boxes. Used in RCNNBox,CascadeHead,RCNNHead,RetinaHead. + Note: return tensor shape [n,1,4] + If you want to add a reshape, please add after the calling code instead of here. + """ + clip_scale = math.log(1000.0 / 16) + + widths = boxes[:, 2] - boxes[:, 0] + heights = boxes[:, 3] - boxes[:, 1] + ctr_x = boxes[:, 0] + 0.5 * widths + ctr_y = boxes[:, 1] + 0.5 * heights + + wx, wy, ww, wh = weights + dx = deltas[:, 0::4] / wx + dy = deltas[:, 1::4] / wy + dw = deltas[:, 2::4] / ww + dh = deltas[:, 3::4] / wh + # Prevent sending too large values into paddle.exp() + dw = paddle.clip(dw, max=clip_scale) + dh = paddle.clip(dh, max=clip_scale) + + pred_ctr_x = dx * widths.unsqueeze(1) + ctr_x.unsqueeze(1) + pred_ctr_y = dy * heights.unsqueeze(1) + ctr_y.unsqueeze(1) + pred_w = paddle.exp(dw) * widths.unsqueeze(1) + pred_h = paddle.exp(dh) * heights.unsqueeze(1) + + pred_boxes = [] + pred_boxes.append(pred_ctr_x - 0.5 * pred_w) + pred_boxes.append(pred_ctr_y - 0.5 * pred_h) + pred_boxes.append(pred_ctr_x + 0.5 * pred_w) + pred_boxes.append(pred_ctr_y + 0.5 * pred_h) + pred_boxes = paddle.stack(pred_boxes, axis=-1) + + if max_shape is not None: + pred_boxes[..., 0::2] = pred_boxes[..., 0::2].clip(min=0, max=max_shape[1]) + pred_boxes[..., 1::2] = pred_boxes[..., 1::2].clip(min=0, max=max_shape[0]) + return pred_boxes + + +def _get_clones(module, N): + return nn.LayerList([copy.deepcopy(module) for _ in range(N)]) + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clip(min=0.0, max=1.0) + return paddle.log(x.clip(min=eps) / (1 - x).clip(min=eps)) + + +def get_valid_ratio(mask): + _, H, W = mask.shape + valid_ratio_h = paddle.sum(mask[:, :, 0], 1) / H + valid_ratio_w = paddle.sum(mask[:, 0, :], 1) / W + # [b, 2] + return paddle.stack([valid_ratio_w, valid_ratio_h], -1) + + +def get_sine_pos_embed( + pos_tensor, num_pos_feats=128, temperature=10000, exchange_xy=True +): + """generate sine position embedding from a position tensor + + Args: + pos_tensor (Tensor): Shape as `(None, n)`. + num_pos_feats (int): projected shape for each float in the tensor. Default: 128 + temperature (int): The temperature used for scaling + the position embedding. Default: 10000. + exchange_xy (bool, optional): exchange pos x and pos y. \ + For example, input tensor is `[x, y]`, the results will # noqa + be `[pos(y), pos(x)]`. Defaults: True. + + Returns: + Tensor: Returned position embedding # noqa + with shape `(None, n * num_pos_feats)`. + """ + scale = 2.0 * math.pi + dim_t = 2.0 * paddle.floor_divide(paddle.arange(num_pos_feats), paddle.to_tensor(2)) + dim_t = scale / temperature ** (dim_t / num_pos_feats) + + def sine_func(x): + x *= dim_t + return paddle.stack((x[:, :, 0::2].sin(), x[:, :, 1::2].cos()), axis=3).flatten( + 2 + ) + + pos_res = [sine_func(x) for x in pos_tensor.split(pos_tensor.shape[-1], -1)] + if exchange_xy: + pos_res[0], pos_res[1] = pos_res[1], pos_res[0] + pos_res = paddle.concat(pos_res, axis=2) + return pos_res + + +def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9): + """calculate the iou of box1 and box2 + + Args: + box1 (list): [x, y, w, h], all have the shape [b, na, h, w, 1] + box2 (list): [x, y, w, h], all have the shape [b, na, h, w, 1] + giou (bool): whether use giou or not, default False + diou (bool): whether use diou or not, default False + ciou (bool): whether use ciou or not, default False + eps (float): epsilon to avoid divide by zero + + Return: + iou (Tensor): iou of box1 and box1, with the shape [b, na, h, w, 1] + """ + px1, py1, px2, py2 = box1 + gx1, gy1, gx2, gy2 = box2 + x1 = paddle.maximum(px1, gx1) + y1 = paddle.maximum(py1, gy1) + x2 = paddle.minimum(px2, gx2) + y2 = paddle.minimum(py2, gy2) + + overlap = ((x2 - x1).clip(0)) * ((y2 - y1).clip(0)) + + area1 = (px2 - px1) * (py2 - py1) + area1 = area1.clip(0) + + area2 = (gx2 - gx1) * (gy2 - gy1) + area2 = area2.clip(0) + + union = area1 + area2 - overlap + eps + iou = overlap / union + + if giou or ciou or diou: + # convex w, h + cw = paddle.maximum(px2, gx2) - paddle.minimum(px1, gx1) + ch = paddle.maximum(py2, gy2) - paddle.minimum(py1, gy1) + if giou: + c_area = cw * ch + eps + return iou - (c_area - union) / c_area + else: + # convex diagonal squared + c2 = cw**2 + ch**2 + eps + # center distance + rho2 = ((px1 + px2 - gx1 - gx2) ** 2 + (py1 + py2 - gy1 - gy2) ** 2) / 4 + if diou: + return iou - rho2 / c2 + else: + w1, h1 = px2 - px1, py2 - py1 + eps + w2, h2 = gx2 - gx1, gy2 - gy1 + eps + delta = paddle.atan(w1 / h1) - paddle.atan(w2 / h2) + v = (4 / math.pi**2) * paddle.pow(delta, 2) + alpha = v / (1 + eps - iou + v) + alpha.stop_gradient = True + return iou - (rho2 / c2 + v * alpha) + else: + return iou diff --git a/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/initializer.py b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/initializer.py new file mode 100644 index 0000000000..edc3642096 --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/initializer.py @@ -0,0 +1,330 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is based on https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py +Ths copyright of pytorch/pytorch is a BSD-style license, as found in the LICENSE file. +""" + +import math + +import numpy as np +import paddle +import paddle.nn as nn + +__all__ = [ + "uniform_", + "normal_", + "constant_", + "ones_", + "zeros_", + "xavier_uniform_", + "xavier_normal_", + "kaiming_uniform_", + "kaiming_normal_", + "linear_init_", + "conv_init_", + "reset_initialized_parameter", +] + + +def _no_grad_uniform_(tensor, a, b): + with paddle.no_grad(): + tensor.set_value( + paddle.uniform(shape=tensor.shape, dtype=tensor.dtype, min=a, max=b) + ) + return tensor + + +def _no_grad_normal_(tensor, mean=0.0, std=1.0): + with paddle.no_grad(): + tensor.set_value(paddle.normal(mean=mean, std=std, shape=tensor.shape)) + return tensor + + +def _no_grad_fill_(tensor, value=0.0): + with paddle.no_grad(): + tensor.set_value(paddle.full_like(tensor, value, dtype=tensor.dtype)) + return tensor + + +def uniform_(tensor, a, b): + """ + Modified tensor inspace using uniform_ + Args: + tensor (paddle.Tensor): paddle Tensor + a (float|int): min value. + b (float|int): max value. + Return: + tensor + """ + return _no_grad_uniform_(tensor, a, b) + + +def normal_(tensor, mean=0.0, std=1.0): + """ + Modified tensor inspace using normal_ + Args: + tensor (paddle.Tensor): paddle Tensor + mean (float|int): mean value. + std (float|int): std value. + Return: + tensor + """ + return _no_grad_normal_(tensor, mean, std) + + +def constant_(tensor, value=0.0): + """ + Modified tensor inspace using constant_ + Args: + tensor (paddle.Tensor): paddle Tensor + value (float|int): value to fill tensor. + Return: + tensor + """ + return _no_grad_fill_(tensor, value) + + +def ones_(tensor): + """ + Modified tensor inspace using ones_ + Args: + tensor (paddle.Tensor): paddle Tensor + Return: + tensor + """ + return _no_grad_fill_(tensor, 1) + + +def zeros_(tensor): + """ + Modified tensor inspace using zeros_ + Args: + tensor (paddle.Tensor): paddle Tensor + Return: + tensor + """ + return _no_grad_fill_(tensor, 0) + + +def vector_(tensor, vector): + with paddle.no_grad(): + tensor.set_value(paddle.to_tensor(vector, dtype=tensor.dtype)) + return tensor + + +def _calculate_fan_in_and_fan_out(tensor, reverse=False): + """ + Calculate (fan_in, _fan_out) for tensor + + Args: + tensor (Tensor): paddle.Tensor + reverse (bool: False): tensor data format order, False by default as [fout, fin, ...]. e.g. : conv.weight [cout, cin, kh, kw] is False; linear.weight [cin, cout] is True + + Return: + Tuple[fan_in, fan_out] + """ + if tensor.ndim < 2: + raise ValueError( + "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" + ) + + if reverse: + num_input_fmaps, num_output_fmaps = tensor.shape[0], tensor.shape[1] + else: + num_input_fmaps, num_output_fmaps = tensor.shape[1], tensor.shape[0] + + receptive_field_size = 1 + if tensor.ndim > 2: + receptive_field_size = np.prod(tensor.shape[2:]) + + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + +def xavier_uniform_(tensor, gain=1.0, reverse=False): + """ + Modified tensor inspace using xavier_uniform_ + Args: + tensor (paddle.Tensor): paddle Tensor + gain (float): super parameter, 1. default. + reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...]. + Return: + tensor + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse=reverse) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + k = math.sqrt(3.0) * std + return _no_grad_uniform_(tensor, -k, k) + + +def xavier_normal_(tensor, gain=1.0, reverse=False): + """ + Modified tensor inspace using xavier_normal_ + Args: + tensor (paddle.Tensor): paddle Tensor + gain (float): super parameter, 1. default. + reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...]. + Return: + tensor + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse=reverse) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + return _no_grad_normal_(tensor, 0, std) + + +# reference: https://pytorch.org/docs/stable/_modules/torch/nn/init.html +def _calculate_correct_fan(tensor, mode, reverse=False): + mode = mode.lower() + valid_modes = ["fan_in", "fan_out"] + if mode not in valid_modes: + raise ValueError( + "Mode {} not supported, please use one of {}".format(mode, valid_modes) + ) + + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse) + + return fan_in if mode == "fan_in" else fan_out + + +def _calculate_gain(nonlinearity, param=None): + linear_fns = [ + "linear", + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + ] + if nonlinearity in linear_fns or nonlinearity == "sigmoid": + return 1 + elif nonlinearity == "tanh": + return 5.0 / 3 + elif nonlinearity == "relu": + return math.sqrt(2.0) + elif nonlinearity == "leaky_relu": + if param is None: + negative_slope = 0.01 + elif ( + not isinstance(param, bool) + and isinstance(param, int) + or isinstance(param, float) + ): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + return math.sqrt(2.0 / (1 + negative_slope**2)) + elif nonlinearity == "selu": + return 3.0 / 4 + else: + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + + +def kaiming_uniform_( + tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", reverse=False +): + """ + Modified tensor inspace using kaiming_uniform method + Args: + tensor (paddle.Tensor): paddle Tensor + mode (str): ['fan_in', 'fan_out'], 'fin_in' defalut + nonlinearity (str): nonlinearity method name + reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...]. + Return: + tensor + """ + fan = _calculate_correct_fan(tensor, mode, reverse) + gain = _calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + k = math.sqrt(3.0) * std + return _no_grad_uniform_(tensor, -k, k) + + +def kaiming_normal_( + tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", reverse=False +): + """ + Modified tensor inspace using kaiming_normal_ + Args: + tensor (paddle.Tensor): paddle Tensor + mode (str): ['fan_in', 'fan_out'], 'fin_in' defalut + nonlinearity (str): nonlinearity method name + reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...]. + Return: + tensor + """ + fan = _calculate_correct_fan(tensor, mode, reverse) + gain = _calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + return _no_grad_normal_(tensor, 0, std) + + +def linear_init_(module): + bound = 1 / math.sqrt(module.weight.shape[0]) + uniform_(module.weight, -bound, bound) + if hasattr(module, "bias") and module.bias is not None: + uniform_(module.bias, -bound, bound) + + +def conv_init_(module): + bound = 1 / np.sqrt(np.prod(module.weight.shape[1:])) + uniform_(module.weight, -bound, bound) + if module.bias is not None: + uniform_(module.bias, -bound, bound) + + +def bias_init_with_prob(prior_prob=0.01): + """initialize conv/fc bias value according to a given probability value.""" + bias_init = float(-np.log((1 - prior_prob) / prior_prob)) + return bias_init + + +@paddle.no_grad() +def reset_initialized_parameter(model, include_self=True): + """ + Reset initialized parameter using following method for [conv, linear, embedding, bn] + + Args: + model (paddle.Layer): paddle Layer + include_self (bool: False): include_self for Layer.named_sublayers method. Indicate whether including itself + Return: + None + """ + for _, m in model.named_sublayers(include_self=include_self): + if isinstance(m, nn.Conv2D): + k = float(m._groups) / ( + m._in_channels * m._kernel_size[0] * m._kernel_size[1] + ) + k = math.sqrt(k) + _no_grad_uniform_(m.weight, -k, k) + if hasattr(m, "bias") and getattr(m, "bias") is not None: + _no_grad_uniform_(m.bias, -k, k) + + elif isinstance(m, nn.Linear): + k = math.sqrt(1.0 / m.weight.shape[0]) + _no_grad_uniform_(m.weight, -k, k) + if hasattr(m, "bias") and getattr(m, "bias") is not None: + _no_grad_uniform_(m.bias, -k, k) + + elif isinstance(m, nn.Embedding): + _no_grad_normal_(m.weight, mean=0.0, std=1.0) + + elif isinstance(m, (nn.BatchNorm2D, nn.LayerNorm)): + _no_grad_fill_(m.weight, 1.0) + if hasattr(m, "bias") and getattr(m, "bias") is not None: + _no_grad_fill_(m.bias, 0) diff --git a/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/layers.py b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/layers.py new file mode 100644 index 0000000000..97241d41b2 --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/layers.py @@ -0,0 +1,1392 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import six +from paddle import ParamAttr +from paddle.nn.initializer import Constant, Normal, XavierUniform +from paddle.regularizer import L2Decay +from paddle.vision.ops import DeformConv2D + +from . import ops +from .detr_ops import delta2bbox +from .initializer import constant_, xavier_uniform_ + + +def _to_list(l): + if isinstance(l, (list, tuple)): + return list(l) + return [l] + + +class AlignConv(nn.Layer): + def __init__(self, in_channels, out_channels, kernel_size=3, groups=1): + super(AlignConv, self).__init__() + self.kernel_size = kernel_size + self.align_conv = paddle.vision.ops.DeformConv2D( + in_channels, + out_channels, + kernel_size=self.kernel_size, + padding=(self.kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(initializer=Normal(0, 0.01)), + bias_attr=None, + ) + + @paddle.no_grad() + def get_offset(self, anchors, featmap_size, stride): + """ + Args: + anchors: [B, L, 5] xc,yc,w,h,angle + featmap_size: (feat_h, feat_w) + stride: 8 + Returns: + + """ + batch = anchors.shape[0] + dtype = anchors.dtype + feat_h, feat_w = featmap_size + pad = (self.kernel_size - 1) // 2 + idx = paddle.arange(-pad, pad + 1, dtype=dtype) + + yy, xx = paddle.meshgrid(idx, idx) + xx = paddle.reshape(xx, [-1]) + yy = paddle.reshape(yy, [-1]) + + # get sampling locations of default conv + xc = paddle.arange(0, feat_w, dtype=dtype) + yc = paddle.arange(0, feat_h, dtype=dtype) + yc, xc = paddle.meshgrid(yc, xc) + + xc = paddle.reshape(xc, [-1, 1]) + yc = paddle.reshape(yc, [-1, 1]) + x_conv = xc + xx + y_conv = yc + yy + + # get sampling locations of anchors + x_ctr, y_ctr, w, h, a = paddle.split(anchors, 5, axis=-1) + x_ctr = x_ctr / stride + y_ctr = y_ctr / stride + w_s = w / stride + h_s = h / stride + cos, sin = paddle.cos(a), paddle.sin(a) + dw, dh = w_s / self.kernel_size, h_s / self.kernel_size + x, y = dw * xx, dh * yy + xr = cos * x - sin * y + yr = sin * x + cos * y + x_anchor, y_anchor = xr + x_ctr, yr + y_ctr + # get offset filed + offset_x = x_anchor - x_conv + offset_y = y_anchor - y_conv + offset = paddle.stack([offset_y, offset_x], axis=-1) + offset = offset.reshape( + [batch, feat_h, feat_w, self.kernel_size * self.kernel_size * 2] + ) + offset = offset.transpose([0, 3, 1, 2]) + + return offset + + def forward(self, x, refine_anchors, featmap_size, stride): + batch = x.shape[0].numpy() + offset = self.get_offset(refine_anchors, featmap_size, stride) + if self.training: + x = F.relu(self.align_conv(x, offset.detach())) + else: + x = F.relu(self.align_conv(x, offset)) + return x + + +class DeformableConvV2(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + weight_attr=None, + bias_attr=None, + lr_scale=1, + regularizer=None, + skip_quant=False, + dcn_bias_regularizer=L2Decay(0.0), + dcn_bias_lr_scale=2.0, + ): + super(DeformableConvV2, self).__init__() + self.offset_channel = 2 * kernel_size**2 + self.mask_channel = kernel_size**2 + + if lr_scale == 1 and regularizer is None: + offset_bias_attr = ParamAttr(initializer=Constant(0.0)) + else: + offset_bias_attr = ParamAttr( + initializer=Constant(0.0), + learning_rate=lr_scale, + regularizer=regularizer, + ) + self.conv_offset = nn.Conv2D( + in_channels, + 3 * kernel_size**2, + kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + weight_attr=ParamAttr(initializer=Constant(0.0)), + bias_attr=offset_bias_attr, + ) + if skip_quant: + self.conv_offset.skip_quant = True + + if bias_attr: + # in FCOS-DCN head, specifically need learning_rate and regularizer + dcn_bias_attr = ParamAttr( + initializer=Constant(value=0), + regularizer=dcn_bias_regularizer, + learning_rate=dcn_bias_lr_scale, + ) + else: + # in ResNet backbone, do not need bias + dcn_bias_attr = False + self.conv_dcn = DeformConv2D( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2 * dilation, + dilation=dilation, + groups=groups, + weight_attr=weight_attr, + bias_attr=dcn_bias_attr, + ) + + def forward(self, x): + offset_mask = self.conv_offset(x) + offset, mask = paddle.split( + offset_mask, + num_or_sections=[self.offset_channel, self.mask_channel], + axis=1, + ) + mask = F.sigmoid(mask) + y = self.conv_dcn(x, offset, mask=mask) + return y + + +class ConvNormLayer(nn.Layer): + def __init__( + self, + ch_in, + ch_out, + filter_size, + stride, + groups=1, + norm_type="bn", + norm_decay=0.0, + norm_groups=32, + use_dcn=False, + bias_on=False, + lr_scale=1.0, + freeze_norm=False, + initializer=Normal(mean=0.0, std=0.01), + skip_quant=False, + dcn_lr_scale=2.0, + dcn_regularizer=L2Decay(0.0), + ): + super(ConvNormLayer, self).__init__() + assert norm_type in ["bn", "sync_bn", "gn", None] + + if bias_on: + bias_attr = ParamAttr( + initializer=Constant(value=0.0), learning_rate=lr_scale + ) + else: + bias_attr = False + + if not use_dcn: + self.conv = nn.Conv2D( + in_channels=ch_in, + out_channels=ch_out, + kernel_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(initializer=initializer, learning_rate=1.0), + bias_attr=bias_attr, + ) + if skip_quant: + self.conv.skip_quant = True + else: + # in FCOS-DCN head, specifically need learning_rate and regularizer + self.conv = DeformableConvV2( + in_channels=ch_in, + out_channels=ch_out, + kernel_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(initializer=initializer, learning_rate=1.0), + bias_attr=True, + lr_scale=dcn_lr_scale, + regularizer=dcn_regularizer, + dcn_bias_regularizer=dcn_regularizer, + dcn_bias_lr_scale=dcn_lr_scale, + skip_quant=skip_quant, + ) + + norm_lr = 0.0 if freeze_norm else 1.0 + param_attr = ParamAttr( + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay) if norm_decay is not None else None, + ) + bias_attr = ParamAttr( + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay) if norm_decay is not None else None, + ) + if norm_type in ["bn", "sync_bn"]: + self.norm = nn.BatchNorm2D( + ch_out, weight_attr=param_attr, bias_attr=bias_attr + ) + elif norm_type == "gn": + self.norm = nn.GroupNorm( + num_groups=norm_groups, + num_channels=ch_out, + weight_attr=param_attr, + bias_attr=bias_attr, + ) + else: + self.norm = None + + def forward(self, inputs): + out = self.conv(inputs) + if self.norm is not None: + out = self.norm(out) + return out + + +class LiteConv(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + stride=1, + with_act=True, + norm_type="sync_bn", + name=None, + ): + super(LiteConv, self).__init__() + self.lite_conv = nn.Sequential() + conv1 = ConvNormLayer( + in_channels, + in_channels, + filter_size=5, + stride=stride, + groups=in_channels, + norm_type=norm_type, + initializer=XavierUniform(), + ) + conv2 = ConvNormLayer( + in_channels, + out_channels, + filter_size=1, + stride=stride, + norm_type=norm_type, + initializer=XavierUniform(), + ) + conv3 = ConvNormLayer( + out_channels, + out_channels, + filter_size=1, + stride=stride, + norm_type=norm_type, + initializer=XavierUniform(), + ) + conv4 = ConvNormLayer( + out_channels, + out_channels, + filter_size=5, + stride=stride, + groups=out_channels, + norm_type=norm_type, + initializer=XavierUniform(), + ) + conv_list = [conv1, conv2, conv3, conv4] + self.lite_conv.add_sublayer("conv1", conv1) + self.lite_conv.add_sublayer("relu6_1", nn.ReLU6()) + self.lite_conv.add_sublayer("conv2", conv2) + if with_act: + self.lite_conv.add_sublayer("relu6_2", nn.ReLU6()) + self.lite_conv.add_sublayer("conv3", conv3) + self.lite_conv.add_sublayer("relu6_3", nn.ReLU6()) + self.lite_conv.add_sublayer("conv4", conv4) + if with_act: + self.lite_conv.add_sublayer("relu6_4", nn.ReLU6()) + + def forward(self, inputs): + out = self.lite_conv(inputs) + return out + + +class DropBlock(nn.Layer): + def __init__(self, block_size, keep_prob, name=None, data_format="NCHW"): + """ + DropBlock layer, see https://arxiv.org/abs/1810.12890 + + Args: + block_size (int): block size + keep_prob (int): keep probability + name (str): layer name + data_format (str): data format, NCHW or NHWC + """ + super(DropBlock, self).__init__() + self.block_size = block_size + self.keep_prob = keep_prob + self.name = name + self.data_format = data_format + + def forward(self, x): + if not self.training or self.keep_prob == 1: + return x + else: + gamma = (1.0 - self.keep_prob) / (self.block_size**2) + if self.data_format == "NCHW": + shape = x.shape[2:] + else: + shape = x.shape[1:3] + for s in shape: + gamma *= s / (s - self.block_size + 1) + + matrix = paddle.cast(paddle.rand(x.shape) < gamma, x.dtype) + mask_inv = F.max_pool2d( + matrix, + self.block_size, + stride=1, + padding=self.block_size // 2, + data_format=self.data_format, + ) + mask = 1.0 - mask_inv + mask = mask.astype("float32") + x = x.astype("float32") + y = x * mask * (mask.numel() / mask.sum()) + return y + + +class AnchorGeneratorSSD(object): + def __init__( + self, + steps=[8, 16, 32, 64, 100, 300], + aspect_ratios=[[2.0], [2.0, 3.0], [2.0, 3.0], [2.0, 3.0], [2.0], [2.0]], + min_ratio=15, + max_ratio=90, + base_size=300, + min_sizes=[30.0, 60.0, 111.0, 162.0, 213.0, 264.0], + max_sizes=[60.0, 111.0, 162.0, 213.0, 264.0, 315.0], + offset=0.5, + flip=True, + clip=False, + min_max_aspect_ratios_order=False, + ): + self.steps = steps + self.aspect_ratios = aspect_ratios + self.min_ratio = min_ratio + self.max_ratio = max_ratio + self.base_size = base_size + self.min_sizes = min_sizes + self.max_sizes = max_sizes + self.offset = offset + self.flip = flip + self.clip = clip + self.min_max_aspect_ratios_order = min_max_aspect_ratios_order + + if self.min_sizes == [] and self.max_sizes == []: + num_layer = len(aspect_ratios) + step = int( + math.floor(((self.max_ratio - self.min_ratio)) / (num_layer - 2)) + ) + for ratio in six.moves.range(self.min_ratio, self.max_ratio + 1, step): + self.min_sizes.append(self.base_size * ratio / 100.0) + self.max_sizes.append(self.base_size * (ratio + step) / 100.0) + self.min_sizes = [self.base_size * 0.10] + self.min_sizes + self.max_sizes = [self.base_size * 0.20] + self.max_sizes + + self.num_priors = [] + for aspect_ratio, min_size, max_size in zip( + aspect_ratios, self.min_sizes, self.max_sizes + ): + if isinstance(min_size, (list, tuple)): + self.num_priors.append( + len(_to_list(min_size)) + len(_to_list(max_size)) + ) + else: + self.num_priors.append( + (len(aspect_ratio) * 2 + 1) * len(_to_list(min_size)) + + len(_to_list(max_size)) + ) + + def __call__(self, inputs, image): + boxes = [] + for input, min_size, max_size, aspect_ratio, step in zip( + inputs, self.min_sizes, self.max_sizes, self.aspect_ratios, self.steps + ): + box, _ = ops.prior_box( + input=input, + image=image, + min_sizes=_to_list(min_size), + max_sizes=_to_list(max_size), + aspect_ratios=aspect_ratio, + flip=self.flip, + clip=self.clip, + steps=[step, step], + offset=self.offset, + min_max_aspect_ratios_order=self.min_max_aspect_ratios_order, + ) + boxes.append(paddle.reshape(box, [-1, 4])) + return boxes + + +class RCNNBox(object): + __shared__ = ["num_classes", "export_onnx"] + + def __init__( + self, + prior_box_var=[10.0, 10.0, 5.0, 5.0], + code_type="decode_center_size", + box_normalized=False, + num_classes=80, + export_onnx=False, + ): + super(RCNNBox, self).__init__() + self.prior_box_var = prior_box_var + self.code_type = code_type + self.box_normalized = box_normalized + self.num_classes = num_classes + self.export_onnx = export_onnx + + def __call__(self, bbox_head_out, rois, im_shape, scale_factor): + bbox_pred = bbox_head_out[0] + cls_prob = bbox_head_out[1] + roi = rois[0] + rois_num = rois[1] + + if self.export_onnx: + onnx_rois_num_per_im = rois_num[0] + origin_shape = paddle.expand(im_shape[0, :], [onnx_rois_num_per_im, 2]) + + else: + origin_shape_list = [] + if isinstance(roi, list): + batch_size = len(roi) + else: + batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1]) + + # bbox_pred.shape: [N, C*4] + for idx in range(batch_size): + rois_num_per_im = rois_num[idx] + expand_im_shape = paddle.expand(im_shape[idx, :], [rois_num_per_im, 2]) + origin_shape_list.append(expand_im_shape) + + origin_shape = paddle.concat(origin_shape_list) + + # bbox_pred.shape: [N, C*4] + # C=num_classes in faster/mask rcnn(bbox_head), C=1 in cascade rcnn(cascade_head) + bbox = paddle.concat(roi) + bbox = delta2bbox(bbox_pred, bbox, self.prior_box_var) + scores = cls_prob[:, :-1] + + # bbox.shape: [N, C, 4] + # bbox.shape[1] must be equal to scores.shape[1] + total_num = bbox.shape[0] + bbox_dim = bbox.shape[-1] + bbox = paddle.expand(bbox, [total_num, self.num_classes, bbox_dim]) + + origin_h = paddle.unsqueeze(origin_shape[:, 0], axis=1) + origin_w = paddle.unsqueeze(origin_shape[:, 1], axis=1) + zeros = paddle.zeros_like(origin_h) + x1 = paddle.maximum(paddle.minimum(bbox[:, :, 0], origin_w), zeros) + y1 = paddle.maximum(paddle.minimum(bbox[:, :, 1], origin_h), zeros) + x2 = paddle.maximum(paddle.minimum(bbox[:, :, 2], origin_w), zeros) + y2 = paddle.maximum(paddle.minimum(bbox[:, :, 3], origin_h), zeros) + bbox = paddle.stack([x1, y1, x2, y2], axis=-1) + bboxes = (bbox, rois_num) + return bboxes, scores + + +class MultiClassNMS(object): + def __init__( + self, + score_threshold=0.05, + nms_top_k=-1, + keep_top_k=100, + nms_threshold=0.5, + normalized=True, + nms_eta=1.0, + return_index=False, + return_rois_num=True, + trt=False, + cpu=False, + ): + super(MultiClassNMS, self).__init__() + self.score_threshold = score_threshold + self.nms_top_k = nms_top_k + self.keep_top_k = keep_top_k + self.nms_threshold = nms_threshold + self.normalized = normalized + self.nms_eta = nms_eta + self.return_index = return_index + self.return_rois_num = return_rois_num + self.trt = trt + self.cpu = cpu + + def __call__(self, bboxes, score, background_label=-1): + """ + bboxes (Tensor|List[Tensor]): 1. (Tensor) Predicted bboxes with shape + [N, M, 4], N is the batch size and M + is the number of bboxes + 2. (List[Tensor]) bboxes and bbox_num, + bboxes have shape of [M, C, 4], C + is the class number and bbox_num means + the number of bboxes of each batch with + shape [N,] + score (Tensor): Predicted scores with shape [N, C, M] or [M, C] + background_label (int): Ignore the background label; For example, RCNN + is num_classes and YOLO is -1. + """ + kwargs = self.__dict__.copy() + if isinstance(bboxes, tuple): + bboxes, bbox_num = bboxes + kwargs.update({"rois_num": bbox_num}) + if background_label > -1: + kwargs.update({"background_label": background_label}) + kwargs.pop("trt") + kwargs.pop("cpu") + + # TODO(wangxinxin08): paddle version should be develop or 2.3 and above to run nms on tensorrt + if self.trt and ( + int(paddle.version.major) == 0 + or (int(paddle.version.major) >= 2 and int(paddle.version.minor) >= 3) + ): + # TODO(wangxinxin08): tricky switch to run nms on tensorrt + kwargs.update({"nms_eta": 1.1}) + bbox, bbox_num, _ = ops.multiclass_nms(bboxes, score, **kwargs) + bbox = bbox.reshape([1, -1, 6]) + idx = paddle.nonzero(bbox[..., 0] != -1) + bbox = paddle.gather_nd(bbox, idx) + return bbox, bbox_num, None + else: + if self.cpu: + device = paddle.device.get_device() + paddle.set_device("cpu") + outputs = ops.multiclass_nms(bboxes, score, **kwargs) + paddle.set_device(device) + return outputs + else: + return ops.multiclass_nms(bboxes, score, **kwargs) + + +class MatrixNMS(object): + __append_doc__ = True + + def __init__( + self, + score_threshold=0.05, + post_threshold=0.05, + nms_top_k=-1, + keep_top_k=100, + use_gaussian=False, + gaussian_sigma=2.0, + normalized=False, + background_label=0, + ): + super(MatrixNMS, self).__init__() + self.score_threshold = score_threshold + self.post_threshold = post_threshold + self.nms_top_k = nms_top_k + self.keep_top_k = keep_top_k + self.normalized = normalized + self.use_gaussian = use_gaussian + self.gaussian_sigma = gaussian_sigma + self.background_label = background_label + + def __call__(self, bbox, score, *args): + return ops.matrix_nms( + bboxes=bbox, + scores=score, + score_threshold=self.score_threshold, + post_threshold=self.post_threshold, + nms_top_k=self.nms_top_k, + keep_top_k=self.keep_top_k, + use_gaussian=self.use_gaussian, + gaussian_sigma=self.gaussian_sigma, + background_label=self.background_label, + normalized=self.normalized, + ) + + +class YOLOBox(object): + __shared__ = ["num_classes"] + + def __init__( + self, + num_classes=80, + conf_thresh=0.005, + downsample_ratio=32, + clip_bbox=True, + scale_x_y=1.0, + ): + self.num_classes = num_classes + self.conf_thresh = conf_thresh + self.downsample_ratio = downsample_ratio + self.clip_bbox = clip_bbox + self.scale_x_y = scale_x_y + + def __call__(self, yolo_head_out, anchors, im_shape, scale_factor, var_weight=None): + boxes_list = [] + scores_list = [] + origin_shape = im_shape / scale_factor + origin_shape = paddle.cast(origin_shape, "int32") + for i, head_out in enumerate(yolo_head_out): + boxes, scores = paddle.vision.ops.yolo_box( + head_out, + origin_shape, + anchors[i], + self.num_classes, + self.conf_thresh, + self.downsample_ratio // 2**i, + self.clip_bbox, + scale_x_y=self.scale_x_y, + ) + boxes_list.append(boxes) + scores_list.append(paddle.transpose(scores, perm=[0, 2, 1])) + yolo_boxes = paddle.concat(boxes_list, axis=1) + yolo_scores = paddle.concat(scores_list, axis=2) + return yolo_boxes, yolo_scores + + +class SSDBox(object): + def __init__( + self, + is_normalized=True, + prior_box_var=[0.1, 0.1, 0.2, 0.2], + use_fuse_decode=False, + ): + self.is_normalized = is_normalized + self.norm_delta = float(not self.is_normalized) + self.prior_box_var = prior_box_var + self.use_fuse_decode = use_fuse_decode + + def __call__(self, preds, prior_boxes, im_shape, scale_factor, var_weight=None): + boxes, scores = preds + boxes = paddle.concat(boxes, axis=1) + prior_boxes = paddle.concat(prior_boxes) + if self.use_fuse_decode: + output_boxes = ops.box_coder( + prior_boxes, + self.prior_box_var, + boxes, + code_type="decode_center_size", + box_normalized=self.is_normalized, + ) + else: + pb_w = prior_boxes[:, 2] - prior_boxes[:, 0] + self.norm_delta + pb_h = prior_boxes[:, 3] - prior_boxes[:, 1] + self.norm_delta + pb_x = prior_boxes[:, 0] + pb_w * 0.5 + pb_y = prior_boxes[:, 1] + pb_h * 0.5 + out_x = pb_x + boxes[:, :, 0] * pb_w * self.prior_box_var[0] + out_y = pb_y + boxes[:, :, 1] * pb_h * self.prior_box_var[1] + out_w = paddle.exp(boxes[:, :, 2] * self.prior_box_var[2]) * pb_w + out_h = paddle.exp(boxes[:, :, 3] * self.prior_box_var[3]) * pb_h + output_boxes = paddle.stack( + [ + out_x - out_w / 2.0, + out_y - out_h / 2.0, + out_x + out_w / 2.0, + out_y + out_h / 2.0, + ], + axis=-1, + ) + + if self.is_normalized: + h = (im_shape[:, 0] / scale_factor[:, 0]).unsqueeze(-1) + w = (im_shape[:, 1] / scale_factor[:, 1]).unsqueeze(-1) + im_shape = paddle.stack([w, h, w, h], axis=-1) + output_boxes *= im_shape + else: + output_boxes[..., -2:] -= 1.0 + output_scores = F.softmax(paddle.concat(scores, axis=1)).transpose([0, 2, 1]) + + return output_boxes, output_scores + + +class TTFBox(object): + __shared__ = ["down_ratio"] + + def __init__(self, max_per_img=100, score_thresh=0.01, down_ratio=4): + super(TTFBox, self).__init__() + self.max_per_img = max_per_img + self.score_thresh = score_thresh + self.down_ratio = down_ratio + + def _simple_nms(self, heat, kernel=3): + """ + Use maxpool to filter the max score, get local peaks. + """ + pad = (kernel - 1) // 2 + hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad) + keep = paddle.cast(hmax == heat, "float32") + return heat * keep + + def _topk(self, scores): + """ + Select top k scores and decode to get xy coordinates. + """ + k = self.max_per_img + shape_fm = paddle.shape(scores) + shape_fm.stop_gradient = True + cat, height, width = shape_fm[1], shape_fm[2], shape_fm[3] + # batch size is 1 + scores_r = paddle.reshape(scores, [cat, -1]) + topk_scores, topk_inds = paddle.topk(scores_r, k) + topk_ys = topk_inds // width + topk_xs = topk_inds % width + + topk_score_r = paddle.reshape(topk_scores, [-1]) + topk_score, topk_ind = paddle.topk(topk_score_r, k) + k_t = paddle.full(topk_ind.shape, k, dtype="int64") + topk_clses = paddle.cast(paddle.floor_divide(topk_ind, k_t), "float32") + + topk_inds = paddle.reshape(topk_inds, [-1]) + topk_ys = paddle.reshape(topk_ys, [-1, 1]) + topk_xs = paddle.reshape(topk_xs, [-1, 1]) + topk_inds = paddle.gather(topk_inds, topk_ind) + topk_ys = paddle.gather(topk_ys, topk_ind) + topk_xs = paddle.gather(topk_xs, topk_ind) + + return topk_score, topk_inds, topk_clses, topk_ys, topk_xs + + def _decode(self, hm, wh, im_shape, scale_factor): + heatmap = F.sigmoid(hm) + heat = self._simple_nms(heatmap) + scores, inds, clses, ys, xs = self._topk(heat) + ys = paddle.cast(ys, "float32") * self.down_ratio + xs = paddle.cast(xs, "float32") * self.down_ratio + scores = paddle.tensor.unsqueeze(scores, [1]) + clses = paddle.tensor.unsqueeze(clses, [1]) + + wh_t = paddle.transpose(wh, [0, 2, 3, 1]) + wh = paddle.reshape(wh_t, [-1, wh_t.shape[-1]]) + wh = paddle.gather(wh, inds) + + x1 = xs - wh[:, 0:1] + y1 = ys - wh[:, 1:2] + x2 = xs + wh[:, 2:3] + y2 = ys + wh[:, 3:4] + + bboxes = paddle.concat([x1, y1, x2, y2], axis=1) + + scale_y = scale_factor[:, 0:1] + scale_x = scale_factor[:, 1:2] + scale_expand = paddle.concat([scale_x, scale_y, scale_x, scale_y], axis=1) + boxes_shape = paddle.shape(bboxes) + boxes_shape.stop_gradient = True + scale_expand = paddle.expand(scale_expand, shape=boxes_shape) + bboxes = paddle.divide(bboxes, scale_expand) + results = paddle.concat([clses, scores, bboxes], axis=1) + # hack: append result with cls=-1 and score=1. to avoid all scores + # are less than score_thresh which may cause error in gather. + fill_r = paddle.to_tensor(np.array([[-1, 1, 0, 0, 0, 0]])) + fill_r = paddle.cast(fill_r, results.dtype) + results = paddle.concat([results, fill_r]) + scores = results[:, 1] + valid_ind = paddle.nonzero(scores > self.score_thresh) + results = paddle.gather(results, valid_ind) + return results, results.shape[0:1] + + def __call__(self, hm, wh, im_shape, scale_factor): + results = [] + results_num = [] + for i in range(scale_factor.shape[0]): + result, num = self._decode( + hm[i : i + 1,], + wh[i : i + 1,], + im_shape[i : i + 1,], + scale_factor[i : i + 1,], + ) + results.append(result) + results_num.append(num) + results = paddle.concat(results, axis=0) + results_num = paddle.concat(results_num, axis=0) + return results, results_num + + +class JDEBox(object): + __shared__ = ["num_classes"] + + def __init__(self, num_classes=1, conf_thresh=0.3, downsample_ratio=32): + self.num_classes = num_classes + self.conf_thresh = conf_thresh + self.downsample_ratio = downsample_ratio + + def generate_anchor(self, nGh, nGw, anchor_wh): + nA = len(anchor_wh) + yv, xv = paddle.meshgrid([paddle.arange(nGh), paddle.arange(nGw)]) + mesh = paddle.stack((xv, yv), axis=0).cast(dtype="float32") # 2 x nGh x nGw + meshs = paddle.tile(mesh, [nA, 1, 1, 1]) + + anchor_offset_mesh = ( + anchor_wh[:, :, None][:, :, :, None] + .repeat(int(nGh), axis=-2) + .repeat(int(nGw), axis=-1) + ) + anchor_offset_mesh = paddle.to_tensor(anchor_offset_mesh.astype(np.float32)) + # nA x 2 x nGh x nGw + + anchor_mesh = paddle.concat([meshs, anchor_offset_mesh], axis=1) + anchor_mesh = paddle.transpose( + anchor_mesh, [0, 2, 3, 1] + ) # (nA x nGh x nGw) x 4 + return anchor_mesh + + def decode_delta(self, delta, fg_anchor_list): + px, py, pw, ph = ( + fg_anchor_list[:, 0], + fg_anchor_list[:, 1], + fg_anchor_list[:, 2], + fg_anchor_list[:, 3], + ) + dx, dy, dw, dh = delta[:, 0], delta[:, 1], delta[:, 2], delta[:, 3] + gx = pw * dx + px + gy = ph * dy + py + gw = pw * paddle.exp(dw) + gh = ph * paddle.exp(dh) + gx1 = gx - gw * 0.5 + gy1 = gy - gh * 0.5 + gx2 = gx + gw * 0.5 + gy2 = gy + gh * 0.5 + return paddle.stack([gx1, gy1, gx2, gy2], axis=1) + + def decode_delta_map(self, nA, nGh, nGw, delta_map, anchor_vec): + anchor_mesh = self.generate_anchor(nGh, nGw, anchor_vec) + anchor_mesh = paddle.unsqueeze(anchor_mesh, 0) + pred_list = self.decode_delta( + paddle.reshape(delta_map, shape=[-1, 4]), + paddle.reshape(anchor_mesh, shape=[-1, 4]), + ) + pred_map = paddle.reshape(pred_list, shape=[nA * nGh * nGw, 4]) + return pred_map + + def _postprocessing_by_level(self, nA, stride, head_out, anchor_vec): + boxes_shape = head_out.shape # [nB, nA*6, nGh, nGw] + nGh, nGw = boxes_shape[-2], boxes_shape[-1] + nB = 1 # TODO: only support bs=1 now + boxes_list, scores_list = [], [] + for idx in range(nB): + p = paddle.reshape( + head_out[idx], shape=[nA, self.num_classes + 5, nGh, nGw] + ) + p = paddle.transpose(p, perm=[0, 2, 3, 1]) # [nA, nGh, nGw, 6] + delta_map = p[:, :, :, :4] + boxes = self.decode_delta_map(nA, nGh, nGw, delta_map, anchor_vec) + # [nA * nGh * nGw, 4] + boxes_list.append(boxes * stride) + + p_conf = paddle.transpose( + p[:, :, :, 4:6], perm=[3, 0, 1, 2] + ) # [2, nA, nGh, nGw] + p_conf = F.softmax(p_conf, axis=0)[1, :, :, :].unsqueeze( + -1 + ) # [nA, nGh, nGw, 1] + scores = paddle.reshape(p_conf, shape=[nA * nGh * nGw, 1]) + scores_list.append(scores) + + boxes_results = paddle.stack(boxes_list) + scores_results = paddle.stack(scores_list) + return boxes_results, scores_results + + def __call__(self, yolo_head_out, anchors): + bbox_pred_list = [] + for i, head_out in enumerate(yolo_head_out): + stride = self.downsample_ratio // 2**i + anc_w, anc_h = anchors[i][0::2], anchors[i][1::2] + anchor_vec = np.stack((anc_w, anc_h), axis=1) / stride + nA = len(anc_w) + boxes, scores = self._postprocessing_by_level( + nA, stride, head_out, anchor_vec + ) + bbox_pred_list.append(paddle.concat([boxes, scores], axis=-1)) + + yolo_boxes_scores = paddle.concat(bbox_pred_list, axis=1) + boxes_idx_over_conf_thr = paddle.nonzero( + yolo_boxes_scores[:, :, -1] > self.conf_thresh + ) + boxes_idx_over_conf_thr.stop_gradient = True + + return boxes_idx_over_conf_thr, yolo_boxes_scores + + +class MaskMatrixNMS(object): + """ + Matrix NMS for multi-class masks. + Args: + update_threshold (float): Updated threshold of categroy score in second time. + pre_nms_top_n (int): Number of total instance to be kept per image before NMS + post_nms_top_n (int): Number of total instance to be kept per image after NMS. + kernel (str): 'linear' or 'gaussian'. + sigma (float): std in gaussian method. + Input: + seg_preds (Variable): shape (n, h, w), segmentation feature maps + seg_masks (Variable): shape (n, h, w), segmentation feature maps + cate_labels (Variable): shape (n), mask labels in descending order + cate_scores (Variable): shape (n), mask scores in descending order + sum_masks (Variable): a float tensor of the sum of seg_masks + Returns: + Variable: cate_scores, tensors of shape (n) + """ + + def __init__( + self, + update_threshold=0.05, + pre_nms_top_n=500, + post_nms_top_n=100, + kernel="gaussian", + sigma=2.0, + ): + super(MaskMatrixNMS, self).__init__() + self.update_threshold = update_threshold + self.pre_nms_top_n = pre_nms_top_n + self.post_nms_top_n = post_nms_top_n + self.kernel = kernel + self.sigma = sigma + + def _sort_score(self, scores, top_num): + if scores.shape[0] > top_num: + return paddle.topk(scores, top_num)[1] + else: + return paddle.argsort(scores, descending=True) + + def __call__(self, seg_preds, seg_masks, cate_labels, cate_scores, sum_masks=None): + # sort and keep top nms_pre + sort_inds = self._sort_score(cate_scores, self.pre_nms_top_n) + seg_masks = paddle.gather(seg_masks, index=sort_inds) + seg_preds = paddle.gather(seg_preds, index=sort_inds) + sum_masks = paddle.gather(sum_masks, index=sort_inds) + cate_scores = paddle.gather(cate_scores, index=sort_inds) + cate_labels = paddle.gather(cate_labels, index=sort_inds) + + seg_masks = paddle.flatten(seg_masks, start_axis=1, stop_axis=-1) + # inter. + inter_matrix = paddle.mm(seg_masks, paddle.transpose(seg_masks, [1, 0])) + n_samples = cate_labels.shape + n_samples = paddle.to_tensor(n_samples, dtype="int32") + # union. + sum_masks_x = paddle.expand(sum_masks, shape=[n_samples, n_samples]) + # iou. + iou_matrix = inter_matrix / ( + sum_masks_x + paddle.transpose(sum_masks_x, [1, 0]) - inter_matrix + ) + iou_matrix = paddle.triu(iou_matrix, diagonal=1) + # label_specific matrix. + cate_labels_x = paddle.expand(cate_labels, shape=[n_samples, n_samples]) + label_matrix = paddle.cast( + (cate_labels_x == paddle.transpose(cate_labels_x, [1, 0])), "float32" + ) + label_matrix = paddle.triu(label_matrix, diagonal=1) + + # IoU compensation + compensate_iou = paddle.max((iou_matrix * label_matrix), axis=0) + compensate_iou = paddle.expand(compensate_iou, shape=[n_samples, n_samples]) + compensate_iou = paddle.transpose(compensate_iou, [1, 0]) + + # IoU decay + decay_iou = iou_matrix * label_matrix + + # matrix nms + if self.kernel == "gaussian": + decay_matrix = paddle.exp(-1 * self.sigma * (decay_iou**2)) + compensate_matrix = paddle.exp(-1 * self.sigma * (compensate_iou**2)) + decay_coefficient = paddle.min(decay_matrix / compensate_matrix, axis=0) + elif self.kernel == "linear": + decay_matrix = (1 - decay_iou) / (1 - compensate_iou) + decay_coefficient = paddle.min(decay_matrix, axis=0) + else: + raise NotImplementedError + + # update the score. + cate_scores = cate_scores * decay_coefficient + y = paddle.zeros(shape=cate_scores.shape, dtype="float32") + keep = paddle.where(cate_scores >= self.update_threshold, cate_scores, y) + keep = paddle.nonzero(keep) + keep = paddle.squeeze(keep, axis=[1]) + # Prevent empty and increase fake data + keep = paddle.concat( + [keep, paddle.cast(paddle.shape(cate_scores)[0:1] - 1, "int64")] + ) + + seg_preds = paddle.gather(seg_preds, index=keep) + cate_scores = paddle.gather(cate_scores, index=keep) + cate_labels = paddle.gather(cate_labels, index=keep) + + # sort and keep top_k + sort_inds = self._sort_score(cate_scores, self.post_nms_top_n) + seg_preds = paddle.gather(seg_preds, index=sort_inds) + cate_scores = paddle.gather(cate_scores, index=sort_inds) + cate_labels = paddle.gather(cate_labels, index=sort_inds) + return seg_preds, cate_scores, cate_labels + + +def Conv2d( + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + weight_init=Normal(std=0.001), + bias_init=Constant(0.0), +): + weight_attr = paddle.framework.ParamAttr(initializer=weight_init) + if bias: + bias_attr = paddle.framework.ParamAttr(initializer=bias_init) + else: + bias_attr = False + conv = nn.Conv2D( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + weight_attr=weight_attr, + bias_attr=bias_attr, + ) + return conv + + +def ConvTranspose2d( + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + weight_init=Normal(std=0.001), + bias_init=Constant(0.0), +): + weight_attr = paddle.framework.ParamAttr(initializer=weight_init) + if bias: + bias_attr = paddle.framework.ParamAttr(initializer=bias_init) + else: + bias_attr = False + conv = nn.Conv2DTranspose( + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + dilation, + groups, + weight_attr=weight_attr, + bias_attr=bias_attr, + ) + return conv + + +def BatchNorm2d(num_features, eps=1e-05, momentum=0.9, affine=True): + if not affine: + weight_attr = False + bias_attr = False + else: + weight_attr = None + bias_attr = None + batchnorm = nn.BatchNorm2D( + num_features, momentum, eps, weight_attr=weight_attr, bias_attr=bias_attr + ) + return batchnorm + + +def ReLU(): + return nn.ReLU() + + +def Upsample(scale_factor=None, mode="nearest", align_corners=False): + return nn.Upsample(None, scale_factor, mode, align_corners) + + +def MaxPool(kernel_size, stride, padding, ceil_mode=False): + return nn.MaxPool2D(kernel_size, stride, padding, ceil_mode=ceil_mode) + + +class Concat(nn.Layer): + def __init__(self, dim=0): + super(Concat, self).__init__() + self.dim = dim + + def forward(self, inputs): + return paddle.concat(inputs, axis=self.dim) + + def extra_repr(self): + return "dim={}".format(self.dim) + + +def _convert_attention_mask(attn_mask, dtype): + """ + Convert the attention mask to the target dtype we expect. + Parameters: + attn_mask (Tensor, optional): A tensor used in multi-head attention + to prevents attention to some unwanted positions, usually the + paddings or the subsequent positions. It is a tensor with shape + broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. + When the data type is bool, the unwanted positions have `False` + values and the others have `True` values. When the data type is + int, the unwanted positions have 0 values and the others have 1 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when + nothing wanted or needed to be prevented attention to. Default None. + dtype (VarType): The target type of `attn_mask` we expect. + Returns: + Tensor: A Tensor with shape same as input `attn_mask`, with data type `dtype`. + """ + return nn.layer.transformer._convert_attention_mask(attn_mask, dtype) + + +class MultiHeadAttention(nn.Layer): + """ + Attention mapps queries and a set of key-value pairs to outputs, and + Multi-Head Attention performs multiple parallel attention to jointly attending + to information from different representation subspaces. + + Please refer to `Attention Is All You Need `_ + for more details. + + Parameters: + embed_dim (int): The expected feature size in the input and output. + num_heads (int): The number of heads in multi-head attention. + dropout (float, optional): The dropout probability used on attention + weights to drop some attention targets. 0 for no dropout. Default 0 + kdim (int, optional): The feature size in key. If None, assumed equal to + `embed_dim`. Default None. + vdim (int, optional): The feature size in value. If None, assumed equal to + `embed_dim`. Default None. + need_weights (bool, optional): Indicate whether to return the attention + weights. Default False. + + Examples: + + .. code-block:: python + + import paddle + + # encoder input: [batch_size, sequence_length, d_model] + query = paddle.rand((2, 4, 128)) + # self attention mask: [batch_size, num_heads, query_len, query_len] + attn_mask = paddle.rand((2, 2, 4, 4)) + multi_head_attn = paddle.nn.MultiHeadAttention(128, 2) + output = multi_head_attn(query, None, None, attn_mask=attn_mask) # [2, 4, 128] + """ + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + kdim=None, + vdim=None, + need_weights=False, + ): + super(MultiHeadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.need_weights = need_weights + + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + if self._qkv_same_embed_dim: + self.in_proj_weight = self.create_parameter( + shape=[embed_dim, 3 * embed_dim], + attr=None, + dtype=self._dtype, + is_bias=False, + ) + self.in_proj_bias = self.create_parameter( + shape=[3 * embed_dim], attr=None, dtype=self._dtype, is_bias=True + ) + else: + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.k_proj = nn.Linear(self.kdim, embed_dim) + self.v_proj = nn.Linear(self.vdim, embed_dim) + + self.out_proj = nn.Linear(embed_dim, embed_dim) + self._type_list = ("q_proj", "k_proj", "v_proj") + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + else: + constant_(p) + + def compute_qkv(self, tensor, index): + if self._qkv_same_embed_dim: + tensor = F.linear( + x=tensor, + weight=self.in_proj_weight[ + :, index * self.embed_dim : (index + 1) * self.embed_dim + ], + bias=( + self.in_proj_bias[ + index * self.embed_dim : (index + 1) * self.embed_dim + ] + if self.in_proj_bias is not None + else None + ), + ) + else: + tensor = getattr(self, self._type_list[index])(tensor) + tensor = tensor.reshape([0, 0, self.num_heads, self.head_dim]).transpose( + [0, 2, 1, 3] + ) + return tensor + + def forward(self, query, key=None, value=None, attn_mask=None): + r""" + Applies multi-head attention to map queries and a set of key-value pairs + to outputs. + + Parameters: + query (Tensor): The queries for multi-head attention. It is a + tensor with shape `[batch_size, query_length, embed_dim]`. The + data type should be float32 or float64. + key (Tensor, optional): The keys for multi-head attention. It is + a tensor with shape `[batch_size, key_length, kdim]`. The + data type should be float32 or float64. If None, use `query` as + `key`. Default None. + value (Tensor, optional): The values for multi-head attention. It + is a tensor with shape `[batch_size, value_length, vdim]`. + The data type should be float32 or float64. If None, use `query` as + `value`. Default None. + attn_mask (Tensor, optional): A tensor used in multi-head attention + to prevents attention to some unwanted positions, usually the + paddings or the subsequent positions. It is a tensor with shape + broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. + When the data type is bool, the unwanted positions have `False` + values and the others have `True` values. When the data type is + int, the unwanted positions have 0 values and the others have 1 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when + nothing wanted or needed to be prevented attention to. Default None. + + Returns: + Tensor|tuple: It is a tensor that has the same shape and data type \ + as `query`, representing attention output. Or a tuple if \ + `need_weights` is True or `cache` is not None. If `need_weights` \ + is True, except for attention output, the tuple also includes \ + the attention weights tensor shaped `[batch_size, num_heads, query_length, key_length]`. \ + If `cache` is not None, the tuple then includes the new cache \ + having the same type as `cache`, and if it is `StaticCache`, it \ + is same as the input `cache`, if it is `Cache`, the new cache \ + reserves tensors concatanating raw tensors with intermediate \ + results of current query. + """ + key = query if key is None else key + value = query if value is None else value + # compute q ,k ,v + q, k, v = (self.compute_qkv(t, i) for i, t in enumerate([query, key, value])) + + # scale dot product attention + product = paddle.matmul(x=q, y=k, transpose_y=True) + scaling = float(self.head_dim) ** -0.5 + product = product * scaling + + if attn_mask is not None: + # Support bool or int mask + attn_mask = _convert_attention_mask(attn_mask, product.dtype) + product = product + attn_mask + weights = F.softmax(product) + if self.dropout: + weights = F.dropout( + weights, self.dropout, training=self.training, mode="upscale_in_train" + ) + out = paddle.matmul(weights, v) + + # combine heads + out = paddle.transpose(out, perm=[0, 2, 1, 3]) + out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.out_proj(out) + + outs = [out] + if self.need_weights: + outs.append(weights) + return out if len(outs) == 1 else tuple(outs) + + +class ConvMixer(nn.Layer): + def __init__( + self, + dim, + depth, + kernel_size=3, + ): + super().__init__() + self.dim = dim + self.depth = depth + self.kernel_size = kernel_size + + self.mixer = self.conv_mixer(dim, depth, kernel_size) + + def forward(self, x): + return self.mixer(x) + + @staticmethod + def conv_mixer( + dim, + depth, + kernel_size, + ): + Seq, ActBn = nn.Sequential, lambda x: Seq(x, nn.GELU(), nn.BatchNorm2D(dim)) + Residual = type("Residual", (Seq,), {"forward": lambda self, x: self[0](x) + x}) + return Seq( + *[ + Seq( + Residual( + ActBn( + nn.Conv2D(dim, dim, kernel_size, groups=dim, padding="same") + ) + ), + ActBn(nn.Conv2D(dim, dim, 1)), + ) + for i in range(depth) + ] + ) diff --git a/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/matchers.py b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/matchers.py new file mode 100644 index 0000000000..7ded61b064 --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/matchers.py @@ -0,0 +1,275 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from __future__ import absolute_import, division, print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from scipy.optimize import linear_sum_assignment + +from .utils import bbox_cxcywh_to_xyxy + +__all__ = ["HungarianMatcher"] + + +class GIoULoss(object): + """ + Generalized Intersection over Union, see https://arxiv.org/abs/1902.09630 + Args: + loss_weight (float): giou loss weight, default as 1 + eps (float): epsilon to avoid divide by zero, default as 1e-10 + reduction (string): Options are "none", "mean" and "sum". default as none + """ + + def __init__(self, loss_weight=1.0, eps=1e-10, reduction="none"): + self.loss_weight = loss_weight + self.eps = eps + assert reduction in ("none", "mean", "sum") + self.reduction = reduction + + def bbox_overlap(self, box1, box2, eps=1e-10): + """calculate the iou of box1 and box2 + Args: + box1 (Tensor): box1 with the shape (..., 4) + box2 (Tensor): box1 with the shape (..., 4) + eps (float): epsilon to avoid divide by zero + Return: + iou (Tensor): iou of box1 and box2 + overlap (Tensor): overlap of box1 and box2 + union (Tensor): union of box1 and box2 + """ + x1, y1, x2, y2 = box1 + x1g, y1g, x2g, y2g = box2 + + xkis1 = paddle.maximum(x1, x1g) + ykis1 = paddle.maximum(y1, y1g) + xkis2 = paddle.minimum(x2, x2g) + ykis2 = paddle.minimum(y2, y2g) + w_inter = (xkis2 - xkis1).clip(0) + h_inter = (ykis2 - ykis1).clip(0) + overlap = w_inter * h_inter + + area1 = (x2 - x1) * (y2 - y1) + area2 = (x2g - x1g) * (y2g - y1g) + union = area1 + area2 - overlap + eps + iou = overlap / union + + return iou, overlap, union + + def __call__(self, pbox, gbox, iou_weight=1.0, loc_reweight=None): + x1, y1, x2, y2 = paddle.split(pbox, num_or_sections=4, axis=-1) + x1g, y1g, x2g, y2g = paddle.split(gbox, num_or_sections=4, axis=-1) + box1 = [x1, y1, x2, y2] + box2 = [x1g, y1g, x2g, y2g] + iou, overlap, union = self.bbox_overlap(box1, box2, self.eps) + xc1 = paddle.minimum(x1, x1g) + yc1 = paddle.minimum(y1, y1g) + xc2 = paddle.maximum(x2, x2g) + yc2 = paddle.maximum(y2, y2g) + + area_c = (xc2 - xc1) * (yc2 - yc1) + self.eps + miou = iou - ((area_c - union) / area_c) + if loc_reweight is not None: + loc_reweight = paddle.reshape(loc_reweight, shape=(-1, 1)) + loc_thresh = 0.9 + giou = 1 - (1 - loc_thresh) * miou - loc_thresh * miou * loc_reweight + else: + giou = 1 - miou + if self.reduction == "none": + loss = giou + elif self.reduction == "sum": + loss = paddle.sum(giou * iou_weight) + else: + loss = paddle.mean(giou * iou_weight) + return loss * self.loss_weight + + +class HungarianMatcher(nn.Layer): + __shared__ = ["use_focal_loss", "with_mask", "num_sample_points"] + + def __init__( + self, + matcher_coeff={"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}, + use_focal_loss=False, + with_mask=False, + num_sample_points=12544, + alpha=0.25, + gamma=2.0, + ): + r""" + Args: + matcher_coeff (dict): The coefficient of hungarian matcher cost. + """ + super(HungarianMatcher, self).__init__() + self.matcher_coeff = matcher_coeff + self.use_focal_loss = use_focal_loss + self.with_mask = with_mask + self.num_sample_points = num_sample_points + self.alpha = alpha + self.gamma = gamma + + self.giou_loss = GIoULoss() + + def forward(self, boxes, logits, gt_bbox, gt_class, masks=None, gt_mask=None): + r""" + Args: + boxes (Tensor): [b, query, 4] + logits (Tensor): [b, query, num_classes] + gt_bbox (List(Tensor)): list[[n, 4]] + gt_class (List(Tensor)): list[[n, 1]] + masks (Tensor|None): [b, query, h, w] + gt_mask (List(Tensor)): list[[n, H, W]] + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + bs, num_queries = boxes.shape[:2] + + num_gts = [len(a) for a in gt_class] + if sum(num_gts) == 0: + return [ + ( + paddle.to_tensor([], dtype=paddle.int64), + paddle.to_tensor([], dtype=paddle.int64), + ) + for _ in range(bs) + ] + + # We flatten to compute the cost matrices in a batch + # [batch_size * num_queries, num_classes] + logits = logits.detach() + out_prob = ( + F.sigmoid(logits.flatten(0, 1)) + if self.use_focal_loss + else F.softmax(logits.flatten(0, 1)) + ) + # [batch_size * num_queries, 4] + out_bbox = boxes.detach().flatten(0, 1) + + # Also concat the target labels and boxes + if "npu" in paddle.device.get_device(): + gt_class = [tensor.to(paddle.int32) for tensor in gt_class] + + tgt_ids = paddle.concat(gt_class).flatten() + tgt_bbox = paddle.concat(gt_bbox) + + # Compute the classification cost + out_prob = paddle.gather(out_prob, tgt_ids, axis=1) + if self.use_focal_loss: + neg_cost_class = ( + (1 - self.alpha) + * (out_prob**self.gamma) + * (-(1 - out_prob + 1e-8).log()) + ) + pos_cost_class = ( + self.alpha * ((1 - out_prob) ** self.gamma) * (-(out_prob + 1e-8).log()) + ) + cost_class = pos_cost_class - neg_cost_class + else: + cost_class = -out_prob + + # Compute the L1 cost between boxes + cost_bbox = (out_bbox.unsqueeze(1) - tgt_bbox.unsqueeze(0)).abs().sum(-1) + + # Compute the giou cost betwen boxes + giou_loss = self.giou_loss( + bbox_cxcywh_to_xyxy(out_bbox.unsqueeze(1)), + bbox_cxcywh_to_xyxy(tgt_bbox.unsqueeze(0)), + ).squeeze(-1) + cost_giou = giou_loss - 1 + + # Final cost matrix + C = ( + self.matcher_coeff["class"] * cost_class + + self.matcher_coeff["bbox"] * cost_bbox + + self.matcher_coeff["giou"] * cost_giou + ) + # Compute the mask cost and dice cost + if self.with_mask: + assert ( + masks is not None and gt_mask is not None, + "Make sure the input has `mask` and `gt_mask`", + ) + # all masks share the same set of points for efficient matching + sample_points = paddle.rand([bs, 1, self.num_sample_points, 2]) + sample_points = 2.0 * sample_points - 1.0 + + out_mask = F.grid_sample( + masks.detach(), sample_points, align_corners=False + ).squeeze(-2) + out_mask = out_mask.flatten(0, 1) + + tgt_mask = paddle.concat(gt_mask).unsqueeze(1) + sample_points = paddle.concat( + [a.tile([b, 1, 1, 1]) for a, b in zip(sample_points, num_gts) if b > 0] + ) + tgt_mask = F.grid_sample( + tgt_mask, sample_points, align_corners=False + ).squeeze([1, 2]) + + with paddle.amp.auto_cast(enable=False): + # binary cross entropy cost + pos_cost_mask = F.binary_cross_entropy_with_logits( + out_mask, paddle.ones_like(out_mask), reduction="none" + ) + neg_cost_mask = F.binary_cross_entropy_with_logits( + out_mask, paddle.zeros_like(out_mask), reduction="none" + ) + cost_mask = paddle.matmul( + pos_cost_mask, tgt_mask, transpose_y=True + ) + paddle.matmul(neg_cost_mask, 1 - tgt_mask, transpose_y=True) + cost_mask /= self.num_sample_points + + # dice cost + out_mask = F.sigmoid(out_mask) + numerator = 2 * paddle.matmul(out_mask, tgt_mask, transpose_y=True) + denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum( + -1 + ).unsqueeze(0) + cost_dice = 1 - (numerator + 1) / (denominator + 1) + + C = ( + C + + self.matcher_coeff["mask"] * cost_mask + + self.matcher_coeff["dice"] * cost_dice + ) + + C = C.reshape([bs, num_queries, -1]) + C = [a.squeeze(0) for a in C.chunk(bs)] + sizes = [a.shape[0] for a in gt_bbox] + if hasattr(paddle.Tensor, "contiguous"): + indices = [ + linear_sum_assignment(c.split(sizes, -1)[i].contiguous().numpy()) + for i, c in enumerate(C) + ] + else: + indices = [ + linear_sum_assignment(c.split(sizes, -1)[i].numpy()) + for i, c in enumerate(C) + ] + return [ + ( + paddle.to_tensor(i, dtype=paddle.int64), + paddle.to_tensor(j, dtype=paddle.int64), + ) + for i, j in indices + ] diff --git a/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/ops.py b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/ops.py new file mode 100644 index 0000000000..c58834943c --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/ops.py @@ -0,0 +1,1193 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr, in_dynamic_mode +from paddle.common_ops_import import ( + LayerHelper, + Variable, + check_type, + check_variable_and_dtype, +) +from paddle.regularizer import L2Decay + +try: + import paddle._legacy_C_ops as C_ops +except: + import paddle._C_ops as C_ops + +try: + from paddle.framework import in_dynamic_or_pir_mode + + HAVE_PIR = True +except: + HAVE_PIR = False + + +__all__ = [ + "prior_box", + "generate_proposals", + "box_coder", + "multiclass_nms", + "distribute_fpn_proposals", + "matrix_nms", + "batch_norm", + "mish", + "silu", + "swish", + "identity", + "anchor_generator", +] + + +def identity(x): + return x + + +def mish(x): + return F.mish(x) if hasattr(F, mish) else x * F.tanh(F.softplus(x)) + + +def silu(x): + return F.silu(x) + + +def swish(x): + return x * F.sigmoid(x) + + +TRT_ACT_SPEC = {"swish": swish, "silu": swish} + +ACT_SPEC = {"mish": mish, "silu": silu} + + +def get_act_fn(act=None, trt=False): + assert act is None or isinstance( + act, (str, dict) + ), "name of activation should be str, dict or None" + if not act: + return identity + + if isinstance(act, dict): + name = act["name"] + act.pop("name") + kwargs = act + else: + name = act + kwargs = dict() + + if trt and name in TRT_ACT_SPEC: + fn = TRT_ACT_SPEC[name] + elif name in ACT_SPEC: + fn = ACT_SPEC[name] + else: + fn = getattr(F, name) + + return lambda x: fn(x, **kwargs) + + +def batch_norm( + ch, + norm_type="bn", + norm_decay=0.0, + freeze_norm=False, + initializer=None, + data_format="NCHW", +): + + norm_lr = 0.0 if freeze_norm else 1.0 + weight_attr = ParamAttr( + initializer=initializer, + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay), + trainable=False if freeze_norm else True, + ) + bias_attr = ParamAttr( + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay), + trainable=False if freeze_norm else True, + ) + + if norm_type in ["sync_bn", "bn"]: + norm_layer = nn.BatchNorm2D( + ch, weight_attr=weight_attr, bias_attr=bias_attr, data_format=data_format + ) + + norm_params = norm_layer.parameters() + if freeze_norm: + for param in norm_params: + param.stop_gradient = True + + return norm_layer + + +@paddle.jit.not_to_static +def anchor_generator( + input, + anchor_sizes=None, + aspect_ratios=None, + variance=[0.1, 0.1, 0.2, 0.2], + stride=None, + offset=0.5, +): + """ + **Anchor generator operator** + Generate anchors for Faster RCNN algorithm. + Each position of the input produce N anchors, N = + size(anchor_sizes) * size(aspect_ratios). The order of generated anchors + is firstly aspect_ratios loop then anchor_sizes loop. + Args: + input(Variable): 4-D Tensor with shape [N,C,H,W]. The input feature map. + anchor_sizes(float32|list|tuple, optional): The anchor sizes of generated + anchors, given in absolute pixels e.g. [64., 128., 256., 512.]. + For instance, the anchor size of 64 means the area of this anchor + equals to 64**2. None by default. + aspect_ratios(float32|list|tuple, optional): The height / width ratios + of generated anchors, e.g. [0.5, 1.0, 2.0]. None by default. + variance(list|tuple, optional): The variances to be used in box + regression deltas. The data type is float32, [0.1, 0.1, 0.2, 0.2] by + default. + stride(list|tuple, optional): The anchors stride across width and height. + The data type is float32. e.g. [16.0, 16.0]. None by default. + offset(float32, optional): Prior boxes center offset. 0.5 by default. + Returns: + Tuple: + Anchors(Variable): The output anchors with a layout of [H, W, num_anchors, 4]. + H is the height of input, W is the width of input, + num_anchors is the box count of each position. + Each anchor is in (xmin, ymin, xmax, ymax) format an unnormalized. + + Variances(Variable): The expanded variances of anchors + with a layout of [H, W, num_priors, 4]. + H is the height of input, W is the width of input + num_anchors is the box count of each position. + Each variance is in (xcenter, ycenter, w, h) format. + Examples: + .. code-block:: python + import paddle.fluid as fluid + conv1 = fluid.data(name='conv1', shape=[None, 48, 16, 16], dtype='float32') + anchor, var = fluid.layers.anchor_generator( + input=conv1, + anchor_sizes=[64, 128, 256, 512], + aspect_ratios=[0.5, 1.0, 2.0], + variance=[0.1, 0.1, 0.2, 0.2], + stride=[16.0, 16.0], + offset=0.5) + """ + + def _is_list_or_tuple_(data): + return isinstance(data, list) or isinstance(data, tuple) + + if not _is_list_or_tuple_(anchor_sizes): + anchor_sizes = [anchor_sizes] + if not _is_list_or_tuple_(aspect_ratios): + aspect_ratios = [aspect_ratios] + if not (_is_list_or_tuple_(stride) and len(stride) == 2): + raise ValueError( + "stride should be a list or tuple ", + "with length 2, (stride_width, stride_height).", + ) + + anchor_sizes = list(map(float, anchor_sizes)) + aspect_ratios = list(map(float, aspect_ratios)) + stride = list(map(float, stride)) + + if in_dynamic_mode(): + attrs = ( + "anchor_sizes", + anchor_sizes, + "aspect_ratios", + aspect_ratios, + "variances", + variance, + "stride", + stride, + "offset", + offset, + ) + anchor, var = C_ops.anchor_generator(input, *attrs) + return anchor, var + + helper = LayerHelper("anchor_generator", **locals()) + dtype = helper.input_dtype() + attrs = { + "anchor_sizes": anchor_sizes, + "aspect_ratios": aspect_ratios, + "variances": variance, + "stride": stride, + "offset": offset, + } + + anchor = helper.create_variable_for_type_inference(dtype) + var = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="anchor_generator", + inputs={"Input": input}, + outputs={"Anchors": anchor, "Variances": var}, + attrs=attrs, + ) + anchor.stop_gradient = True + var.stop_gradient = True + return anchor, var + + +@paddle.jit.not_to_static +def distribute_fpn_proposals( + fpn_rois, + min_level, + max_level, + refer_level, + refer_scale, + pixel_offset=False, + rois_num=None, + name=None, +): + r""" + + **This op only takes LoDTensor as input.** In Feature Pyramid Networks + (FPN) models, it is needed to distribute all proposals into different FPN + level, with respect to scale of the proposals, the referring scale and the + referring level. Besides, to restore the order of proposals, we return an + array which indicates the original index of rois in current proposals. + To compute FPN level for each roi, the formula is given as follows: + + .. math:: + + roi\_scale &= \sqrt{BBoxArea(fpn\_roi)} + + level = floor(&\log(\\frac{roi\_scale}{refer\_scale}) + refer\_level) + + where BBoxArea is a function to compute the area of each roi. + + Args: + + fpn_rois(Variable): 2-D Tensor with shape [N, 4] and data type is + float32 or float64. The input fpn_rois. + min_level(int32): The lowest level of FPN layer where the proposals come + from. + max_level(int32): The highest level of FPN layer where the proposals + come from. + refer_level(int32): The referring level of FPN layer with specified scale. + refer_scale(int32): The referring scale of FPN layer with specified level. + rois_num(Tensor): 1-D Tensor contains the number of RoIs in each image. + The shape is [B] and data type is int32. B is the number of images. + If it is not None then return a list of 1-D Tensor. Each element + is the output RoIs' number of each image on the corresponding level + and the shape is [B]. None by default. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + Tuple: + + multi_rois(List) : A list of 2-D LoDTensor with shape [M, 4] + and data type of float32 and float64. The length is + max_level-min_level+1. The proposals in each FPN level. + + restore_ind(Variable): A 2-D Tensor with shape [N, 1], N is + the number of total rois. The data type is int32. It is + used to restore the order of fpn_rois. + + rois_num_per_level(List): A list of 1-D Tensor and each Tensor is + the RoIs' number in each image on the corresponding level. The shape + is [B] and data type of int32. B is the number of images + + + Examples: + .. code-block:: python + + import paddle + from ppdet.modeling import ops + paddle.enable_static() + fpn_rois = paddle.static.data( + name='data', shape=[None, 4], dtype='float32', lod_level=1) + multi_rois, restore_ind = ops.distribute_fpn_proposals( + fpn_rois=fpn_rois, + min_level=2, + max_level=5, + refer_level=4, + refer_scale=224) + """ + num_lvl = max_level - min_level + 1 + + if in_dynamic_mode(): + assert rois_num is not None, "rois_num should not be None in dygraph mode." + attrs = ( + "min_level", + min_level, + "max_level", + max_level, + "refer_level", + refer_level, + "refer_scale", + refer_scale, + "pixel_offset", + pixel_offset, + ) + multi_rois, restore_ind, rois_num_per_level = C_ops.distribute_fpn_proposals( + fpn_rois, rois_num, num_lvl, num_lvl, *attrs + ) + + return multi_rois, restore_ind, rois_num_per_level + + else: + check_variable_and_dtype( + fpn_rois, "fpn_rois", ["float32", "float64"], "distribute_fpn_proposals" + ) + helper = LayerHelper("distribute_fpn_proposals", **locals()) + dtype = helper.input_dtype("fpn_rois") + multi_rois = [ + helper.create_variable_for_type_inference(dtype) for i in range(num_lvl) + ] + + restore_ind = helper.create_variable_for_type_inference(dtype="int32") + + inputs = {"FpnRois": fpn_rois} + outputs = { + "MultiFpnRois": multi_rois, + "RestoreIndex": restore_ind, + } + + if rois_num is not None: + inputs["RoisNum"] = rois_num + rois_num_per_level = [ + helper.create_variable_for_type_inference(dtype="int32") + for i in range(num_lvl) + ] + outputs["MultiLevelRoIsNum"] = rois_num_per_level + else: + rois_num_per_level = None + + helper.append_op( + type="distribute_fpn_proposals", + inputs=inputs, + outputs=outputs, + attrs={ + "min_level": min_level, + "max_level": max_level, + "refer_level": refer_level, + "refer_scale": refer_scale, + "pixel_offset": pixel_offset, + }, + ) + return multi_rois, restore_ind, rois_num_per_level + + +@paddle.jit.not_to_static +def prior_box( + input, + image, + min_sizes, + max_sizes=None, + aspect_ratios=[1.0], + variance=[0.1, 0.1, 0.2, 0.2], + flip=False, + clip=False, + steps=[0.0, 0.0], + offset=0.5, + min_max_aspect_ratios_order=False, + name=None, +): + """ + + This op generates prior boxes for SSD(Single Shot MultiBox Detector) algorithm. + Each position of the input produce N prior boxes, N is determined by + the count of min_sizes, max_sizes and aspect_ratios, The size of the + box is in range(min_size, max_size) interval, which is generated in + sequence according to the aspect_ratios. + + Parameters: + input(Tensor): 4-D tensor(NCHW), the data type should be float32 or float64. + image(Tensor): 4-D tensor(NCHW), the input image data of PriorBoxOp, + the data type should be float32 or float64. + min_sizes(list|tuple|float): the min sizes of generated prior boxes. + max_sizes(list|tuple|None): the max sizes of generated prior boxes. + Default: None. + aspect_ratios(list|tuple|float): the aspect ratios of generated + prior boxes. Default: [1.]. + variance(list|tuple): the variances to be encoded in prior boxes. + Default:[0.1, 0.1, 0.2, 0.2]. + flip(bool): Whether to flip aspect ratios. Default:False. + clip(bool): Whether to clip out-of-boundary boxes. Default: False. + step(list|tuple): Prior boxes step across width and height, If + step[0] equals to 0.0 or step[1] equals to 0.0, the prior boxes step across + height or weight of the input will be automatically calculated. + Default: [0., 0.] + offset(float): Prior boxes center offset. Default: 0.5 + min_max_aspect_ratios_order(bool): If set True, the output prior box is + in order of [min, max, aspect_ratios], which is consistent with + Caffe. Please note, this order affects the weights order of + convolution layer followed by and does not affect the final + detection results. Default: False. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + + Returns: + Tuple: A tuple with two Variable (boxes, variances) + + boxes(Tensor): the output prior boxes of PriorBox. + 4-D tensor, the layout is [H, W, num_priors, 4]. + H is the height of input, W is the width of input, + num_priors is the total box count of each position of input. + + variances(Tensor): the expanded variances of PriorBox. + 4-D tensor, the layput is [H, W, num_priors, 4]. + H is the height of input, W is the width of input + num_priors is the total box count of each position of input + + Examples: + .. code-block:: python + + import paddle + from ppdet.modeling import ops + + paddle.enable_static() + input = paddle.static.data(name="input", shape=[None,3,6,9]) + image = paddle.static.data(name="image", shape=[None,3,9,12]) + box, var = ops.prior_box( + input=input, + image=image, + min_sizes=[100.], + clip=True, + flip=True) + """ + return paddle.vision.ops.prior_box( + input, + image, + min_sizes, + max_sizes, + aspect_ratios, + variance, + flip, + clip, + steps, + offset, + min_max_aspect_ratios_order, + name, + ) + + +@paddle.jit.not_to_static +def multiclass_nms( + bboxes, + scores, + score_threshold, + nms_top_k, + keep_top_k, + nms_threshold=0.3, + normalized=True, + nms_eta=1.0, + background_label=-1, + return_index=False, + return_rois_num=True, + rois_num=None, + name=None, +): + """ + This operator is to do multi-class non maximum suppression (NMS) on + boxes and scores. + In the NMS step, this operator greedily selects a subset of detection bounding + boxes that have high scores larger than score_threshold, if providing this + threshold, then selects the largest nms_top_k confidences scores if nms_top_k + is larger than -1. Then this operator pruns away boxes that have high IOU + (intersection over union) overlap with already selected boxes by adaptive + threshold NMS based on parameters of nms_threshold and nms_eta. + Aftern NMS step, at most keep_top_k number of total bboxes are to be kept + per image if keep_top_k is larger than -1. + Args: + bboxes (Tensor): Two types of bboxes are supported: + 1. (Tensor) A 3-D Tensor with shape + [N, M, 4 or 8 16 24 32] represents the + predicted locations of M bounding bboxes, + N is the batch size. Each bounding box has four + coordinate values and the layout is + [xmin, ymin, xmax, ymax], when box size equals to 4. + 2. (LoDTensor) A 3-D Tensor with shape [M, C, 4] + M is the number of bounding boxes, C is the + class number + scores (Tensor): Two types of scores are supported: + 1. (Tensor) A 3-D Tensor with shape [N, C, M] + represents the predicted confidence predictions. + N is the batch size, C is the class number, M is + number of bounding boxes. For each category there + are total M scores which corresponding M bounding + boxes. Please note, M is equal to the 2nd dimension + of BBoxes. + 2. (LoDTensor) A 2-D LoDTensor with shape [M, C]. + M is the number of bbox, C is the class number. + In this case, input BBoxes should be the second + case with shape [M, C, 4]. + background_label (int): The index of background label, the background + label will be ignored. If set to -1, then all + categories will be considered. Default: 0 + score_threshold (float): Threshold to filter out bounding boxes with + low confidence score. If not provided, + consider all boxes. + nms_top_k (int): Maximum number of detections to be kept according to + the confidences after the filtering detections based + on score_threshold. + nms_threshold (float): The threshold to be used in NMS. Default: 0.3 + nms_eta (float): The threshold to be used in NMS. Default: 1.0 + keep_top_k (int): Number of total bboxes to be kept per image after NMS + step. -1 means keeping all bboxes after NMS step. + normalized (bool): Whether detections are normalized. Default: True + return_index(bool): Whether return selected index. Default: False + rois_num(Tensor): 1-D Tensor contains the number of RoIs in each image. + The shape is [B] and data type is int32. B is the number of images. + If it is not None then return a list of 1-D Tensor. Each element + is the output RoIs' number of each image on the corresponding level + and the shape is [B]. None by default. + name(str): Name of the multiclass nms op. Default: None. + Returns: + A tuple with two Variables: (Out, Index) if return_index is True, + otherwise, a tuple with one Variable(Out) is returned. + Out: A 2-D LoDTensor with shape [No, 6] represents the detections. + Each row has 6 values: [label, confidence, xmin, ymin, xmax, ymax] + or A 2-D LoDTensor with shape [No, 10] represents the detections. + Each row has 10 values: [label, confidence, x1, y1, x2, y2, x3, y3, + x4, y4]. No is the total number of detections. + If all images have not detected results, all elements in LoD will be + 0, and output tensor is empty (None). + Index: Only return when return_index is True. A 2-D LoDTensor with + shape [No, 1] represents the selected index which type is Integer. + The index is the absolute value cross batches. No is the same number + as Out. If the index is used to gather other attribute such as age, + one needs to reshape the input(N, M, 1) to (N * M, 1) as first, where + N is the batch size and M is the number of boxes. + Examples: + .. code-block:: python + + import paddle + from ppdet.modeling import ops + boxes = paddle.static.data(name='bboxes', shape=[81, 4], + dtype='float32', lod_level=1) + scores = paddle.static.data(name='scores', shape=[81], + dtype='float32', lod_level=1) + out, index = ops.multiclass_nms(bboxes=boxes, + scores=scores, + background_label=0, + score_threshold=0.5, + nms_top_k=400, + nms_threshold=0.3, + keep_top_k=200, + normalized=False, + return_index=True) + """ + helper = LayerHelper("multiclass_nms3", **locals()) + + if HAVE_PIR and in_dynamic_or_pir_mode(): + # https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/ops/yaml/ops.yaml#L3175 + attrs = ( + score_threshold, + nms_top_k, + keep_top_k, + nms_threshold, + normalized, + nms_eta, + background_label, + ) + output, index, nms_rois_num = paddle._C_ops.multiclass_nms3( + bboxes, scores, rois_num, *attrs + ) + + if not return_index: + index = None + return output, nms_rois_num, index + + elif in_dynamic_mode(): + attrs = ( + "background_label", + background_label, + "score_threshold", + score_threshold, + "nms_top_k", + nms_top_k, + "nms_threshold", + nms_threshold, + "keep_top_k", + keep_top_k, + "nms_eta", + nms_eta, + "normalized", + normalized, + ) + output, index, nms_rois_num = C_ops.multiclass_nms3( + bboxes, scores, rois_num, *attrs + ) + if not return_index: + index = None + return output, nms_rois_num, index + + else: + output = helper.create_variable_for_type_inference(dtype=bboxes.dtype) + index = helper.create_variable_for_type_inference(dtype="int32") + + inputs = {"BBoxes": bboxes, "Scores": scores} + outputs = {"Out": output, "Index": index} + + if rois_num is not None: + inputs["RoisNum"] = rois_num + + if return_rois_num: + nms_rois_num = helper.create_variable_for_type_inference(dtype="int32") + outputs["NmsRoisNum"] = nms_rois_num + + helper.append_op( + type="multiclass_nms3", + inputs=inputs, + attrs={ + "background_label": background_label, + "score_threshold": score_threshold, + "nms_top_k": nms_top_k, + "nms_threshold": nms_threshold, + "keep_top_k": keep_top_k, + "nms_eta": nms_eta, + "normalized": normalized, + }, + outputs=outputs, + ) + output.stop_gradient = True + index.stop_gradient = True + if not return_index: + index = None + if not return_rois_num: + nms_rois_num = None + + return output, nms_rois_num, index + + +@paddle.jit.not_to_static +def matrix_nms( + bboxes, + scores, + score_threshold, + post_threshold, + nms_top_k, + keep_top_k, + use_gaussian=False, + gaussian_sigma=2.0, + background_label=0, + normalized=True, + return_index=False, + return_rois_num=True, + name=None, +): + """ + **Matrix NMS** + This operator does matrix non maximum suppression (NMS). + First selects a subset of candidate bounding boxes that have higher scores + than score_threshold (if provided), then the top k candidate is selected if + nms_top_k is larger than -1. Score of the remaining candidate are then + decayed according to the Matrix NMS scheme. + Aftern NMS step, at most keep_top_k number of total bboxes are to be kept + per image if keep_top_k is larger than -1. + Args: + bboxes (Tensor): A 3-D Tensor with shape [N, M, 4] represents the + predicted locations of M bounding bboxes, + N is the batch size. Each bounding box has four + coordinate values and the layout is + [xmin, ymin, xmax, ymax], when box size equals to 4. + The data type is float32 or float64. + scores (Tensor): A 3-D Tensor with shape [N, C, M] + represents the predicted confidence predictions. + N is the batch size, C is the class number, M is + number of bounding boxes. For each category there + are total M scores which corresponding M bounding + boxes. Please note, M is equal to the 2nd dimension + of BBoxes. The data type is float32 or float64. + score_threshold (float): Threshold to filter out bounding boxes with + low confidence score. + post_threshold (float): Threshold to filter out bounding boxes with + low confidence score AFTER decaying. + nms_top_k (int): Maximum number of detections to be kept according to + the confidences after the filtering detections based + on score_threshold. + keep_top_k (int): Number of total bboxes to be kept per image after NMS + step. -1 means keeping all bboxes after NMS step. + use_gaussian (bool): Use Gaussian as the decay function. Default: False + gaussian_sigma (float): Sigma for Gaussian decay function. Default: 2.0 + background_label (int): The index of background label, the background + label will be ignored. If set to -1, then all + categories will be considered. Default: 0 + normalized (bool): Whether detections are normalized. Default: True + return_index(bool): Whether return selected index. Default: False + return_rois_num(bool): whether return rois_num. Default: True + name(str): Name of the matrix nms op. Default: None. + Returns: + A tuple with three Tensor: (Out, Index, RoisNum) if return_index is True, + otherwise, a tuple with two Tensor (Out, RoisNum) is returned. + Out (Tensor): A 2-D Tensor with shape [No, 6] containing the + detection results. + Each row has 6 values: [label, confidence, xmin, ymin, xmax, ymax] + (After version 1.3, when no boxes detected, the lod is changed + from {0} to {1}) + Index (Tensor): A 2-D Tensor with shape [No, 1] containing the + selected indices, which are absolute values cross batches. + rois_num (Tensor): A 1-D Tensor with shape [N] containing + the number of detected boxes in each image. + Examples: + .. code-block:: python + import paddle + from ppdet.modeling import ops + boxes = paddle.static.data(name='bboxes', shape=[None,81, 4], + dtype='float32', lod_level=1) + scores = paddle.static.data(name='scores', shape=[None,81], + dtype='float32', lod_level=1) + out = ops.matrix_nms(bboxes=boxes, scores=scores, background_label=0, + score_threshold=0.5, post_threshold=0.1, + nms_top_k=400, keep_top_k=200, normalized=False) + """ + check_variable_and_dtype(bboxes, "BBoxes", ["float32", "float64"], "matrix_nms") + check_variable_and_dtype(scores, "Scores", ["float32", "float64"], "matrix_nms") + check_type(score_threshold, "score_threshold", float, "matrix_nms") + check_type(post_threshold, "post_threshold", float, "matrix_nms") + check_type(nms_top_k, "nums_top_k", int, "matrix_nms") + check_type(keep_top_k, "keep_top_k", int, "matrix_nms") + check_type(normalized, "normalized", bool, "matrix_nms") + check_type(use_gaussian, "use_gaussian", bool, "matrix_nms") + check_type(gaussian_sigma, "gaussian_sigma", float, "matrix_nms") + check_type(background_label, "background_label", int, "matrix_nms") + + if in_dynamic_mode(): + attrs = ( + "background_label", + background_label, + "score_threshold", + score_threshold, + "post_threshold", + post_threshold, + "nms_top_k", + nms_top_k, + "gaussian_sigma", + gaussian_sigma, + "use_gaussian", + use_gaussian, + "keep_top_k", + keep_top_k, + "normalized", + normalized, + ) + out, index, rois_num = C_ops.matrix_nms(bboxes, scores, *attrs) + if not return_index: + index = None + if not return_rois_num: + rois_num = None + return out, rois_num, index + else: + helper = LayerHelper("matrix_nms", **locals()) + output = helper.create_variable_for_type_inference(dtype=bboxes.dtype) + index = helper.create_variable_for_type_inference(dtype="int32") + outputs = {"Out": output, "Index": index} + if return_rois_num: + rois_num = helper.create_variable_for_type_inference(dtype="int32") + outputs["RoisNum"] = rois_num + + helper.append_op( + type="matrix_nms", + inputs={"BBoxes": bboxes, "Scores": scores}, + attrs={ + "background_label": background_label, + "score_threshold": score_threshold, + "post_threshold": post_threshold, + "nms_top_k": nms_top_k, + "gaussian_sigma": gaussian_sigma, + "use_gaussian": use_gaussian, + "keep_top_k": keep_top_k, + "normalized": normalized, + }, + outputs=outputs, + ) + output.stop_gradient = True + + if not return_index: + index = None + if not return_rois_num: + rois_num = None + return output, rois_num, index + + +@paddle.jit.not_to_static +def box_coder( + prior_box, + prior_box_var, + target_box, + code_type="encode_center_size", + box_normalized=True, + axis=0, + name=None, +): + r""" + **Box Coder Layer** + Encode/Decode the target bounding box with the priorbox information. + + The Encoding schema described below: + .. math:: + ox = (tx - px) / pw / pxv + oy = (ty - py) / ph / pyv + ow = \log(\abs(tw / pw)) / pwv + oh = \log(\abs(th / ph)) / phv + The Decoding schema described below: + + .. math:: + + ox = (pw * pxv * tx * + px) - tw / 2 + oy = (ph * pyv * ty * + py) - th / 2 + ow = \exp(pwv * tw) * pw + tw / 2 + oh = \exp(phv * th) * ph + th / 2 + where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates, + width and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote + the priorbox's (anchor) center coordinates, width and height. `pxv`, + `pyv`, `pwv`, `phv` denote the variance of the priorbox and `ox`, `oy`, + `ow`, `oh` denote the encoded/decoded coordinates, width and height. + During Box Decoding, two modes for broadcast are supported. Say target + box has shape [N, M, 4], and the shape of prior box can be [N, 4] or + [M, 4]. Then prior box will broadcast to target box along the + assigned axis. + + Args: + prior_box(Tensor): Box list prior_box is a 2-D Tensor with shape + [M, 4] holds M boxes and data type is float32 or float64. Each box + is represented as [xmin, ymin, xmax, ymax], [xmin, ymin] is the + left top coordinate of the anchor box, if the input is image feature + map, they are close to the origin of the coordinate system. + [xmax, ymax] is the right bottom coordinate of the anchor box. + prior_box_var(List|Tensor|None): prior_box_var supports three types + of input. One is Tensor with shape [M, 4] which holds M group and + data type is float32 or float64. The second is list consist of + 4 elements shared by all boxes and data type is float32 or float64. + Other is None and not involved in calculation. + target_box(Tensor): This input can be a 2-D LoDTensor with shape + [N, 4] when code_type is 'encode_center_size'. This input also can + be a 3-D Tensor with shape [N, M, 4] when code_type is + 'decode_center_size'. Each box is represented as + [xmin, ymin, xmax, ymax]. The data type is float32 or float64. + code_type(str): The code type used with the target box. It can be + `encode_center_size` or `decode_center_size`. `encode_center_size` + by default. + box_normalized(bool): Whether treat the priorbox as a normalized box. + Set true by default. + axis(int): Which axis in PriorBox to broadcast for box decode, + for example, if axis is 0 and TargetBox has shape [N, M, 4] and + PriorBox has shape [M, 4], then PriorBox will broadcast to [N, M, 4] + for decoding. It is only valid when code type is + `decode_center_size`. Set 0 by default. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + Tensor: + output_box(Tensor): When code_type is 'encode_center_size', the + output tensor of box_coder_op with shape [N, M, 4] representing the + result of N target boxes encoded with M Prior boxes and variances. + When code_type is 'decode_center_size', N represents the batch size + and M represents the number of decoded boxes. + + Examples: + + .. code-block:: python + + import paddle + from ppdet.modeling import ops + paddle.enable_static() + # For encode + prior_box_encode = paddle.static.data(name='prior_box_encode', + shape=[512, 4], + dtype='float32') + target_box_encode = paddle.static.data(name='target_box_encode', + shape=[81, 4], + dtype='float32') + output_encode = ops.box_coder(prior_box=prior_box_encode, + prior_box_var=[0.1,0.1,0.2,0.2], + target_box=target_box_encode, + code_type="encode_center_size") + # For decode + prior_box_decode = paddle.static.data(name='prior_box_decode', + shape=[512, 4], + dtype='float32') + target_box_decode = paddle.static.data(name='target_box_decode', + shape=[512, 81, 4], + dtype='float32') + output_decode = ops.box_coder(prior_box=prior_box_decode, + prior_box_var=[0.1,0.1,0.2,0.2], + target_box=target_box_decode, + code_type="decode_center_size", + box_normalized=False, + axis=1) + """ + check_variable_and_dtype( + prior_box, "prior_box", ["float32", "float64"], "box_coder" + ) + check_variable_and_dtype( + target_box, "target_box", ["float32", "float64"], "box_coder" + ) + + if in_dynamic_mode(): + if isinstance(prior_box_var, Variable): + output_box = C_ops.box_coder( + prior_box, + prior_box_var, + target_box, + "code_type", + code_type, + "box_normalized", + box_normalized, + "axis", + axis, + ) + + elif isinstance(prior_box_var, list): + output_box = C_ops.box_coder( + prior_box, + None, + target_box, + "code_type", + code_type, + "box_normalized", + box_normalized, + "axis", + axis, + "variance", + prior_box_var, + ) + else: + raise TypeError("Input variance of box_coder must be Variable or list") + return output_box + else: + helper = LayerHelper("box_coder", **locals()) + + output_box = helper.create_variable_for_type_inference(dtype=prior_box.dtype) + + inputs = {"PriorBox": prior_box, "TargetBox": target_box} + attrs = {"code_type": code_type, "box_normalized": box_normalized, "axis": axis} + if isinstance(prior_box_var, Variable): + inputs["PriorBoxVar"] = prior_box_var + elif isinstance(prior_box_var, list): + attrs["variance"] = prior_box_var + else: + raise TypeError("Input variance of box_coder must be Variable or list") + helper.append_op( + type="box_coder", + inputs=inputs, + attrs=attrs, + outputs={"OutputBox": output_box}, + ) + return output_box + + +@paddle.jit.not_to_static +def generate_proposals( + scores, + bbox_deltas, + im_shape, + anchors, + variances, + pre_nms_top_n=6000, + post_nms_top_n=1000, + nms_thresh=0.5, + min_size=0.1, + eta=1.0, + pixel_offset=False, + return_rois_num=False, + name=None, +): + """ + **Generate proposal Faster-RCNN** + This operation proposes RoIs according to each box with their + probability to be a foreground object and + the box can be calculated by anchors. Bbox_deltais and scores + to be an object are the output of RPN. Final proposals + could be used to train detection net. + For generating proposals, this operation performs following steps: + 1. Transposes and resizes scores and bbox_deltas in size of + (H*W*A, 1) and (H*W*A, 4) + 2. Calculate box locations as proposals candidates. + 3. Clip boxes to image + 4. Remove predicted boxes with small area. + 5. Apply NMS to get final proposals as output. + Args: + scores(Tensor): A 4-D Tensor with shape [N, A, H, W] represents + the probability for each box to be an object. + N is batch size, A is number of anchors, H and W are height and + width of the feature map. The data type must be float32. + bbox_deltas(Tensor): A 4-D Tensor with shape [N, 4*A, H, W] + represents the difference between predicted box location and + anchor location. The data type must be float32. + im_shape(Tensor): A 2-D Tensor with shape [N, 2] represents H, W, the + origin image size or input size. The data type can be float32 or + float64. + anchors(Tensor): A 4-D Tensor represents the anchors with a layout + of [H, W, A, 4]. H and W are height and width of the feature map, + num_anchors is the box count of each position. Each anchor is + in (xmin, ymin, xmax, ymax) format an unnormalized. The data type must be float32. + variances(Tensor): A 4-D Tensor. The expanded variances of anchors with a layout of + [H, W, num_priors, 4]. Each variance is in + (xcenter, ycenter, w, h) format. The data type must be float32. + pre_nms_top_n(float): Number of total bboxes to be kept per + image before NMS. The data type must be float32. `6000` by default. + post_nms_top_n(float): Number of total bboxes to be kept per + image after NMS. The data type must be float32. `1000` by default. + nms_thresh(float): Threshold in NMS. The data type must be float32. `0.5` by default. + min_size(float): Remove predicted boxes with either height or + width < min_size. The data type must be float32. `0.1` by default. + eta(float): Apply in adaptive NMS, if adaptive `threshold > 0.5`, + `adaptive_threshold = adaptive_threshold * eta` in each iteration. + return_rois_num(bool): When setting True, it will return a 1D Tensor with shape [N, ] that includes Rois's + num of each image in one batch. The N is the image's num. For example, the tensor has values [4,5] that represents + the first image has 4 Rois, the second image has 5 Rois. It only used in rcnn model. + 'False' by default. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + tuple: + A tuple with format ``(rpn_rois, rpn_roi_probs)``. + - **rpn_rois**: The generated RoIs. 2-D Tensor with shape ``[N, 4]`` while ``N`` is the number of RoIs. The data type is the same as ``scores``. + - **rpn_roi_probs**: The scores of generated RoIs. 2-D Tensor with shape ``[N, 1]`` while ``N`` is the number of RoIs. The data type is the same as ``scores``. + + Examples: + .. code-block:: python + + import paddle + from ppdet.modeling import ops + paddle.enable_static() + scores = paddle.static.data(name='scores', shape=[None, 4, 5, 5], dtype='float32') + bbox_deltas = paddle.static.data(name='bbox_deltas', shape=[None, 16, 5, 5], dtype='float32') + im_shape = paddle.static.data(name='im_shape', shape=[None, 2], dtype='float32') + anchors = paddle.static.data(name='anchors', shape=[None, 5, 4, 4], dtype='float32') + variances = paddle.static.data(name='variances', shape=[None, 5, 10, 4], dtype='float32') + rois, roi_probs = ops.generate_proposals(scores, bbox_deltas, + im_shape, anchors, variances) + """ + if in_dynamic_mode(): + assert return_rois_num, "return_rois_num should be True in dygraph mode." + attrs = ( + "pre_nms_topN", + pre_nms_top_n, + "post_nms_topN", + post_nms_top_n, + "nms_thresh", + nms_thresh, + "min_size", + min_size, + "eta", + eta, + "pixel_offset", + pixel_offset, + ) + rpn_rois, rpn_roi_probs, rpn_rois_num = C_ops.generate_proposals_v2( + scores, bbox_deltas, im_shape, anchors, variances, *attrs + ) + if not return_rois_num: + rpn_rois_num = None + return rpn_rois, rpn_roi_probs, rpn_rois_num + + else: + helper = LayerHelper("generate_proposals_v2", **locals()) + + check_variable_and_dtype(scores, "scores", ["float32"], "generate_proposals_v2") + check_variable_and_dtype( + bbox_deltas, "bbox_deltas", ["float32"], "generate_proposals_v2" + ) + check_variable_and_dtype( + im_shape, "im_shape", ["float32", "float64"], "generate_proposals_v2" + ) + check_variable_and_dtype( + anchors, "anchors", ["float32"], "generate_proposals_v2" + ) + check_variable_and_dtype( + variances, "variances", ["float32"], "generate_proposals_v2" + ) + + rpn_rois = helper.create_variable_for_type_inference(dtype=bbox_deltas.dtype) + rpn_roi_probs = helper.create_variable_for_type_inference(dtype=scores.dtype) + outputs = { + "RpnRois": rpn_rois, + "RpnRoiProbs": rpn_roi_probs, + } + if return_rois_num: + rpn_rois_num = helper.create_variable_for_type_inference(dtype="int32") + rpn_rois_num.stop_gradient = True + outputs["RpnRoisNum"] = rpn_rois_num + + helper.append_op( + type="generate_proposals_v2", + inputs={ + "Scores": scores, + "BboxDeltas": bbox_deltas, + "ImShape": im_shape, + "Anchors": anchors, + "Variances": variances, + }, + attrs={ + "pre_nms_topN": pre_nms_top_n, + "post_nms_topN": post_nms_top_n, + "nms_thresh": nms_thresh, + "min_size": min_size, + "eta": eta, + "pixel_offset": pixel_offset, + }, + outputs=outputs, + ) + rpn_rois.stop_gradient = True + rpn_roi_probs.stop_gradient = True + if not return_rois_num: + rpn_rois_num = None + + return rpn_rois, rpn_roi_probs, rpn_rois_num + + +def sigmoid_cross_entropy_with_logits(input, label, ignore_index=-100, normalize=False): + output = F.binary_cross_entropy_with_logits(input, label, reduction="none") + mask_tensor = paddle.cast(label != ignore_index, "float32") + output = paddle.multiply(output, mask_tensor) + if normalize: + sum_valid_mask = paddle.sum(mask_tensor) + output = output / sum_valid_mask + return output + + +def smooth_l1(input, label, inside_weight=None, outside_weight=None, sigma=None): + input_new = paddle.multiply(input, inside_weight) + label_new = paddle.multiply(label, inside_weight) + delta = 1 / (sigma * sigma) + out = F.smooth_l1_loss(input_new, label_new, reduction="none", delta=delta) + out = paddle.multiply(out, outside_weight) + out = out / delta + out = paddle.reshape(out, shape=[out.shape[0], -1]) + out = paddle.sum(out, axis=1) + return out + + +def channel_shuffle(x, groups): + batch_size, num_channels, height, width = x.shape[0:4] + assert num_channels % groups == 0, "num_channels should be divisible by groups" + channels_per_group = num_channels // groups + x = paddle.reshape( + x=x, shape=[batch_size, groups, channels_per_group, height, width] + ) + x = paddle.transpose(x=x, perm=[0, 2, 1, 3, 4]) + x = paddle.reshape(x=x, shape=[batch_size, num_channels, height, width]) + return x + + +def get_static_shape(tensor): + shape = paddle.shape(tensor) + shape.stop_gradient = True + return shape diff --git a/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/position_encoding.py b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/position_encoding.py new file mode 100644 index 0000000000..442b1b0cf2 --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/position_encoding.py @@ -0,0 +1,103 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from __future__ import absolute_import, division, print_function + +import math + +import paddle +import paddle.nn as nn + + +class PositionEmbedding(nn.Layer): + def __init__( + self, + num_pos_feats=128, + temperature=10000, + normalize=True, + scale=2 * math.pi, + embed_type="sine", + num_embeddings=50, + offset=0.0, + eps=1e-6, + ): + super(PositionEmbedding, self).__init__() + assert embed_type in ["sine", "learned"] + + self.embed_type = embed_type + self.offset = offset + self.eps = eps + if self.embed_type == "sine": + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + self.scale = scale + elif self.embed_type == "learned": + self.row_embed = nn.Embedding(num_embeddings, num_pos_feats) + self.col_embed = nn.Embedding(num_embeddings, num_pos_feats) + else: + raise ValueError(f"{self.embed_type} is not supported.") + + def forward(self, mask): + """ + Args: + mask (Tensor): [B, H, W] + Returns: + pos (Tensor): [B, H, W, C] + """ + if self.embed_type == "sine": + y_embed = mask.cumsum(1) + x_embed = mask.cumsum(2) + if self.normalize: + y_embed = ( + (y_embed + self.offset) + / (y_embed[:, -1:, :] + self.eps) + * self.scale + ) + x_embed = ( + (x_embed + self.offset) + / (x_embed[:, :, -1:] + self.eps) + * self.scale + ) + + dim_t = 2 * (paddle.arange(self.num_pos_feats) // 2).astype("float32") + dim_t = self.temperature ** (dim_t / self.num_pos_feats) + + pos_x = x_embed.unsqueeze(-1) / dim_t + pos_y = y_embed.unsqueeze(-1) / dim_t + pos_x = paddle.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), axis=4 + ).flatten(3) + pos_y = paddle.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), axis=4 + ).flatten(3) + return paddle.concat((pos_y, pos_x), axis=3) + elif self.embed_type == "learned": + h, w = mask.shape[-2:] + i = paddle.arange(w) + j = paddle.arange(h) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + return paddle.concat( + [ + x_emb.unsqueeze(0).tile([h, 1, 1]), + y_emb.unsqueeze(1).tile([1, w, 1]), + ], + axis=-1, + ).unsqueeze(0) + else: + raise ValueError(f"not supported {self.embed_type}") diff --git a/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/utils.py b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/utils.py new file mode 100644 index 0000000000..b16db76adc --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/modules/utils.py @@ -0,0 +1,544 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# Modified from detrex (https://github.com/IDEA-Research/detrex) +# Copyright 2022 The IDEA Authors. All rights reserved. + +from __future__ import absolute_import, division, print_function + +import copy +import math + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +__all__ = [ + "_get_clones", + "bbox_cxcywh_to_xyxy", + "bbox_xyxy_to_cxcywh", + "sigmoid_focal_loss", + "inverse_sigmoid", + "deformable_attention_core_func", + "varifocal_loss_with_logits", + "mal_loss_with_logits", +] + + +def _get_clones(module, N): + return nn.LayerList([copy.deepcopy(module) for _ in range(N)]) + + +def bbox_cxcywh_to_xyxy(x): + cxcy, wh = paddle.split(x, 2, axis=-1) + return paddle.concat([cxcy - 0.5 * wh, cxcy + 0.5 * wh], axis=-1) + + +def bbox_xyxy_to_cxcywh(x): + x1, y1, x2, y2 = x.split(4, axis=-1) + return paddle.concat([(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)], axis=-1) + + +def sigmoid_focal_loss(logit, label, normalizer=1.0, alpha=0.25, gamma=2.0): + prob = F.sigmoid(logit) + ce_loss = F.binary_cross_entropy_with_logits(logit, label, reduction="none") + p_t = prob * label + (1 - prob) * (1 - label) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * label + (1 - alpha) * (1 - label) + loss = alpha_t * loss + return loss.mean(1).sum() / normalizer + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clip(min=0.0, max=1.0) + return paddle.log(x.clip(min=eps) / (1 - x).clip(min=eps)) + + +def deformable_attention_core_func( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, +): + """ + Args: + value (Tensor): [bs, value_length, n_head, c] + value_spatial_shapes (Tensor|List): [n_levels, 2] + value_level_start_index (Tensor|List): [n_levels] + sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2] + attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points] + + Returns: + output (Tensor): [bs, Length_{query}, C] + """ + bs, _, n_head, c = value.shape + _, Len_q, _, n_levels, n_points, _ = sampling_locations.shape + + split_shape = [h * w for h, w in value_spatial_shapes] + value_list = value.split(split_shape, axis=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level, (h, w) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = ( + value_list[level] + .flatten(2) + .transpose([0, 2, 1]) + .reshape([bs * n_head, c, h, w]) + ) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = ( + sampling_grids[:, :, :, level].transpose([0, 2, 1, 3, 4]).flatten(0, 1) + ) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample( + value_l_, + sampling_grid_l_, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_*M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose([0, 2, 1, 3, 4]).reshape( + [bs * n_head, 1, Len_q, n_levels * n_points] + ) + output = ( + (paddle.stack(sampling_value_list, axis=-2).flatten(-2) * attention_weights) + .sum(-1) + .reshape([bs, n_head * c, Len_q]) + ) + + return output.transpose([0, 2, 1]) + + +def discrete_sample(x, grid): + """ + Args: + x (Tensor): [N, C, H, W] + grid (Tensor): [N, grid_H, grid_W, 2] + Returns: + output (Tensor): [N, C, grid_H, grid_W] + """ + N, C, H, W = x.shape + _, grid_H, grid_W, _ = grid.shape + spatial_shape = paddle.to_tensor([[W, H]], dtype=paddle.float32) + index = (grid * spatial_shape + 0.5).astype(paddle.int64).flatten(1, 2) + h_index = index[:, :, 1].clip(0, H - 1) + w_index = index[:, :, 0].clip(0, W - 1) + batch_index = paddle.arange(N).unsqueeze(-1).tile([1, grid_H * grid_W]) + output = x[batch_index, :, h_index, w_index] + output = output.transpose([0, 2, 1]).reshape([N, C, grid_H, grid_W]) + return output + + +def deformable_attention_core_func_v2( + value, + value_spatial_shapes, + sampling_locations, + attention_weights, + num_points_list, + sampling_method="default", +): + """ + Args: + value (Tensor): [batch_num, value_len, num_heads, head_dim] + value_spatial_shapes (Tensor|List): [n_levels, 2] + sampling_locations (Tensor): [batch_num, query_len, num_heads, total_num_points, 2] + attention_weights (Tensor): [batch_num, query_len, num_heads, total_num_points] + num_points_list (List): The number of sampling point corresponding to each level + sampling_method (str): default(grid_sample) or discrete(discrete_sample) + + Returns: + output (Tensor): [batch_num, query_len, num_heads * head_dim] + """ + assert sampling_method in ["default", "discrete"], NotImplementedError + batch_num, _, num_heads, head_dim = value.shape + query_len = sampling_locations.shape[1] + num_levels = len(num_points_list) + + value = value.transpose([0, 2, 3, 1]).flatten(0, 1) + split_shape = [h * w for h, w in value_spatial_shapes] + value_list = value.split(split_shape, axis=-1) + value_list = [ + value.reshape([-1, head_dim, h, w]) + for value, (h, w) in zip(value_list, value_spatial_shapes) + ] + + if sampling_method == "default": + sampling_grids = 2 * sampling_locations - 1 + else: + sampling_grids = sampling_locations + + sampling_grids = sampling_grids.transpose([0, 2, 1, 3, 4]).flatten(0, 1) + sampling_grids_list = sampling_grids.split(num_points_list, axis=-2) + + sampling_value_list = [] + for idx in range(num_levels): + # value_list[idx]: [batch_num * num_heads, head_dim, h, w] + # sampling_grids_list[idx]: [batch_num * num_heads, query_len, num_points, 2] + # _sampling_value: [batch_num * num_heads, head_dim, query_len, num_points] + if sampling_method == "default": + _sampling_value = F.grid_sample( + value_list[idx], + sampling_grids_list[idx], + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ) + else: + _sampling_value = discrete_sample(value_list[idx], sampling_grids_list[idx]) + sampling_value_list.append(_sampling_value) + + attn_weights = attention_weights.transpose([0, 2, 1, 3]) + attn_weights = attn_weights.flatten(0, 1).unsqueeze(1) + sampling_value = paddle.concat(sampling_value_list, axis=-1) + # attn_weights: [batch_num * num_heads, 1, query_len, total_num_points] + # sampling_value: [batch_num * num_heads, head_dim, query_len, total_num_points] + # output: [batch_num * num_heads, head_dim, query_len] + output = (sampling_value * attn_weights).sum(-1) + output = output.reshape([batch_num, num_heads * head_dim, query_len]) + return output.transpose([0, 2, 1]) + + +def get_valid_ratio(mask): + _, H, W = mask.shape + valid_ratio_h = paddle.sum(mask[:, :, 0], 1) / H + valid_ratio_w = paddle.sum(mask[:, 0, :], 1) / W + # [b, 2] + return paddle.stack([valid_ratio_w, valid_ratio_h], -1) + + +def get_denoising_training_group( + targets, + num_classes, + num_queries, + class_embed, + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, +): + if num_denoising <= 0: + return None, None, None, None + num_gts = [len(t) for t in targets["gt_class"]] + max_gt_num = max(num_gts) + if max_gt_num == 0: + return None, None, None, None + + num_group = num_denoising // max_gt_num + num_group = 1 if num_group == 0 else num_group + # pad gt to max_num of a batch + bs = len(targets["gt_class"]) + input_query_class = paddle.full([bs, max_gt_num], num_classes, dtype="int32") + input_query_bbox = paddle.zeros([bs, max_gt_num, 4]) + pad_gt_mask = paddle.zeros([bs, max_gt_num]) + for i in range(bs): + num_gt = num_gts[i] + if num_gt > 0: + input_query_class[i, :num_gt] = targets["gt_class"][i].squeeze(-1) + input_query_bbox[i, :num_gt] = targets["gt_bbox"][i] + pad_gt_mask[i, :num_gt] = 1 + + input_query_class = input_query_class.tile([1, num_group]) + input_query_bbox = input_query_bbox.tile([1, num_group, 1]) + pad_gt_mask = pad_gt_mask.tile([1, num_group]) + + dn_positive_idx = paddle.nonzero(pad_gt_mask)[:, 1] + dn_positive_idx = paddle.split(dn_positive_idx, [n * num_group for n in num_gts]) + # total denoising queries + num_denoising = int(max_gt_num * num_group) + + if label_noise_ratio > 0: + input_query_class = paddle.assign(input_query_class.flatten()) + pad_gt_mask = paddle.assign(pad_gt_mask.flatten()) + # half of bbox prob, cast mask from bool to float bacause dtype promotaion + # between bool and float is not supported in static mode. + mask = paddle.cast( + paddle.rand(input_query_class.shape) < (label_noise_ratio * 0.5), + paddle.float32, + ) + chosen_idx = paddle.nonzero(mask * pad_gt_mask).squeeze(-1) + # randomly put a new one here + new_label = paddle.randint_like( + chosen_idx, 0, num_classes, dtype=input_query_class.dtype + ) + input_query_class.scatter_(chosen_idx, new_label) + input_query_class.reshape_([bs, num_denoising]) + pad_gt_mask.reshape_([bs, num_denoising]) + + if box_noise_scale > 0: + diff = ( + paddle.concat( + [input_query_bbox[..., 2:] * 0.5, input_query_bbox[..., 2:]], axis=-1 + ) + * box_noise_scale + ) + diff *= paddle.rand(input_query_bbox.shape) * 2.0 - 1.0 + input_query_bbox += diff + input_query_bbox = inverse_sigmoid(input_query_bbox) + + class_embed = paddle.concat([class_embed, paddle.zeros([1, class_embed.shape[-1]])]) + input_query_class = paddle.gather( + class_embed, input_query_class.flatten(), axis=0 + ).reshape([bs, num_denoising, -1]) + + tgt_size = num_denoising + num_queries + attn_mask = paddle.ones([tgt_size, tgt_size]) < 0 + # match query cannot see the reconstruction + attn_mask[num_denoising:, :num_denoising] = True + # reconstruct cannot see each other + for i in range(num_group): + if i == 0: + attn_mask[ + max_gt_num * i : max_gt_num * (i + 1), + max_gt_num * (i + 1) : num_denoising, + ] = True + if i == num_group - 1: + attn_mask[max_gt_num * i : max_gt_num * (i + 1), : max_gt_num * i] = True + else: + attn_mask[ + max_gt_num * i : max_gt_num * (i + 1), + max_gt_num * (i + 1) : num_denoising, + ] = True + attn_mask[max_gt_num * i : max_gt_num * (i + 1), : max_gt_num * i] = True + attn_mask = ~attn_mask + dn_meta = { + "dn_positive_idx": dn_positive_idx, + "dn_num_group": num_group, + "dn_num_split": [num_denoising, num_queries], + } + + return input_query_class, input_query_bbox, attn_mask, dn_meta + + +def get_contrastive_denoising_training_group( + targets, + num_classes, + num_queries, + class_embed, + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, +): + if num_denoising <= 0: + return None, None, None, None + # listcomp is not well-supported in SOT mode for now. + num_gts = [] + for t in targets["gt_class"]: + num_gts.append(len(t)) + max_gt_num = max(num_gts) + if max_gt_num == 0: + return None, None, None, None + + num_group = num_denoising // max_gt_num + num_group = 1 if num_group == 0 else num_group + # pad gt to max_num of a batch + bs = len(targets["gt_class"]) + input_query_class = paddle.full([bs, max_gt_num], num_classes, dtype="int32") + input_query_bbox = paddle.zeros([bs, max_gt_num, 4]) + pad_gt_mask = paddle.zeros([bs, max_gt_num]) + for i in range(bs): + num_gt = num_gts[i] + if num_gt > 0: + input_query_class[i, :num_gt] = targets["gt_class"][i].squeeze(-1) + input_query_bbox[i, :num_gt] = targets["gt_bbox"][i] + pad_gt_mask[i, :num_gt] = 1 + # each group has positive and negative queries. + input_query_class = input_query_class.tile([1, 2 * num_group]) + input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1]) + pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group]) + # positive and negative mask + negative_gt_mask = paddle.zeros([bs, max_gt_num * 2, 1]) + negative_gt_mask[:, max_gt_num:] = 1 + negative_gt_mask = negative_gt_mask.tile([1, num_group, 1]) + positive_gt_mask = 1 - negative_gt_mask + # contrastive denoising training positive index + positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask + dn_positive_idx = paddle.nonzero(positive_gt_mask)[:, 1] + dn_positive_idx = paddle.split(dn_positive_idx, [n * num_group for n in num_gts]) + # total denoising queries + num_denoising = int(max_gt_num * 2 * num_group) + + if label_noise_ratio > 0: + input_query_class = paddle.assign(input_query_class.flatten()) + pad_gt_mask = paddle.assign(pad_gt_mask.flatten()) + # half of bbox prob + mask = paddle.rand(input_query_class.shape) < (label_noise_ratio * 0.5) + chosen_idx = paddle.nonzero(mask.cast(pad_gt_mask.dtype) * pad_gt_mask).squeeze( + -1 + ) + # randomly put a new one here + new_label = paddle.randint_like( + chosen_idx, 0, num_classes, dtype=input_query_class.dtype + ) + input_query_class.scatter_(chosen_idx, new_label) + input_query_class.reshape_([bs, num_denoising]) + pad_gt_mask.reshape_([bs, num_denoising]) + + if box_noise_scale > 0: + known_bbox = bbox_cxcywh_to_xyxy(input_query_bbox) + + diff = paddle.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale + + rand_sign = paddle.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0 + rand_part = paddle.rand(input_query_bbox.shape) + rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * ( + 1 - negative_gt_mask + ) + rand_part *= rand_sign + known_bbox += rand_part * diff + known_bbox.clip_(min=0.0, max=1.0) + input_query_bbox = bbox_xyxy_to_cxcywh(known_bbox) + input_query_bbox = inverse_sigmoid(input_query_bbox) + + class_embed = paddle.concat([class_embed, paddle.zeros([1, class_embed.shape[-1]])]) + input_query_class = paddle.gather( + class_embed, input_query_class.flatten(), axis=0 + ).reshape([bs, num_denoising, -1]) + + tgt_size = num_denoising + num_queries + attn_mask = paddle.ones([tgt_size, tgt_size]) < 0 + # match query cannot see the reconstruction + attn_mask[num_denoising:, :num_denoising] = True + # reconstruct cannot see each other + for i in range(num_group): + if i == 0: + attn_mask[ + max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), + max_gt_num * 2 * (i + 1) : num_denoising, + ] = True + if i == num_group - 1: + attn_mask[ + max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * i * 2 + ] = True + else: + attn_mask[ + max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), + max_gt_num * 2 * (i + 1) : num_denoising, + ] = True + attn_mask[ + max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * 2 * i + ] = True + attn_mask = ~attn_mask + dn_meta = { + "dn_positive_idx": dn_positive_idx, + "dn_num_group": num_group, + "dn_num_split": [num_denoising, num_queries], + } + + return input_query_class, input_query_bbox, attn_mask, dn_meta + + +def get_sine_pos_embed( + pos_tensor, num_pos_feats=128, temperature=10000, exchange_xy=True +): + """generate sine position embedding from a position tensor + + Args: + pos_tensor (Tensor): Shape as `(None, n)`. + num_pos_feats (int): projected shape for each float in the tensor. Default: 128 + temperature (int): The temperature used for scaling + the position embedding. Default: 10000. + exchange_xy (bool, optional): exchange pos x and pos y. \ + For example, input tensor is `[x, y]`, the results will # noqa + be `[pos(y), pos(x)]`. Defaults: True. + + Returns: + Tensor: Returned position embedding # noqa + with shape `(None, n * num_pos_feats)`. + """ + scale = 2.0 * math.pi + dim_t = 2.0 * paddle.floor_divide(paddle.arange(num_pos_feats), paddle.to_tensor(2)) + dim_t = scale / temperature ** (dim_t / num_pos_feats) + + def sine_func(x): + x *= dim_t + return paddle.stack((x[:, :, 0::2].sin(), x[:, :, 1::2].cos()), axis=3).flatten( + 2 + ) + + pos_res = [sine_func(x) for x in pos_tensor.split(pos_tensor.shape[-1], -1)] + if exchange_xy: + pos_res[0], pos_res[1] = pos_res[1], pos_res[0] + pos_res = paddle.concat(pos_res, axis=2) + return pos_res + + +def mask_to_box_coordinate(mask, normalize=False, format="xyxy", dtype="float32"): + """ + Compute the bounding boxes around the provided mask. + Args: + mask (Tensor:bool): [b, c, h, w] + + Returns: + bbox (Tensor): [b, c, 4] + """ + assert mask.ndim == 4 + assert format in ["xyxy", "xywh"] + + h, w = mask.shape[-2:] + y, x = paddle.meshgrid( + paddle.arange(end=h, dtype=dtype), paddle.arange(end=w, dtype=dtype) + ) + + x_mask = x * mask.astype(x.dtype) + x_max = x_mask.flatten(-2).max(-1) + 1 + x_min = ( + paddle.where(mask.astype(bool), x_mask, paddle.to_tensor(1e8)) + .flatten(-2) + .min(-1) + ) + + y_mask = y * mask.astype(y.dtype) + y_max = y_mask.flatten(-2).max(-1) + 1 + y_min = ( + paddle.where(mask.astype(bool), y_mask, paddle.to_tensor(1e8)) + .flatten(-2) + .min(-1) + ) + out_bbox = paddle.stack([x_min, y_min, x_max, y_max], axis=-1) + mask = mask.any(axis=[2, 3]).unsqueeze(2) + out_bbox = out_bbox * mask.astype(out_bbox.dtype) + if normalize: + out_bbox /= paddle.to_tensor([w, h, w, h]).astype(dtype) + + return out_bbox if format == "xyxy" else bbox_xyxy_to_cxcywh(out_bbox) + + +def varifocal_loss_with_logits( + pred_logits, gt_score, label, normalizer=1.0, alpha=0.75, gamma=2.0 +): + pred_score = F.sigmoid(pred_logits) + weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label + loss = F.binary_cross_entropy_with_logits( + pred_logits, gt_score, weight=weight, reduction="none" + ) + return loss.mean(1).sum() / normalizer + + +def mal_loss_with_logits( + pred_logits, gt_score, label, normalizer=1.0, alpha=1.0, gamma=1.5 +): + pred_score = F.sigmoid(pred_logits) + gt_score = gt_score.pow(gamma) + weight = alpha * pred_score.pow(gamma) * (1 - label) + label + loss = F.binary_cross_entropy_with_logits( + pred_logits, gt_score, weight=weight, reduction="none" + ) + return loss.mean(1).sum() / normalizer diff --git a/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/rtdetr_transformer.py b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/rtdetr_transformer.py new file mode 100644 index 0000000000..4eeb36c487 --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/rtdetrl_modules/rtdetr_transformer.py @@ -0,0 +1,646 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modified from Deformable-DETR (https://github.com/fundamentalvision/Deformable-DETR) +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Modified from detrex (https://github.com/IDEA-Research/detrex) +# Copyright 2022 The IDEA Authors. All rights reserved. + +from __future__ import absolute_import, division, print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.regularizer import L2Decay + +from .detr_head import MLP +from .modules.deformable_transformer import MSDeformableAttention +from .modules.detr_ops import _get_clones, inverse_sigmoid +from .modules.initializer import ( + bias_init_with_prob, + constant_, + linear_init_, + xavier_uniform_, +) +from .modules.layers import MultiHeadAttention +from .modules.utils import get_contrastive_denoising_training_group + +__all__ = ["RTDETRTransformer"] + + +class PPMSDeformableAttention(MSDeformableAttention): + def forward( + self, + query, + reference_points, + value, + value_spatial_shapes, + value_level_start_index, + value_mask=None, + ): + """ + Args: + query (Tensor): [bs, query_length, C] + reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area + value (Tensor): [bs, value_length, C] + value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + value_level_start_index (List): [n_levels], [0, H_0*W_0, H_0*W_0+H_1*W_1, ...] + value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements + + Returns: + output (Tensor): [bs, Length_{query}, C] + """ + bs, Len_q = query.shape[:2] + Len_v = value.shape[1] + + value = self.value_proj(value) + if value_mask is not None: + value_mask = value_mask.astype(value.dtype).unsqueeze(-1) + value *= value_mask + value = value.reshape([bs, Len_v, self.num_heads, self.head_dim]) + + sampling_offsets = self.sampling_offsets(query).reshape( + [bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2] + ) + attention_weights = self.attention_weights(query).reshape( + [bs, Len_q, self.num_heads, self.num_levels * self.num_points] + ) + attention_weights = F.softmax(attention_weights).reshape( + [bs, Len_q, self.num_heads, self.num_levels, self.num_points] + ) + + if reference_points.shape[-1] == 2: + offset_normalizer = paddle.to_tensor(value_spatial_shapes) + offset_normalizer = offset_normalizer.flip([1]).reshape( + [1, 1, 1, self.num_levels, 1, 2] + ) + sampling_locations = ( + reference_points.reshape([bs, Len_q, 1, self.num_levels, 1, 2]) + + sampling_offsets / offset_normalizer + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets + / self.num_points + * reference_points[:, :, None, :, None, 2:] + * 0.5 + ) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.".format( + reference_points.shape[-1] + ) + ) + + if not isinstance(query, paddle.Tensor): + from ppdet.modeling.transformers.utils import deformable_attention_core_func + + output = deformable_attention_core_func( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ) + else: + value_spatial_shapes = paddle.to_tensor(value_spatial_shapes) + value_level_start_index = paddle.to_tensor(value_level_start_index) + output = self.ms_deformable_attn_core( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ) + output = self.output_proj(output) + + return output + + +class TransformerDecoderLayer(nn.Layer): + def __init__( + self, + d_model=256, + n_head=8, + dim_feedforward=1024, + dropout=0.0, + activation="relu", + n_levels=4, + n_points=4, + weight_attr=None, + bias_attr=None, + ): + super(TransformerDecoderLayer, self).__init__() + + # self attention + self.self_attn = MultiHeadAttention(d_model, n_head, dropout=dropout) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm( + d_model, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)), + ) + + # cross attention + self.cross_attn = PPMSDeformableAttention( + d_model, n_head, n_levels, n_points, 1.0 + ) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm( + d_model, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)), + ) + + # ffn + self.linear1 = nn.Linear(d_model, dim_feedforward, weight_attr, bias_attr) + self.activation = getattr(F, activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, weight_attr, bias_attr) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm( + d_model, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)), + ) + self._reset_parameters() + + def _reset_parameters(self): + linear_init_(self.linear1) + linear_init_(self.linear2) + xavier_uniform_(self.linear1.weight) + xavier_uniform_(self.linear2.weight) + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + return self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + + def forward( + self, + tgt, + reference_points, + memory, + memory_spatial_shapes, + memory_level_start_index, + attn_mask=None, + memory_mask=None, + query_pos_embed=None, + ): + # self attention + q = k = self.with_pos_embed(tgt, query_pos_embed) + if attn_mask is not None: + attn_mask = paddle.where( + attn_mask.astype("bool"), + paddle.zeros(attn_mask.shape, tgt.dtype), + paddle.full(attn_mask.shape, float("-inf"), tgt.dtype), + ) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=attn_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # cross attention + tgt2 = self.cross_attn( + self.with_pos_embed(tgt, query_pos_embed), + reference_points, + memory, + memory_spatial_shapes, + memory_level_start_index, + memory_mask, + ) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # ffn + tgt2 = self.forward_ffn(tgt) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + + return tgt + + +class TransformerDecoder(nn.Layer): + def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1): + super(TransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx + + def forward( + self, + tgt, + ref_points_unact, + memory, + memory_spatial_shapes, + memory_level_start_index, + bbox_head, + score_head, + query_pos_head, + attn_mask=None, + memory_mask=None, + query_pos_head_inv_sig=False, + ): + output = tgt + dec_out_bboxes = [] + dec_out_logits = [] + ref_points_detach = F.sigmoid(ref_points_unact) + for i, layer in enumerate(self.layers): + ref_points_input = ref_points_detach.unsqueeze(2) + if not query_pos_head_inv_sig: + query_pos_embed = query_pos_head(ref_points_detach) + else: + query_pos_embed = query_pos_head(inverse_sigmoid(ref_points_detach)) + + output = layer( + output, + ref_points_input, + memory, + memory_spatial_shapes, + memory_level_start_index, + attn_mask, + memory_mask, + query_pos_embed, + ) + + inter_ref_bbox = F.sigmoid( + bbox_head[i](output) + inverse_sigmoid(ref_points_detach) + ) + + if self.training: + dec_out_logits.append(score_head[i](output)) + if i == 0: + dec_out_bboxes.append(inter_ref_bbox) + else: + dec_out_bboxes.append( + F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points)) + ) + elif i == self.eval_idx: + dec_out_logits.append(score_head[i](output)) + dec_out_bboxes.append(inter_ref_bbox) + break + + ref_points = inter_ref_bbox + ref_points_detach = ( + inter_ref_bbox.detach() if self.training else inter_ref_bbox + ) + + return paddle.stack(dec_out_bboxes), paddle.stack(dec_out_logits) + + +class RTDETRTransformer(nn.Layer): + __shared__ = ["num_classes", "hidden_dim", "eval_size"] + + def __init__( + self, + num_classes=80, + hidden_dim=256, + num_queries=300, + position_embed_type="sine", + backbone_feat_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + num_levels=3, + num_decoder_points=4, + nhead=8, + num_decoder_layers=6, + dim_feedforward=1024, + dropout=0.0, + activation="relu", + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learnt_init_query=True, + query_pos_head_inv_sig=False, + eval_size=None, + eval_idx=-1, + eps=1e-2, + ): + super(RTDETRTransformer, self).__init__() + assert position_embed_type in [ + "sine", + "learned", + ], f"ValueError: position_embed_type not supported {position_embed_type}!" + assert len(backbone_feat_channels) <= num_levels + assert len(feat_strides) == len(backbone_feat_channels) + for _ in range(num_levels - len(feat_strides)): + feat_strides.append(feat_strides[-1] * 2) + + self.hidden_dim = hidden_dim + self.nhead = nhead + self.feat_strides = feat_strides + self.num_levels = num_levels + self.num_classes = num_classes + self.num_queries = num_queries + self.eps = eps + self.num_decoder_layers = num_decoder_layers + self.eval_size = eval_size + + # backbone feature projection + self._build_input_proj_layer(backbone_feat_channels) + + # Transformer module + decoder_layer = TransformerDecoderLayer( + hidden_dim, + nhead, + dim_feedforward, + dropout, + activation, + num_levels, + num_decoder_points, + ) + self.decoder = TransformerDecoder( + hidden_dim, decoder_layer, num_decoder_layers, eval_idx + ) + + # denoising part + self.denoising_class_embed = nn.Embedding( + num_classes, + hidden_dim, + weight_attr=ParamAttr(initializer=nn.initializer.Normal()), + ) + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + + # decoder embedding + self.learnt_init_query = learnt_init_query + if learnt_init_query: + self.tgt_embed = nn.Embedding(num_queries, hidden_dim) + self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2) + self.query_pos_head_inv_sig = query_pos_head_inv_sig + + # encoder head + self.enc_output = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm( + hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)), + ), + ) + self.enc_score_head = nn.Linear(hidden_dim, num_classes) + self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3) + + # decoder head + self.dec_score_head = nn.LayerList( + [nn.Linear(hidden_dim, num_classes) for _ in range(num_decoder_layers)] + ) + self.dec_bbox_head = nn.LayerList( + [ + MLP(hidden_dim, hidden_dim, 4, num_layers=3) + for _ in range(num_decoder_layers) + ] + ) + + self._reset_parameters() + + def _reset_parameters(self): + # class and bbox head init + bias_cls = bias_init_with_prob(0.01) + linear_init_(self.enc_score_head) + constant_(self.enc_score_head.bias, bias_cls) + constant_(self.enc_bbox_head.layers[-1].weight) + constant_(self.enc_bbox_head.layers[-1].bias) + for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head): + linear_init_(cls_) + constant_(cls_.bias, bias_cls) + constant_(reg_.layers[-1].weight) + constant_(reg_.layers[-1].bias) + + linear_init_(self.enc_output[0]) + xavier_uniform_(self.enc_output[0].weight) + if self.learnt_init_query: + xavier_uniform_(self.tgt_embed.weight) + xavier_uniform_(self.query_pos_head.layers[0].weight) + xavier_uniform_(self.query_pos_head.layers[1].weight) + for l in self.input_proj: + xavier_uniform_(l[0].weight) + + # init encoder output anchors and valid_mask + if self.eval_size: + self.anchors, self.valid_mask = self._generate_anchors() + + @classmethod + def from_config(cls, cfg, input_shape): + return {"backbone_feat_channels": [i.channels for i in input_shape]} + + def _build_input_proj_layer(self, backbone_feat_channels): + self.input_proj = nn.LayerList() + for in_channels in backbone_feat_channels: + self.input_proj.append( + nn.Sequential( + ( + "conv", + nn.Conv2D( + in_channels, self.hidden_dim, kernel_size=1, bias_attr=False + ), + ), + ( + "norm", + nn.BatchNorm2D( + self.hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)), + ), + ), + ) + ) + in_channels = backbone_feat_channels[-1] + for _ in range(self.num_levels - len(backbone_feat_channels)): + self.input_proj.append( + nn.Sequential( + ( + "conv", + nn.Conv2D( + in_channels, + self.hidden_dim, + kernel_size=3, + stride=2, + padding=1, + bias_attr=False, + ), + ), + ( + "norm", + nn.BatchNorm2D( + self.hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)), + ), + ), + ) + ) + in_channels = self.hidden_dim + + def _get_encoder_input(self, feats): + # get projection features + proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] + if self.num_levels > len(proj_feats): + len_srcs = len(proj_feats) + for i in range(len_srcs, self.num_levels): + if i == len_srcs: + proj_feats.append(self.input_proj[i](feats[-1])) + else: + proj_feats.append(self.input_proj[i](proj_feats[-1])) + + # get encoder inputs + feat_flatten = [] + spatial_shapes = [] + level_start_index = [ + 0, + ] + for i, feat in enumerate(proj_feats): + _, _, h, w = feat.shape + # [b, c, h, w] -> [b, h*w, c] + feat_flatten.append(feat.flatten(2).transpose([0, 2, 1])) + # [num_levels, 2] + spatial_shapes.append([h, w]) + # [l], start index of each level + level_start_index.append(h * w + level_start_index[-1]) + + # [b, l, c] + feat_flatten = paddle.concat(feat_flatten, 1) + level_start_index.pop() + return (feat_flatten, spatial_shapes, level_start_index) + + def forward(self, feats, pad_mask=None, gt_meta=None, is_teacher=False): + # input projection and embedding + (memory, spatial_shapes, level_start_index) = self._get_encoder_input(feats) + + # prepare denoising training + if self.training: + denoising_class, denoising_bbox_unact, attn_mask, dn_meta = ( + get_contrastive_denoising_training_group( + gt_meta, + self.num_classes, + self.num_queries, + self.denoising_class_embed.weight, + self.num_denoising, + self.label_noise_ratio, + self.box_noise_scale, + ) + ) + else: + denoising_class, denoising_bbox_unact, attn_mask, dn_meta = ( + None, + None, + None, + None, + ) + + target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = ( + self._get_decoder_input( + memory, + spatial_shapes, + denoising_class, + denoising_bbox_unact, + is_teacher, + ) + ) + + # decoder + out_bboxes, out_logits = self.decoder( + target, + init_ref_points_unact, + memory, + spatial_shapes, + level_start_index, + self.dec_bbox_head, + self.dec_score_head, + self.query_pos_head, + attn_mask=attn_mask, + memory_mask=None, + query_pos_head_inv_sig=self.query_pos_head_inv_sig, + ) + return (out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta) + + def _generate_anchors(self, spatial_shapes=None, grid_size=0.05, dtype="float32"): + if spatial_shapes is None: + spatial_shapes = [ + [int(self.eval_size[0] / s), int(self.eval_size[1] / s)] + for s in self.feat_strides + ] + anchors = [] + for lvl, (h, w) in enumerate(spatial_shapes): + grid_y, grid_x = paddle.meshgrid( + paddle.arange(end=h, dtype=dtype), paddle.arange(end=w, dtype=dtype) + ) + grid_xy = paddle.stack([grid_x, grid_y], -1) + + valid_WH = paddle.to_tensor([h, w]).astype(dtype) + grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH + wh = paddle.ones_like(grid_xy) * grid_size * (2.0**lvl) + anchors.append(paddle.concat([grid_xy, wh], -1).reshape([-1, h * w, 4])) + + anchors = paddle.concat(anchors, 1) + valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all( + -1, keepdim=True + ) + anchors = paddle.log(anchors / (1 - anchors)) + anchors = paddle.where(valid_mask, anchors, paddle.to_tensor(float("inf"))) + return anchors, valid_mask + + def _get_decoder_input( + self, + memory, + spatial_shapes, + denoising_class=None, + denoising_bbox_unact=None, + is_teacher=False, + ): + bs, _, _ = memory.shape + # prepare input for decoder + if self.training or self.eval_size is None or is_teacher: + anchors, valid_mask = self._generate_anchors(spatial_shapes) + else: + anchors, valid_mask = self.anchors, self.valid_mask + memory = paddle.where(valid_mask, memory, paddle.to_tensor(0.0)) + output_memory = self.enc_output(memory) + + enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors + + _, topk_ind = paddle.topk(enc_outputs_class.max(-1), self.num_queries, axis=1) + # extract region proposal boxes + batch_ind = paddle.arange(end=bs, dtype=topk_ind.dtype) + batch_ind = batch_ind.unsqueeze(-1).tile([1, self.num_queries]) + topk_ind = paddle.stack([batch_ind, topk_ind], axis=-1) + + reference_points_unact = paddle.gather_nd( + enc_outputs_coord_unact, topk_ind + ) # unsigmoided. + enc_topk_bboxes = F.sigmoid(reference_points_unact) + if denoising_bbox_unact is not None: + reference_points_unact = paddle.concat( + [denoising_bbox_unact, reference_points_unact], 1 + ) + if self.training: + reference_points_unact = reference_points_unact.detach() + enc_topk_logits = paddle.gather_nd(enc_outputs_class, topk_ind) + + # extract region features + if self.learnt_init_query: + target = self.tgt_embed.weight.unsqueeze(0).tile([bs, 1, 1]) + else: + target = paddle.gather_nd(output_memory, topk_ind) + if self.training: + target = target.detach() + if denoising_class is not None: + target = paddle.concat([denoising_class, target], 1) + + return target, reference_points_unact, enc_topk_bboxes, enc_topk_logits diff --git a/paddlex/inference/models/object_detection/predictor.py b/paddlex/inference/models/object_detection/predictor.py index c049cc83c7..631ea0e59e 100644 --- a/paddlex/inference/models/object_detection/predictor.py +++ b/paddlex/inference/models/object_detection/predictor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ import numpy as np from ....modules.object_detection.model_list import MODELS +from ....utils.device import TemporaryDeviceChanger from ....utils.func_register import FuncRegister from ...common.batch_sampler import ImageBatchSampler from ..base import BasePredictor @@ -104,6 +105,7 @@ def __init__( "small", ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'] or a dict, but got {layout_merge_bboxes_mode}" + self.device = kwargs.get("device", None) self.img_size = img_size self.threshold = threshold self.layout_nms = layout_nms @@ -143,7 +145,18 @@ def _build(self) -> Tuple: if self._use_static_model: infer = self.create_static_infer() else: - if self.model_name not in []: + if self.model_name == "RT-DETR-L": + from .modeling import RTDETR + + with TemporaryDeviceChanger(self.device): + infer = RTDETR.from_pretrained( + self.model_dir, + use_safetensors=True, + convert_from_hf=True, + dtype="float32", + ) + infer.eval() + else: raise RuntimeError( f"There is no dynamic graph implementation for model {repr(self.model_name)}." ) @@ -237,7 +250,8 @@ def process( batch_inputs = self.pre_ops[-1](datas) # do infer - batch_preds = self.infer(batch_inputs) + with TemporaryDeviceChanger(self.device): + batch_preds = self.infer(batch_inputs) # process a batch of predictions into a list of single image result preds_list = self._format_output(batch_preds)