diff --git a/mmdet/engine/hooks/visualization_hook.py b/mmdet/engine/hooks/visualization_hook.py index 3408186b6ef..90df932f9af 100644 --- a/mmdet/engine/hooks/visualization_hook.py +++ b/mmdet/engine/hooks/visualization_hook.py @@ -390,7 +390,7 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, gt_bboxes = gt_instances.get('bboxes', None) if gt_bboxes is not None and isinstance(gt_bboxes, BaseBoxes): gt_instances.bboxes = gt_bboxes.tensor - print(gt_labels, tokens_positive, gt_bboxes, img_path) + # print(gt_labels, tokens_positive, gt_bboxes, img_path) pred_instances = data_sample.pred_instances pred_instances = pred_instances[ pred_instances.scores > self.score_thr] @@ -416,8 +416,8 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, self._visualizer.set_image(img) for label, bbox, color in zip(gt_labels, gt_bboxes, colors): - self._visualizer.draw_bboxes( - bbox, edge_colors=color, face_colors=color, alpha=0.3) + # self._visualizer.draw_bboxes( + # bbox, edge_colors=color, face_colors=color, alpha=0.3) self._visualizer.draw_bboxes( bbox, edge_colors=color, alpha=1) @@ -460,11 +460,11 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, for label, bbox, color in zip(pred_labels, pred_bboxes, colors): - self._visualizer.draw_bboxes( - bbox, edge_colors=color, face_colors=color, alpha=0.3) + # self._visualizer.draw_bboxes( + # bbox, edge_colors=color, face_colors=color, alpha=0.3) self._visualizer.draw_bboxes( bbox, edge_colors=color, alpha=1) - print(pred_labels, pred_bboxes, pred_scores, colors) + # print(pred_labels, pred_bboxes, pred_scores, colors) areas = (pred_bboxes[:, 3] - pred_bboxes[:, 1]) * ( pred_bboxes[:, 2] - pred_bboxes[:, 0]) scales = _get_adaptive_scales(areas) diff --git a/mmdet/models/dense_heads/grounding_dino_head.py b/mmdet/models/dense_heads/grounding_dino_head.py index 8088322546f..468acc17aeb 100644 --- a/mmdet/models/dense_heads/grounding_dino_head.py +++ b/mmdet/models/dense_heads/grounding_dino_head.py @@ -18,6 +18,7 @@ from ..layers import inverse_sigmoid from .atss_vlfusion_head import convert_grounding_to_cls_scores from .dino_head import DINOHead +import torch.nn.functional as F class ContrastiveEmbed(nn.Module): @@ -60,7 +61,7 @@ def __init__(self, torch.Tensor([bias_value]), requires_grad=True) def forward(self, visual_feat: Tensor, text_feat: Tensor, - text_token_mask: Tensor) -> Tensor: + text_token_mask: Tensor, need_expand=True) -> Tensor: """Forward function. Args: @@ -79,13 +80,15 @@ def forward(self, visual_feat: Tensor, text_feat: Tensor, res = res / math.sqrt(visual_feat.shape[-1]) if self.bias is not None: res = res + self.bias - res.masked_fill_(~text_token_mask[:, None, :], float('-inf')) - - new_res = torch.full((*res.shape[:-1], self.max_text_len), - float('-inf'), - device=res.device) - new_res[..., :res.shape[-1]] = res - + if need_expand: + res.masked_fill_(~text_token_mask[:, None, :], float('-inf')) + new_res = torch.full((*res.shape[:-1], self.max_text_len), + float('-inf'), + device=res.device) + new_res[..., :res.shape[-1]] = res + else: + res.masked_fill_(~text_token_mask[:, None, :], float('-inf')) + new_res = res return new_res @@ -190,10 +193,16 @@ def _get_targets_single(self, cls_score: Tensor, bbox_pred: Tensor, # Major changes. The labels are 0-1 binary labels for each bbox # and text tokens. - labels = gt_bboxes.new_full((num_bboxes, self.max_text_len), - 0, - dtype=torch.float32) - labels[pos_inds] = gt_instances.positive_maps[pos_assigned_gt_inds] + if 'positive_maps' in gt_instances: + labels = gt_bboxes.new_full((num_bboxes, self.max_text_len), + 0, + dtype=torch.float32) + labels[pos_inds] = gt_instances.positive_maps[pos_assigned_gt_inds] + else: + labels = gt_bboxes.new_full((num_bboxes,), + cls_score.size(1), + dtype=torch.long) + labels[pos_inds] = gt_instances.labels[pos_assigned_gt_inds] label_weights = gt_bboxes.new_ones(num_bboxes) # bbox targets @@ -211,11 +220,12 @@ def _get_targets_single(self, cls_score: Tensor, bbox_pred: Tensor, neg_inds) def forward( - self, - hidden_states: Tensor, - references: List[Tensor], - memory_text: Tensor, - text_token_mask: Tensor, + self, + hidden_states: Tensor, + references: List[Tensor], + memory_text: Tensor, + text_token_mask: Tensor, + need_expand=True ) -> Tuple[Tensor]: """Forward function. @@ -257,7 +267,7 @@ def forward( hidden_state = hidden_states[layer_id] outputs_class = self.cls_branches[layer_id](hidden_state, memory_text, - text_token_mask) + text_token_mask, need_expand) tmp_reg_preds = self.reg_branches[layer_id](hidden_state) if reference.shape[-1] == 4: # When `layer` is 0 and `as_two_stage` of the detector @@ -319,12 +329,17 @@ def predict(self, batch_img_metas = [ data_samples.metainfo for data_samples in batch_data_samples ] - batch_token_positive_maps = [ - data_samples.token_positive_map - for data_samples in batch_data_samples - ] - outs = self(hidden_states, references, memory_text, text_token_mask) + need_expand = True + batch_token_positive_maps = [] + for data_samples in batch_data_samples: + if 'token_positive_map' in data_samples: + batch_token_positive_maps.append(data_samples.token_positive_map) + else: + batch_token_positive_maps.append(None) + need_expand = False + + outs = self(hidden_states, references, memory_text, text_token_mask, need_expand=need_expand) predictions = self.predict_by_feat( *outs, @@ -427,11 +442,13 @@ def _predict_by_feat_single(self, bbox_index = indexes // num_classes bbox_pred = bbox_pred[bbox_index] else: + # TODO: REC cls_score = cls_score.sigmoid() - scores, _ = cls_score.max(-1) - scores, indexes = scores.topk(max_per_img) - bbox_pred = bbox_pred[indexes] - det_labels = scores.new_zeros(scores.shape, dtype=torch.long) + scores, indexes = cls_score.view(-1).topk(max_per_img) + num_classes = cls_score.shape[-1] + det_labels = indexes % num_classes + bbox_index = indexes // num_classes + bbox_pred = bbox_pred[bbox_index] det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred) det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1] @@ -492,7 +509,12 @@ def loss(self, hidden_states: Tensor, references: List[Tensor], batch_img_metas.append(data_sample.metainfo) batch_gt_instances.append(data_sample.gt_instances) - outs = self(hidden_states, references, memory_text, text_token_mask) + if 'tokens_positive' in batch_data_samples[0]: + need_expand = True + else: + need_expand = False + + outs = self(hidden_states, references, memory_text, text_token_mask, need_expand) self.text_masks = text_token_mask loss_inputs = outs + (enc_outputs_class, enc_outputs_coord, batch_gt_instances, batch_img_metas, dn_meta) @@ -539,22 +561,28 @@ def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor, # ===== this change ===== # Loss is not computed for the padded regions of the text. assert (self.text_masks.dim() == 2) - text_masks = self.text_masks.new_zeros( - (self.text_masks.size(0), self.max_text_len)) - text_masks[:, :self.text_masks.size(1)] = self.text_masks + if 'positive_maps' in batch_gt_instances[0]: + text_masks = self.text_masks.new_zeros( + (self.text_masks.size(0), self.max_text_len)) + text_masks[:, :self.text_masks.size(1)] = self.text_masks + else: + text_masks = self.text_masks + num_classes = cls_scores.size(-1) + labels = F.one_hot(labels, num_classes=num_classes + 1) + labels = labels[..., :num_classes] text_mask = (text_masks > 0).unsqueeze(1) text_mask = text_mask.repeat(1, cls_scores.size(1), 1) cls_scores = torch.masked_select(cls_scores, text_mask).contiguous() labels = torch.masked_select(labels, text_mask) label_weights = label_weights[..., - None].repeat(1, 1, text_mask.size(-1)) + None].repeat(1, 1, text_mask.size(-1)) label_weights = torch.masked_select(label_weights, text_mask) # classification loss # construct weighted avg_factor to match with the official DETR repo cls_avg_factor = num_total_pos * 1.0 + \ - num_total_neg * self.bg_cls_weight + num_total_neg * self.bg_cls_weight if self.sync_cls_avg_factor: cls_avg_factor = reduce_mean( cls_scores.new_tensor([cls_avg_factor])) @@ -566,6 +594,9 @@ def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor, else: loss_cls = self.loss_cls( cls_scores, labels, label_weights, avg_factor=cls_avg_factor) + if torch.isnan(loss_cls): + print(f'has nan of loss_cls') + loss_cls = cls_scores.sum() * 0 # Compute the average number of gt boxes across all gpus, for # normalization purposes @@ -578,7 +609,7 @@ def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor, img_h, img_w, = img_meta['img_shape'] factor = bbox_pred.new_tensor([img_w, img_h, img_w, img_h]).unsqueeze(0).repeat( - bbox_pred.size(0), 1) + bbox_pred.size(0), 1) factors.append(factor) factors = torch.cat(factors, 0) @@ -592,10 +623,15 @@ def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor, # regression IoU loss, defaultly GIoU loss loss_iou = self.loss_iou( bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos) - + if torch.isnan(loss_iou): + print(f'has nan of loss_iou') + loss_iou = bboxes.sum() * 0 # regression L1 loss loss_bbox = self.loss_bbox( bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos) + if torch.isnan(loss_bbox): + print(f'has nan of loss_bbox') + loss_bbox = bbox_preds.sum() * 0 return loss_cls, loss_bbox, loss_iou def _loss_dn_single(self, dn_cls_scores: Tensor, dn_bbox_preds: Tensor, @@ -637,15 +673,23 @@ def _loss_dn_single(self, dn_cls_scores: Tensor, dn_bbox_preds: Tensor, # ===== this change ===== # Loss is not computed for the padded regions of the text. assert (self.text_masks.dim() == 2) - text_masks = self.text_masks.new_zeros( - (self.text_masks.size(0), self.max_text_len)) - text_masks[:, :self.text_masks.size(1)] = self.text_masks + if 'positive_maps' in batch_gt_instances[0]: + text_masks = self.text_masks.new_zeros( + (self.text_masks.size(0), self.max_text_len)) + text_masks[:, :self.text_masks.size(1)] = self.text_masks + else: + text_masks = self.text_masks + num_classes = dn_cls_scores.size(-1) + # 临时方案,由于 _get_dn_targets_single 获取不到 dn_cls_scores + labels[labels == self.max_text_len] = num_classes + labels = F.one_hot(labels, num_classes=num_classes + 1) + labels = labels[..., :num_classes] text_mask = (text_masks > 0).unsqueeze(1) text_mask = text_mask.repeat(1, dn_cls_scores.size(1), 1) cls_scores = torch.masked_select(dn_cls_scores, text_mask).contiguous() + labels = torch.masked_select(labels, text_mask) - label_weights = label_weights[..., - None].repeat(1, 1, text_mask.size(-1)) + label_weights = label_weights[..., None].repeat(1, 1, text_mask.size(-1)) label_weights = torch.masked_select(label_weights, text_mask) # ======================= @@ -667,6 +711,9 @@ def _loss_dn_single(self, dn_cls_scores: Tensor, dn_bbox_preds: Tensor, labels, label_weights, avg_factor=cls_avg_factor) + if torch.isnan(loss_cls): + print(f'has nan of dn loss_cls') + loss_cls = cls_scores.sum() * 0 else: loss_cls = torch.zeros( 1, dtype=cls_scores.dtype, device=cls_scores.device) @@ -682,7 +729,7 @@ def _loss_dn_single(self, dn_cls_scores: Tensor, dn_bbox_preds: Tensor, img_h, img_w = img_meta['img_shape'] factor = bbox_pred.new_tensor([img_w, img_h, img_w, img_h]).unsqueeze(0).repeat( - bbox_pred.size(0), 1) + bbox_pred.size(0), 1) factors.append(factor) factors = torch.cat(factors) @@ -696,15 +743,21 @@ def _loss_dn_single(self, dn_cls_scores: Tensor, dn_bbox_preds: Tensor, # regression IoU loss, defaultly GIoU loss loss_iou = self.loss_iou( bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos) + if torch.isnan(loss_iou): + print(f'has nan of dn loss_iou') + loss_iou = bboxes.sum() * 0 # regression L1 loss loss_bbox = self.loss_bbox( bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos) + if torch.isnan(loss_bbox): + print(f'has nan of dn loss_bbox') + loss_bbox = bbox_preds.sum() * 0 return loss_cls, loss_bbox, loss_iou def _get_dn_targets_single(self, gt_instances: InstanceData, img_meta: dict, dn_meta: Dict[str, - int]) -> tuple: + int]) -> tuple: """Get targets in denoising part for one image. Args: @@ -749,10 +802,17 @@ def _get_dn_targets_single(self, gt_instances: InstanceData, neg_inds = pos_inds + num_queries_each_group // 2 # label targets # this change - labels = gt_bboxes.new_full((num_denoising_queries, self.max_text_len), - 0, - dtype=torch.float32) - labels[pos_inds] = gt_instances.positive_maps[pos_assigned_gt_inds] + + if 'positive_maps' in gt_instances: + labels = gt_bboxes.new_full((num_denoising_queries, self.max_text_len), + 0, + dtype=torch.float32) + labels[pos_inds] = gt_instances.positive_maps[pos_assigned_gt_inds] + else: + labels = gt_bboxes.new_full((num_denoising_queries,), + self.max_text_len, + dtype=torch.long) + labels[pos_inds] = gt_labels[pos_assigned_gt_inds] label_weights = gt_bboxes.new_ones(num_denoising_queries) # bbox targets diff --git a/mmdet/models/detectors/grounding_dino.py b/mmdet/models/detectors/grounding_dino.py index 4ec9d14e634..fabdb2979fa 100644 --- a/mmdet/models/detectors/grounding_dino.py +++ b/mmdet/models/detectors/grounding_dino.py @@ -329,8 +329,8 @@ def forward_encoder(self, feat: Tensor, feat_mask: Tensor, # for text encoder memory_text=text_dict['embedded'], text_attention_mask=~text_token_mask, - position_ids=text_dict['position_ids'], - text_self_attention_masks=text_dict['masks']) + position_ids=text_dict.get('position_ids', None), + text_self_attention_masks=text_dict.get('masks', None)) encoder_outputs_dict = dict( memory=memory, memory_mask=feat_mask, @@ -353,13 +353,15 @@ def pre_decoder( output_memory, output_proposals = self.gen_encoder_output_proposals( memory, memory_mask, spatial_shapes) + if ('tokens_positive' in batch_data_samples[0] and batch_data_samples[0].tokens_positive !=-1) \ + or 'token_positive_map' in batch_data_samples[0]: + need_expand = True + else: + need_expand = False enc_outputs_class = self.bbox_head.cls_branches[ - self.decoder.num_layers](output_memory, memory_text, - text_token_mask) - cls_out_features = self.bbox_head.cls_branches[ - self.decoder.num_layers].max_text_len + self.decoder.num_layers](output_memory, memory_text, text_token_mask, need_expand) enc_outputs_coord_unact = self.bbox_head.reg_branches[ - self.decoder.num_layers](output_memory) + output_proposals + self.decoder.num_layers](output_memory) + output_proposals # NOTE The DINO selects top-k proposals according to scores of # multi-class classification, while DeformDETR, where the input @@ -370,7 +372,7 @@ def pre_decoder( topk_score = torch.gather( enc_outputs_class, 1, - topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features)) + topk_indices.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1])) topk_coords_unact = torch.gather( enc_outputs_coord_unact, 1, topk_indices.unsqueeze(-1).repeat(1, 1, 4)) diff --git a/mmdet/models/language_models/bert.py b/mmdet/models/language_models/bert.py index efb0f46bad6..ad1156fde64 100644 --- a/mmdet/models/language_models/bert.py +++ b/mmdet/models/language_models/bert.py @@ -14,6 +14,7 @@ HFBertModel = None from mmdet.registry import MODELS +from mmdet.models.utils import align_tensor def generate_masks_with_special_tokens_and_transfer_map( @@ -71,6 +72,16 @@ def generate_masks_with_special_tokens_and_transfer_map( return attention_mask, position_ids.to(torch.long) +def split_tensor(tensor, num_levels): + level_targets = [] + start = 0 + for n in num_levels: + end = start + n + level_targets.append(tensor[start:end]) + start = end + return level_targets + + @MODELS.register_module() class BertModel(BaseModel): """BERT model for language embedding only encoder. @@ -105,11 +116,13 @@ def __init__(self, add_pooling_layer: bool = False, num_layers_of_embedded: int = 1, use_checkpoint: bool = False, + reduce_type: str = 'avg', # avg start **kwargs) -> None: super().__init__(**kwargs) self.max_tokens = max_tokens self.pad_to_max = pad_to_max + self.reduce_type = reduce_type if AutoTokenizer is None: raise RuntimeError( @@ -134,9 +147,17 @@ def __init__(self, self.special_tokens = self.tokenizer.convert_tokens_to_ids( special_tokens_list) - def forward(self, captions: Sequence[str], **kwargs) -> dict: + def forward(self, captions: Sequence[str], task='VG', **kwargs) -> dict: """Forward function.""" device = next(self.language_backbone.parameters()).device + + if task == 'REC': + batch_len_captions = [len(item) for item in captions] + if isinstance(captions, tuple): + captions=list(captions) + if isinstance(captions[0], (list, tuple)): + captions = [item for sublist in captions for item in sublist] + tokenized = self.tokenizer.batch_encode_plus( captions, max_length=self.max_tokens, @@ -145,12 +166,11 @@ def forward(self, captions: Sequence[str], **kwargs) -> dict: return_tensors='pt', truncation=True).to(device) input_ids = tokenized.input_ids - if self.use_sub_sentence_represent: + if self.use_sub_sentence_represent and task == 'VG': attention_mask, position_ids = \ generate_masks_with_special_tokens_and_transfer_map( tokenized, self.special_tokens) token_type_ids = tokenized['token_type_ids'] - else: attention_mask = tokenized.attention_mask position_ids = None @@ -163,10 +183,27 @@ def forward(self, captions: Sequence[str], **kwargs) -> dict: 'token_type_ids': token_type_ids } language_dict_features = self.language_backbone(tokenizer_input) - if self.use_sub_sentence_represent: + if self.use_sub_sentence_represent and task == 'VG': language_dict_features['position_ids'] = position_ids language_dict_features[ 'text_token_mask'] = tokenized.attention_mask.bool() + else: + embedded = language_dict_features['embedded'] + if self.reduce_type == 'start': + end_token_idx = 0 + embedded = embedded[torch.arange(embedded.shape[0]), end_token_idx] + else: + embedded = embedded * tokenized.attention_mask[..., None].float() + embedded = embedded.sum(1) / tokenized.attention_mask.float().sum(-1)[..., None] + + embedded = split_tensor(embedded, batch_len_captions) + embedded = align_tensor(embedded) + attention_mask = split_tensor(embedded.new_ones((len(tokenized.attention_mask))).bool(), batch_len_captions) + attention_mask = align_tensor(attention_mask) + del language_dict_features['masks'] + del language_dict_features['hidden'] + language_dict_features['embedded'] = embedded + language_dict_features['text_token_mask'] = attention_mask return language_dict_features diff --git a/mmdet/models/layers/transformer/grounding_dino_layers.py b/mmdet/models/layers/transformer/grounding_dino_layers.py index 3c285768f36..43e4cf3a8ed 100644 --- a/mmdet/models/layers/transformer/grounding_dino_layers.py +++ b/mmdet/models/layers/transformer/grounding_dino_layers.py @@ -236,12 +236,27 @@ def forward(self, if self.text_layers: text_num_heads = self.text_layers[ layer_id].self_attn_cfg.num_heads + if text_self_attention_masks is None: + # rec + # l_key_padding_mask = text_attention_mask + # text_self_attention_masks1=None + + l_key_padding_mask = None + text_self_attention_masks1 = \ + torch.eye(text_attention_mask.shape[1], + device=memory_text.device).bool().unsqueeze(0).repeat( + bs, 1, 1) + else: + # phrase grounding + l_key_padding_mask = None + text_self_attention_masks1 = text_self_attention_masks memory_text = self.text_layers[layer_id]( query=memory_text, query_pos=(pos_text if pos_text is not None else None), - attn_mask=~text_self_attention_masks.repeat( - text_num_heads, 1, 1), # note we use ~ for mask here - key_padding_mask=None, + attn_mask=~text_self_attention_masks1.repeat( + text_num_heads, 1, 1) if text_self_attention_masks1 is not None else None, + # note we use ~ for mask here + key_padding_mask=l_key_padding_mask, ) output = layer( query=output, diff --git a/mmdet/models/task_modules/assigners/hungarian_assigner.py b/mmdet/models/task_modules/assigners/hungarian_assigner.py index a6745a36cdc..64afa37e9a9 100644 --- a/mmdet/models/task_modules/assigners/hungarian_assigner.py +++ b/mmdet/models/task_modules/assigners/hungarian_assigner.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import List, Optional, Union - +import numpy as np import torch from mmengine import ConfigDict from mmengine.structures import InstanceData @@ -128,6 +128,11 @@ def assign(self, raise ImportError('Please run "pip install scipy" ' 'to install scipy first.') + has_nan = np.isnan(cost).any() + if has_nan: + print(f' has nan {cost}, replace to 10000000.0') + cost[np.isnan(cost)] = 10000000.0 + matched_row_inds, matched_col_inds = linear_sum_assignment(cost) matched_row_inds = torch.from_numpy(matched_row_inds).to(device) matched_col_inds = torch.from_numpy(matched_col_inds).to(device) diff --git a/mmdet/models/task_modules/assigners/match_cost.py b/mmdet/models/task_modules/assigners/match_cost.py index 5fc62f01f29..05586d110e5 100644 --- a/mmdet/models/task_modules/assigners/match_cost.py +++ b/mmdet/models/task_modules/assigners/match_cost.py @@ -334,6 +334,25 @@ def __call__(self, @TASK_UTILS.register_module() class BinaryFocalLossCost(FocalLossCost): + def _default_focal_loss_cost(self, cls_pred: Tensor, gt_labels: Tensor) -> Tensor: + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + (num_queries, num_class). + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + torch.Tensor: cls_cost value with weight + """ + cls_pred = cls_pred.sigmoid() + neg_cost = -(1 - cls_pred + self.eps).log() * ( + 1 - self.alpha) * cls_pred.pow(self.gamma) + pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( + 1 - cls_pred).pow(self.gamma) + + cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels] + return cls_cost * self.weight + def _focal_loss_cost(self, cls_pred: Tensor, gt_labels: Tensor) -> Tensor: """ Args: @@ -378,8 +397,12 @@ def __call__(self, text_token_mask = torch.nonzero( gt_instances.text_token_mask[0]).squeeze(-1) pred_scores = pred_instances.scores[:, text_token_mask] - gt_labels = gt_instances.positive_maps[:, text_token_mask] - return self._focal_loss_cost(pred_scores, gt_labels) + if 'positive_maps' in gt_instances: + gt_labels = gt_instances.positive_maps[:, text_token_mask] + return self._focal_loss_cost(pred_scores, gt_labels) + else: + gt_labels = gt_instances.labels + return self._default_focal_loss_cost(pred_scores, gt_labels) @TASK_UTILS.register_module() diff --git a/projects/mm_gdino_clip/__init__.py b/projects/mm_gdino_clip/__init__.py new file mode 100644 index 00000000000..4a57c00811f --- /dev/null +++ b/projects/mm_gdino_clip/__init__.py @@ -0,0 +1,7 @@ +from .odvgrec import ODVGRECDataset +from .text_transformers import RandomSamplingNegPosV2 +from .batch_sampler import MultiTaskAspectRatioBatchSampler +from .grounding_dino import GroundingDINOV2 +from .concat_dataset import CustomConcatDataset + +__all__ = ['ODVGRECDataset', 'RandomSamplingNegPosV2', 'MultiTaskAspectRatioBatchSampler', 'GroundingDINOV2', 'CustomConcatDataset'] diff --git a/projects/mm_gdino_clip/batch_sampler.py b/projects/mm_gdino_clip/batch_sampler.py new file mode 100644 index 00000000000..2124961cf75 --- /dev/null +++ b/projects/mm_gdino_clip/batch_sampler.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +from torch.utils.data import BatchSampler, Sampler +from mmdet.registry import DATA_SAMPLERS +import numpy as np + + +@DATA_SAMPLERS.register_module() +class MultiTaskAspectRatioBatchSampler(BatchSampler): + def __init__(self, + sampler: Sampler, + batch_size: int, + drop_last: bool = True, + od_to_rec_prob=0.7) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError('batch_size should be a positive integer value, ' + f'but got batch_size={batch_size}') + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + # two groups for w < h and w >= h and two task + self._aspect_ratio_buckets = [[] for _ in range(2 * 2)] + self.od_to_rec_prob = od_to_rec_prob + assert drop_last is True + + def __iter__(self) -> Sequence[int]: + batch_count = 0 + total_count = len(self.sampler) // self.batch_size + for idx in self.sampler: + wh_mode = self.sampler.dataset.get_wh_mode(idx) + dataset_mode, height, width = wh_mode + bucket_id = 0 if width < height else 1 + + od_to_rec_flag = False + if dataset_mode == 'OD': + # TODO + # if np.random.random() >= 1 - self.od_to_rec_prob: + # dataset_mode = 'REC' + dataset_mode = 'REC' + od_to_rec_flag = True + else: + od_to_rec_flag = False + + # REC: 0 2 + # VG and OD: 1 3 + if dataset_mode == 'REC': + bucket_id = bucket_id * 2 + else: + bucket_id = bucket_id * 2 + 1 + bucket = self._aspect_ratio_buckets[bucket_id] + bucket.append([idx, od_to_rec_flag]) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + batch_count += 1 + del bucket[:] + + # yield the rest data and reset the bucket + left_rec_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[2] + left_vg_data = self._aspect_ratio_buckets[1] + self._aspect_ratio_buckets[3] + self._aspect_ratio_buckets = [[] for _ in range(2 * 2)] + + if batch_count >= total_count: + left_rec_data = [] + left_vg_data = [] + + while len(left_rec_data) > 0: + if len(left_rec_data) > self.batch_size: + yield left_rec_data[:self.batch_size] + batch_count += 1 + left_rec_data = left_rec_data[self.batch_size:] + if batch_count >= total_count: + left_rec_data = [] + left_vg_data = [] + else: + break + + while len(left_vg_data) > 0: + if len(left_vg_data) > self.batch_size: + yield left_vg_data[:self.batch_size] + batch_count += 1 + left_vg_data = left_vg_data[self.batch_size:] + if batch_count >= total_count: + left_rec_data = [] + left_vg_data = [] + else: + break + + if 0 < len(left_rec_data) < self.batch_size: + left_rec_data.extend([left_rec_data[-1]] * (self.batch_size - len(left_rec_data))) + + if 0 < len(left_vg_data) < self.batch_size: + left_vg_data.extend([left_vg_data[-1]] * (self.batch_size - len(left_vg_data))) + + all_left_data = left_rec_data + left_vg_data + while len(all_left_data) > 0: + yield all_left_data[:self.batch_size] + batch_count += 1 + all_left_data = all_left_data[self.batch_size:] + if batch_count >= total_count: + all_left_data = [] + + def __len__(self) -> int: + if self.drop_last: + return len(self.sampler) // self.batch_size + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size diff --git a/projects/mm_gdino_clip/browse_grounding_dataset.py b/projects/mm_gdino_clip/browse_grounding_dataset.py new file mode 100644 index 00000000000..2fcf4b51cc2 --- /dev/null +++ b/projects/mm_gdino_clip/browse_grounding_dataset.py @@ -0,0 +1,221 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp + +import numpy as np +from mmcv.image import imwrite +from mmengine.config import Config, DictAction +from mmengine.registry import init_default_scope +from mmengine.utils import ProgressBar + +from mmdet.registry import DATASETS, VISUALIZERS +from mmdet.structures.bbox import BaseBoxes + + +# configs/grounding_dino_swin-t_pretrain_obj365_goldg.py -o aa --not-show --shuffle +def parse_args(): + parser = argparse.ArgumentParser(description='Browse a dataset') + parser.add_argument('config', help='train config file path') + parser.add_argument( + '--output-dir', + '-o', + default=None, + type=str, + help='If there is no display interface, you can save it') + parser.add_argument('--not-show', default=False, action='store_true') + parser.add_argument('--show-num', '-n', type=int, default=30) + parser.add_argument('--shuffle', default=False, action='store_true') + parser.add_argument( + '--show-interval', + type=float, + default=0, + help='the interval of show (s)') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + return args + + +def draw_all_character(visualizer, characters, w): + start_index = 2 + y_index = 5 + for char in characters: + if isinstance(char, str): + visualizer.draw_texts( + str(char), + positions=np.array([start_index, y_index]), + colors=(0, 0, 0), + font_families='monospace') + start_index += len(char) * 8 + else: + visualizer.draw_texts( + str(char[0]), + positions=np.array([start_index, y_index]), + colors=char[1], + font_families='monospace') + start_index += len(char[0]) * 8 + + if start_index > w - 10: + start_index = 2 + y_index += 15 + + drawn_text = visualizer.get_image() + return drawn_text + + +def main(): + args = parse_args() + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + assert args.show_num > 0 + + # register all modules in mmdet into the registries + init_default_scope(cfg.get('default_scope', 'mmdet')) + + dataset = DATASETS.build(cfg.train_dataloader.dataset) + visualizer = VISUALIZERS.build(cfg.visualizer) + visualizer.dataset_meta = dataset.metainfo + + dataset_index = list(range(len(dataset))) + if args.shuffle: + import random + random.shuffle(dataset_index) + + progress_bar = ProgressBar(len(dataset)) + for i in dataset_index[:args.show_num]: + item = dataset[i] + img = item['inputs'].permute(1, 2, 0).numpy() + data_sample = item['data_samples'].numpy() + gt_instances = data_sample.gt_instances + + gt_labels = gt_instances.labels + + base_name = osp.basename(item['data_samples'].img_path) + name, extension = osp.splitext(base_name) + + img = img[..., [2, 1, 0]] # bgr to rgb + gt_bboxes = gt_instances.get('bboxes', None) + if gt_bboxes is not None and isinstance(gt_bboxes, BaseBoxes): + gt_instances.bboxes = gt_bboxes.tensor + + dataset_mode = data_sample.dataset_mode + print(base_name, dataset_mode, data_sample.text) + + out_file = osp.join(args.output_dir, dataset_mode + '_' + name + '_' + str(i) + + extension) if args.output_dir is not None else None + + if dataset_mode == 'VG': + tokens_positive = data_sample.tokens_positive + + max_label = int(max(gt_labels) if len(gt_labels) > 0 else 0) + palette = np.random.randint(0, 256, size=(max_label + 1, 3)) + bbox_palette = [tuple(c) for c in palette] + # bbox_palette = get_palette('random', max_label + 1) + colors = [bbox_palette[label] for label in gt_labels] + + visualizer.set_image(img) + + for label, bbox, color in zip(gt_labels, gt_bboxes, colors): + visualizer.draw_bboxes( + bbox, edge_colors=color, face_colors=color, alpha=0.3) + visualizer.draw_bboxes(bbox, edge_colors=color, alpha=1) + + drawn_img = visualizer.get_image() + + new_image = np.ones((100, img.shape[1], 3), dtype=np.uint8) * 255 + visualizer.set_image(new_image) + + gt_tokens_positive = [ + tokens_positive[label] for label in gt_labels + ] + split_by_character = [char for char in data_sample.text] + characters = [] + start_index = 0 + end_index = 0 + for w in split_by_character: + end_index += len(w) + is_find = False + for i, positive in enumerate(gt_tokens_positive): + for p in positive: + if start_index >= p[0] and end_index <= p[1]: + characters.append([w, colors[i]]) + is_find = True + break + if is_find: + break + if not is_find: + characters.append([w, (0, 0, 0)]) + start_index = end_index + + drawn_text = draw_all_character(visualizer, characters, + img.shape[1]) + drawn_img = np.concatenate((drawn_img, drawn_text), axis=0) + elif dataset_mode == 'OD': + tokens_positive = data_sample.tokens_positive + gt_labels = gt_instances.labels + text = data_sample.text + label_names = [] + for label in gt_labels: + label_names.append(text[ + tokens_positive[label][0][0]:tokens_positive[label][0][1]]) + gt_instances.label_names = label_names + data_sample.gt_instances = gt_instances + + visualizer.add_datasample( + base_name, + img, + data_sample, + draw_pred=False, + show=False, + wait_time=0, + out_file=None) + drawn_img = visualizer.get_image() + + new_image = np.ones((100, img.shape[1], 3), dtype=np.uint8) * 255 + visualizer.set_image(new_image) + + characters = [char for char in text] + drawn_text = draw_all_character(visualizer, characters, + img.shape[1]) + drawn_img = np.concatenate((drawn_img, drawn_text), axis=0) + else: + gt_labels = gt_instances.labels + text = data_sample.text + label_names = [] + for label in gt_labels: + label_names.append(text[label]) + gt_instances.label_names = label_names + data_sample.gt_instances = gt_instances + + visualizer.add_datasample( + base_name, + img, + data_sample, + draw_pred=False, + show=False, + wait_time=0, + out_file=None) + drawn_img = visualizer.get_image() + + if not args.not_show: + visualizer.show( + drawn_img, win_name=base_name, wait_time=args.show_interval) + + if out_file is not None: + imwrite(drawn_img[..., ::-1], out_file) + + progress_bar.update() + + +if __name__ == '__main__': + main() diff --git a/projects/mm_gdino_clip/browse_grounding_raw.py b/projects/mm_gdino_clip/browse_grounding_raw.py new file mode 100644 index 00000000000..ab57e6a5d48 --- /dev/null +++ b/projects/mm_gdino_clip/browse_grounding_raw.py @@ -0,0 +1,291 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import json +import os.path as osp + +import cv2 +import numpy as np +from mmcv.image import imfrombytes, imwrite +from mmengine.fileio import get +from mmengine.structures import InstanceData +from mmengine.utils import mkdir_or_exist + +from mmdet.structures import DetDataSample +from mmdet.visualization import DetLocalVisualizer +from mmdet.visualization.palette import _get_adaptive_scales + +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + + +# /home/PJLAB/huanghaian/dataset/coco2014/ mdetr_annotations/finetune_refcocog_train_ref.json train2014 --not-show --shuffle -o rex +def parse_args(): + parser = argparse.ArgumentParser(description='Browse a dataset') + parser.add_argument('data_root') + parser.add_argument('ann_file') + parser.add_argument('img_prefix') + parser.add_argument('--label-map-file', '-m', default=None) + parser.add_argument( + '--output-dir', + '-o', + default=None, + type=str, + help='If there is no display interface, you can save it') + parser.add_argument('--not-show', default=False, action='store_true') + parser.add_argument('--show-num', '-n', type=int, default=30) + parser.add_argument('--shuffle', default=False, action='store_true') + parser.add_argument( + '--show-interval', + type=float, + default=0, + help='the interval of show (s)') + args = parser.parse_args() + return args + + +def draw_all_character(visualizer, characters, w): + start_index = 2 + y_index = 5 + for char in characters: + if isinstance(char, str): + visualizer.draw_texts( + str(char), + positions=np.array([start_index, y_index]), + colors=(0, 0, 0), + font_families='monospace') + start_index += len(char) * 8 + else: + visualizer.draw_texts( + str(char[0]), + positions=np.array([start_index, y_index]), + colors=char[1], + font_families='monospace') + start_index += len(char[0]) * 8 + + if start_index > w - 10: + start_index = 2 + y_index += 15 + + drawn_text = visualizer.get_image() + return drawn_text + + +def main(): + args = parse_args() + assert args.show_num > 0 + + local_path = osp.join(args.data_root, args.ann_file) + with open(local_path, 'r') as f: + data_list = [json.loads(line) for line in f] + + dataset_index = list(range(len(data_list))) + if args.shuffle: + import random + random.shuffle(dataset_index) + + if args.label_map_file is not None: + label_map_file = osp.join(args.data_root, args.label_map_file) + with open(label_map_file, 'r') as file: + label_map = json.load(file) + + visualizer = DetLocalVisualizer() + + for i in dataset_index[:args.show_num]: + item = data_list[i] + + img_path = osp.join(args.data_root, args.img_prefix, item['filename']) + if backend_args is not None: + img_bytes = get(img_path, backend_args) + img = imfrombytes(img_bytes, flag='color') + else: + img = cv2.imread(img_path) + img = img[..., [2, 1, 0]] # bgr to rgb + + base_name, extension = osp.splitext(item['filename']) + + out_file = osp.join(args.output_dir, base_name + '_' + str(i) + + extension) if args.output_dir is not None else None + + if args.output_dir is not None: + mkdir_or_exist(args.output_dir) + + if 'detection' in item: + anno = item['detection'] + + instances = [obj for obj in anno['instances']] + bboxes = [obj['bbox'] for obj in instances] + bbox_labels = [int(obj['label']) for obj in instances] + label_names = [label_map[str(label)] for label in bbox_labels] + + data_sample = DetDataSample() + gt_instances = InstanceData() + if len(instances) > 0 and 'score' in instances[0]: + score = [obj['score'] for obj in instances] + gt_instances['scores'] = np.array(score) + + gt_instances['bboxes'] = np.array(bboxes).reshape(-1, 4) + gt_instances['labels'] = np.array(bbox_labels) + gt_instances['label_names'] = label_names + data_sample.gt_instances = gt_instances + + visualizer.add_datasample( + osp.basename(img_path), + img, + data_sample, + draw_pred=False, + show=not args.not_show, + wait_time=args.show_interval, + out_file=out_file) + elif 'grounding' in item: + anno = item['grounding'] + text = anno['caption'] + regions = anno['regions'] + + max_label = len(regions) if len(regions) > 0 else 0 + palette = np.random.randint(0, 256, size=(max_label + 1, 3)) + bbox_palette = [tuple(c) for c in palette] + # bbox_palette = get_palette('random', max_label + 1) + colors = [bbox_palette[label] for label in range(max_label)] + + visualizer.set_image(img) + + gt_tokens_positive = [] + for i, region in enumerate(regions): + bbox = region['bbox'] + bbox = np.array(bbox).reshape(-1, 4) + tokens_positive = region['tokens_positive'] + gt_tokens_positive.append(tokens_positive) + visualizer.draw_bboxes( + bbox, + edge_colors=colors[i], + face_colors=colors[i], + alpha=0.3) + visualizer.draw_bboxes(bbox, edge_colors=colors[i], alpha=1) + + if 'score' in region: + areas = (bbox[:, 3] - bbox[:, 1]) * ( + bbox[:, 2] - bbox[:, 0]) + scales = _get_adaptive_scales(areas) + score = region['score'][0] + score = [str(s) for s in score] + font_sizes = [ + int(13 * scales[i]) for i in range(len(scales)) + ] + visualizer.draw_texts( + score, + bbox[:, :2].astype(np.int32), + colors=(255, 255, 255), + font_sizes=font_sizes, + bboxes=[{ + 'facecolor': 'black', + 'alpha': 0.8, + 'pad': 0.7, + 'edgecolor': 'none' + }] * len(bbox)) + + drawn_img = visualizer.get_image() + new_image = np.ones((100, img.shape[1], 3), dtype=np.uint8) * 255 + visualizer.set_image(new_image) + + split_by_character = [char for char in text] + characters = [] + start_index = 0 + end_index = 0 + for w in split_by_character: + end_index += len(w) + is_find = False + for i, positive in enumerate(gt_tokens_positive): + for p in positive: + if start_index >= p[0] and end_index <= p[1]: + characters.append([w, colors[i]]) + is_find = True + break + if is_find: + break + if not is_find: + characters.append([w, (0, 0, 0)]) + start_index = end_index + + drawn_text = draw_all_character(visualizer, characters, + img.shape[1]) + drawn_img = np.concatenate((drawn_img, drawn_text), axis=0) + + if not args.not_show: + visualizer.show( + drawn_img, + win_name=base_name, + wait_time=args.show_interval) + + if out_file is not None: + imwrite(drawn_img[..., ::-1], out_file) + + elif 'referring' in item: + referring = item['referring']['instances'] + + max_label = len(referring) if len(referring) > 0 else 0 + palette = np.random.randint(0, 256, size=(max_label + 1, 3)) + bbox_palette = [tuple(c) for c in palette] + # bbox_palette = get_palette('random', max_label + 1) + colors = [bbox_palette[label] for label in range(max_label)] + + visualizer.set_image(img) + phrases = [] + for i, ref in enumerate(referring): + bbox = ref['bbox'] + if isinstance(ref['exp'], list): + phrases.append(' / '.join(ref['exp'])) + else: + phrases.append(ref['exp']) + bbox = np.array(bbox).reshape(-1, 4) + + # visualizer.draw_bboxes( + # bbox, + # edge_colors=colors[i], + # face_colors=colors[i], + # alpha=0.3) + visualizer.draw_bboxes(bbox, edge_colors=colors[i], alpha=1) + drawn_img = visualizer.get_image() + + new_image = np.ones((len(phrases) * 20 + 100, img.shape[1], 3), dtype=np.uint8) * 255 + visualizer.set_image(new_image) + + start_index = 2 + y_index = 5 + + chunk_size = max(min(img.shape[1] - 400, 70), 50) + for i, p in enumerate(phrases): + if not isinstance(p, list): + p = [p] + + for _p in p: + chunk_p = [ + _p[i:i + chunk_size] for i in range(0, len(_p), chunk_size) + ] + for cp in chunk_p: + visualizer.draw_texts( + cp, + positions=np.array([start_index, y_index]), + colors=colors[i], + font_families='monospace') + y_index += 15 + + drawn_text = visualizer.get_image() + drawn_img = np.concatenate((drawn_img, drawn_text), axis=0) + + if not args.not_show: + visualizer.show( + drawn_img, + win_name=base_name, + wait_time=args.show_interval) + + if out_file is not None: + imwrite(drawn_img[..., ::-1], out_file) + + +if __name__ == '__main__': + main() diff --git a/projects/mm_gdino_clip/concat_dataset.py b/projects/mm_gdino_clip/concat_dataset.py new file mode 100644 index 00000000000..ddb27d98685 --- /dev/null +++ b/projects/mm_gdino_clip/concat_dataset.py @@ -0,0 +1,38 @@ +from mmdet.datasets import ConcatDataset as _ConcatDataset +from mmdet.registry import DATASETS +from mmengine.logging import print_log +import logging + + +@DATASETS.register_module() +class CustomConcatDataset(_ConcatDataset): + + def __getitem__(self, idx: list): + if not self._fully_initialized: + print_log( + 'Please call `full_init` method manually to ' + 'accelerate the speed.', + logger='current', + level=logging.WARNING) + self.full_init() + + od_to_rec_flag = idx[1] + + dataset_idx, sample_idx = self._get_ori_dataset_idx(idx[0]) + + if od_to_rec_flag: + for _ in range(30): + data_info = self.datasets[dataset_idx].get_data_info(sample_idx) + assert data_info['dataset_mode'] == 'OD' + data_info['dataset_mode'] = 'REC' + data = self.datasets[dataset_idx].pipeline(data_info) + if data is None: + sample_idx = self.datasets[dataset_idx]._rand_another() + continue + return data + else: + return self.datasets[dataset_idx][sample_idx] + + def get_wh_mode(self, idx): + dataset_idx, sample_idx = self._get_ori_dataset_idx(idx) + return self.datasets[dataset_idx].get_wh_mode(sample_idx) diff --git a/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365.py b/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365.py new file mode 100644 index 00000000000..bec0933aea7 --- /dev/null +++ b/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365.py @@ -0,0 +1,278 @@ +_base_ = [ + '../../../configs/_base_/datasets/coco_detection.py', + '../../../configs/_base_/schedules/schedule_1x.py', '../../../configs/_base_/default_runtime.py' +] +pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' # noqa +lang_model_name = 'bert-base-uncased' + +custom_imports = dict( + imports=['projects.mm_gdino_clip'], allow_failed_imports=False) + +model = dict( + type='GroundingDINOV2', + num_queries=900, + with_box_refine=True, + as_two_stage=True, + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_mask=False, + ), + language_model=dict( + type='BertModel', + name=lang_model_name, + max_tokens=256, + pad_to_max=False, + use_sub_sentence_represent=True, + special_tokens_list=['[CLS]', '[SEP]', '.', '?'], + add_pooling_layer=False, + use_checkpoint=False, # change this + ), + backbone=dict( + type='SwinTransformer', + embed_dims=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + patch_norm=True, + out_indices=(1, 2, 3), + with_cp=True, + convert_weights=True, + frozen_stages=-1, + init_cfg=dict(type='Pretrained', checkpoint=pretrained)), + neck=dict( + type='ChannelMapper', + in_channels=[192, 384, 768], + kernel_size=1, + out_channels=256, + act_cfg=None, + bias=True, + norm_cfg=dict(type='GN', num_groups=32), + num_outs=4), + encoder=dict( + num_layers=6, + num_cp=6, + # visual layer config + layer_cfg=dict( + self_attn_cfg=dict(embed_dims=256, num_levels=4, dropout=0.0), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=2048, ffn_drop=0.0)), + # text layer config + text_layer_cfg=dict( + self_attn_cfg=dict(num_heads=4, embed_dims=256, dropout=0.0), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=1024, ffn_drop=0.0)), + # fusion layer config + fusion_layer_cfg=dict( + v_dim=256, + l_dim=256, + embed_dim=1024, + num_heads=4, + init_values=1e-4), + ), + decoder=dict( + num_layers=6, + return_intermediate=True, + layer_cfg=dict( + # query self attention layer + self_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0), + # cross attention layer query to text + cross_attn_text_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0), + # cross attention layer query to image + cross_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=2048, ffn_drop=0.0)), + post_norm_cfg=None), + positional_encoding=dict( + num_feats=128, normalize=True, offset=0.0, temperature=20), + bbox_head=dict( + type='GroundingDINOHead', + num_classes=256, + sync_cls_avg_factor=True, + contrastive_cfg=dict(max_text_len=256, log_scale='auto', bias=True), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), # 2.0 in DeformDETR + loss_bbox=dict(type='L1Loss', loss_weight=5.0)), + dn_cfg=dict( # TODO: Move to model.train_cfg ? + label_noise_scale=0.5, + box_noise_scale=1.0, # 0.4 for DN-DETR + group_cfg=dict(dynamic=True, num_groups=None, + num_dn_queries=100)), # TODO: half num_dn_queries + # training and testing settings + train_cfg=dict( + assigner=dict( + type='HungarianAssigner', + match_costs=[ + dict(type='BinaryFocalLossCost', weight=2.0), + dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'), + dict(type='IoUCost', iou_mode='giou', weight=2.0) + ])), + test_cfg=dict(max_per_img=300)) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=_base_.backend_args), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='RandomFlip', prob=0.5), + # dict( + # type='FixScaleResize', + # scale=(400, 400), + # keep_ratio=True, + # backend='pillow'), + dict( + type='RandomChoice', + transforms=[ + [ + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ], + [ + dict( + type='RandomChoiceResize', + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ] + ]), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)), + dict( + type='RandomSamplingNegPosV2', + tokenizer_name=lang_model_name, + num_sample_negative=85, + max_tokens=256), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction', 'text', + 'custom_entities', 'tokens_positive', 'dataset_mode')) +] + +test_pipeline = [ + dict( + type='LoadImageFromFile', backend_args=None, + imdecode_backend='pillow'), + dict( + type='FixScaleResize', + scale=(800, 1333), + keep_ratio=True, + backend='pillow'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'text', 'custom_entities', + 'tokens_positive')) +] + +dataset_type = 'ODVGRECDataset' + +o365_data_root = 'obj365v1_200/' +obj365_od_dataset = dict( + type=dataset_type, + data_root=o365_data_root, + ann_file='o365v1_train_odvg.json', + label_map_file='o365v1_label_map.json', + data_prefix=dict(img='train/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=train_pipeline, + return_classes=True, + backend_args=None) + +rec_data_root = 'data/coco/' +rec_rec_dataset = dict( + type=dataset_type, + data_root=rec_data_root, + ann_file='mdetr_annotations/finetune_refcocog_train_ref.json', + data_prefix=dict(img='train2014/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=train_pipeline, + return_classes=True, + backend_args=None) + +flickr30k_vg_data_root = 'flickr30k_200/' +flickr30k_vg_dataset = dict( + type=dataset_type, + data_root=flickr30k_vg_data_root, + ann_file='final_flickr_separateGT_train_vg.json', + data_prefix=dict(img='flickr30k_images/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=train_pipeline, + return_classes=True, + backend_args=None) + +train_dataloader = dict( + _delete_=True, + batch_size=4, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + batch_sampler=dict(type='MultiTaskAspectRatioBatchSampler', od_to_rec_prob=0.7), + dataset=dict(type='CustomConcatDataset', datasets=[obj365_od_dataset])) + +val_dataloader = dict( + dataset=dict(pipeline=test_pipeline, return_classes=True)) +test_dataloader = val_dataloader + +optim_wrapper = dict( + _delete_=True, + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.0004, + weight_decay=0.0001), + clip_grad=dict(max_norm=0.1, norm_type=2), + paramwise_cfg=dict( + custom_keys={ + 'absolute_pos_embed': dict(decay_mult=0.), + 'backbone': dict(lr_mult=0.1), + 'language_model': dict(lr_mult=0.1), + })) + +# learning policy +max_epochs = 30 +param_scheduler = [ + dict(type='LinearLR', start_factor=0.1, by_epoch=False, begin=0, end=1000), + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[19, 26], + gamma=0.1) +] + +train_cfg = dict( + type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1) + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (16 GPUs) x (2 samples per GPU) +auto_scale_lr = dict(base_batch_size=64) + +default_hooks = dict(visualization=dict(type='GroundingVisualizationHook')) diff --git a/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det.py b/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det.py new file mode 100644 index 00000000000..16ca1491389 --- /dev/null +++ b/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det.py @@ -0,0 +1,117 @@ +_base_ = 'grounding_dino_swin-t_pretrain_obj365.py' + +o365v1_od_dataset = dict( + type='ODVGRECDataset', + data_root='data/objects365v1/', + ann_file='o365v1_train_odvg.json', + label_map_file='o365v1_label_map.json', + data_prefix=dict(img='train/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None, +) + +flickr30k_dataset = dict( + type='ODVGRECDataset', + data_root='data/flickr30k/', + ann_file='flickr30k_separateGT_train_rec.json', + label_map_file=None, + data_prefix=dict(img='flickr30k_images/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None) + +gqa_dataset = dict( + type='ODVGRECDataset', + data_root='data/gqa/', + ann_file='gqa_rec.json', + label_map_file=None, + data_prefix=dict(img='images/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None) + +v3d_train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=_base_.backend_args), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='RandomFlip', prob=0.5), + dict( + type='RandomChoice', + transforms=[ + [ + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ], + [ + dict( + type='RandomChoiceResize', + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ] + ]), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)), + dict( + type='RandomSamplingNegPosV2', + tokenizer_name=_base_.lang_model_name, + num_sample_negative=85, + # change this + label_map_file='data/V3Det/annotations/v3det_2023_v1_label_map.json', + max_tokens=256), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction', 'text', + 'custom_entities', 'tokens_positive', 'dataset_mode')) +] +v3det_dataset = dict( + type='ODVGDataset', + data_root='data/V3Det/', + ann_file='annotations/v3det_2023_v1_train_od.json', + label_map_file='annotations/v3det_2023_v1_label_map.json', + data_prefix=dict(img=''), + filter_cfg=dict(filter_empty_gt=False), + need_text=False, # change this + pipeline=v3d_train_pipeline, + return_classes=True, + backend_args=None) + +grit_dataset = dict( + type='ODVGDataset', + data_root='grit_processed/', + ann_file='grit20m_rec.json', + label_map_file=None, + data_prefix=dict(img=''), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None) + +train_dataloader = dict( + sampler=dict( + _delete_=True, + type='CustomSampleSizeSampler', + dataset_size=[-1, -1, -1, -1, 500000]), + dataset=dict(datasets=[ + o365v1_od_dataset, flickr30k_dataset, gqa_dataset, v3det_dataset, + grit_dataset + ])) diff --git a/projects/mm_gdino_clip/grounding_dino.py b/projects/mm_gdino_clip/grounding_dino.py new file mode 100644 index 00000000000..f42354e0514 --- /dev/null +++ b/projects/mm_gdino_clip/grounding_dino.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import re +import warnings +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import OptSampleList, SampleList +from mmdet.utils import ConfigType +from mmdet.models.detectors import GroundingDINO + +task_map = {'REC': 0, 'VG': 1} + + +@MODELS.register_module() +class GroundingDINOV2(GroundingDINO): + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> Union[dict, list]: + tasks = [data_samples.dataset_mode for data_samples in batch_data_samples] + tasks = [task_map[task] for task in tasks] + assert len(set(tasks)) == 1, 'Only support one task in one batch, but got {}'.format(tasks) + + if tasks[0] == 1: + # VG + return super().loss(batch_inputs, batch_data_samples) + else: + # REC + text_prompts = [ + data_samples.text for data_samples in batch_data_samples + ] + + text_dict = self.language_model(text_prompts, task='REC') + if self.text_feat_map is not None: + text_dict['embedded'] = self.text_feat_map(text_dict['embedded']) + + for i, data_samples in enumerate(batch_data_samples): + # for calc BinaryFocalLossCost + text_token_mask = text_dict['text_token_mask'][i] + data_samples.gt_instances.text_token_mask = \ + text_token_mask.unsqueeze(0).repeat( + len(data_samples.gt_instances), 1) + + visual_features = self.extract_feat(batch_inputs) + head_inputs_dict = self.forward_transformer(visual_features, text_dict, + batch_data_samples) + + losses = self.bbox_head.loss( + **head_inputs_dict, batch_data_samples=batch_data_samples) + return losses + + def predict(self, batch_inputs, batch_data_samples, rescale: bool = True): + # only od eval for now + text_prompts = [data_samples.text for data_samples in batch_data_samples] + text_prompts = text_prompts[0] + + visual_feats = self.extract_feat(batch_inputs) + + text_dict = self.language_model([text_prompts], task='REC') + if self.text_feat_map is not None: + text_dict['embedded'] = self.text_feat_map( + text_dict['embedded']) + head_inputs_dict = self.forward_transformer( + visual_feats, text_dict, batch_data_samples) + results_list = self.bbox_head.predict( + **head_inputs_dict, + rescale=rescale, + batch_data_samples=batch_data_samples) + + for data_sample, pred_instances in zip( + batch_data_samples, results_list): + if len(pred_instances) > 0: + label_names = [] + for labels in pred_instances.labels: + label_names.append(text_prompts[labels]) + # for visualization + pred_instances.label_names = label_names + data_sample.pred_instances = pred_instances + return batch_data_samples diff --git a/projects/mm_gdino_clip/odvgrec.py b/projects/mm_gdino_clip/odvgrec.py new file mode 100644 index 00000000000..d0be2b27a8a --- /dev/null +++ b/projects/mm_gdino_clip/odvgrec.py @@ -0,0 +1,165 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os.path as osp +from typing import List, Optional + +from mmengine.fileio import get_local_path + +from mmdet.registry import DATASETS +from mmdet.datasets import BaseDetDataset + + +@DATASETS.register_module() +class ODVGRECDataset(BaseDetDataset): + """object detection and visual grounding dataset.""" + + def __init__(self, + *args, + data_root: str = '', + label_map_file: Optional[str] = None, + need_text: bool = True, + **kwargs) -> None: + self.dataset_mode = 'VG' + self.need_text = need_text + if label_map_file: + label_map_file = osp.join(data_root, label_map_file) + with open(label_map_file, 'r') as file: + self.label_map = json.load(file) + self.dataset_mode = 'OD' + super().__init__(*args, data_root=data_root, **kwargs) + assert self.return_classes is True + + def load_data_list(self) -> List[dict]: + self.image_to_exp = {} + self.wh_modes=[] + with get_local_path( + self.ann_file, backend_args=self.backend_args) as local_path: + with open(local_path, 'r') as f: + data_list = [json.loads(line) for line in f] + + out_data_list = [] + for data in data_list: + data_info = {} + img_path = osp.join(self.data_prefix['img'], data['filename']) + data_info['img_path'] = img_path + data_info['height'] = data['height'] + data_info['width'] = data['width'] + + if 'referring' in data: + self.dataset_mode = 'REC' + + if self.dataset_mode == 'OD': + if self.need_text: + data_info['text'] = self.label_map + anno = data.get('detection', {}) + instances = [obj for obj in anno.get('instances', [])] + bboxes = [obj['bbox'] for obj in instances] + bbox_labels = [str(obj['label']) for obj in instances] + + instances = [] + for bbox, label in zip(bboxes, bbox_labels): + instance = {} + x1, y1, x2, y2 = bbox + inter_w = max(0, min(x2, data['width']) - max(x1, 0)) + inter_h = max(0, min(y2, data['height']) - max(y1, 0)) + if inter_w * inter_h == 0: + continue + if (x2 - x1) < 1 or (y2 - y1) < 1: + continue + instance['ignore_flag'] = 0 + instance['bbox'] = bbox + instance['bbox_label'] = int(label) + instances.append(instance) + data_info['instances'] = instances + data_info['dataset_mode'] = self.dataset_mode + self.wh_modes.append([self.dataset_mode, data_info['height'], data_info['width']]) + out_data_list.append(data_info) + elif self.dataset_mode == 'REC': + anno = data.get('referring', {}) + instances = [obj for obj in anno.get('instances', [])] + bboxes = [obj['bbox'] for obj in instances] + bbox_exp = [obj['exp'] for obj in instances] + + self.image_to_exp[img_path] = bbox_exp + + bbox_labels = list(range(len(bboxes))) + + phrases = {} + instances = [] + i = 0 + for bbox, exp, label in zip(bboxes, bbox_exp, bbox_labels): + if not isinstance(bbox[0], list): + bbox = [bbox] + for b in bbox: + instance = {} + x1, y1, x2, y2 = b + inter_w = max(0, min(x2, data['width']) - max(x1, 0)) + inter_h = max(0, min(y2, data['height']) - max(y1, 0)) + if inter_w * inter_h == 0: + continue + if (x2 - x1) < 1 or (y2 - y1) < 1: + continue + instance['ignore_flag'] = 0 + instance['bbox'] = b + instance['bbox_label'] = int(label) + instances.append(instance) + phrases[i] = exp + i += 1 + + data_info['instances'] = instances + data_info['dataset_mode'] = self.dataset_mode + data_info['text'] = phrases + self.wh_modes.append([self.dataset_mode, data_info['height'], data_info['width']]) + out_data_list.append(data_info) + else: + anno = data['grounding'] + data_info['text'] = anno['caption'] + regions = anno['regions'] + + instances = [] + phrases = {} + for i, region in enumerate(regions): + bbox = region['bbox'] + phrase = region['phrase'] + tokens_positive = region['tokens_positive'] + if not isinstance(bbox[0], list): + bbox = [bbox] + for box in bbox: + instance = {} + x1, y1, x2, y2 = box + inter_w = max(0, min(x2, data['width']) - max(x1, 0)) + inter_h = max(0, min(y2, data['height']) - max(y1, 0)) + if inter_w * inter_h == 0: + continue + if (x2 - x1) < 1 or (y2 - y1) < 1: + continue + instance['ignore_flag'] = 0 + instance['bbox'] = box + instance['bbox_label'] = i + phrases[i] = { + 'phrase': phrase, + 'tokens_positive': tokens_positive + } + instances.append(instance) + data_info['instances'] = instances + data_info['phrases'] = phrases + data_info['dataset_mode'] = self.dataset_mode + self.wh_modes.append([self.dataset_mode, data_info['height'], data_info['width']]) + out_data_list.append(data_info) + + del data_list + return out_data_list + + def get_wh_mode(self, idx): + return self.wh_modes[idx] + + def prepare_data(self, idx: int): + """Pass the dataset to the pipeline during training to support mixed + data augmentation, such as Mosaic and MixUp.""" + if self.test_mode is False: + data_info = self.get_data_info(idx) + if self.dataset_mode == 'REC': + data_info['image_to_exp'] = self.image_to_exp + return self.pipeline(data_info) + else: + return super().prepare_data(idx) diff --git a/projects/mm_gdino_clip/script/flickr30k2rec.py b/projects/mm_gdino_clip/script/flickr30k2rec.py new file mode 100644 index 00000000000..f69179737d0 --- /dev/null +++ b/projects/mm_gdino_clip/script/flickr30k2rec.py @@ -0,0 +1,440 @@ +import argparse +import os +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Tuple +from xml.etree.ElementTree import parse + +import numpy as np +import torch +import xmltodict # TODO +from tqdm import tqdm +from torchvision.ops.boxes import box_area, batched_nms +import jsonlines +import copy +import os.path as osp + +""" +data/flickr30k_entities + - annotations + - Annotations + - Sentences + - flickr30k_images + - train.txt +""" + + +def parse_args(): + parser = argparse.ArgumentParser("Conversion script") + + parser.add_argument( + "--flickr_path", + default='data/flickr30k_entities', + type=str, + help="Path to the flickr dataset", + ) + parser.add_argument( + "--out_path", + default="", + type=str, + help="Path where to export the resulting dataset.", + ) + + parser.add_argument( + "--merge_ground_truth", + action="store_true", + help="Whether to follow Bryan Plummer protocol and merge ground truth. By default, all the boxes for an entity are kept separate", + ) + + return parser.parse_args() + + +def box_xywh_to_xyxy(x): + """Accepts a list of bounding boxes in coco format (xmin,ymin, width, height) + Returns the list of boxes in pascal format (xmin,ymin,xmax,ymax) + + The boxes are expected as a numpy array + """ + # result = x.copy() + result = x.clone() + result[..., 2:] += result[..., :2] + return result + + +def xyxy2xywh(box: List): + """Accepts a list of bounding boxes in pascal format (xmin,ymin,xmax,ymax) + Returns the list of boxes in coco format (xmin,ymin, width, height) + """ + xmin, ymin, xmax, ymax = box + h = ymax - ymin + w = xmax - xmin + return [xmin, ymin, w, h] + + +def get_sentence_data(filename) -> List[Dict[str, Any]]: + """ + Parses a sentence file from the Flickr30K Entities dataset + + input: + filename - full file path to the sentence file to parse + + output: + a list of dictionaries for each sentence with the following fields: + sentence - the original sentence + phrases - a list of dictionaries for each phrase with the + following fields: + phrase - the text of the annotated phrase + first_word_index - the position of the first word of + the phrase in the sentence + phrase_id - an identifier for this phrase + phrase_type - a list of the coarse categories this + phrase belongs to + + """ + with open(filename, "r") as f: + sentences = f.read().split("\n") + + annotations = [] + for sentence in sentences: + if not sentence: + continue + + first_word = [] + phrases = [] + phrase_id = [] + phrase_type = [] + words = [] + current_phrase = [] + add_to_phrase = False + for token in sentence.split(): + if add_to_phrase: + if token[-1] == "]": + add_to_phrase = False + token = token[:-1] + current_phrase.append(token) + phrases.append(" ".join(current_phrase)) + current_phrase = [] + else: + current_phrase.append(token) + + words.append(token) + else: + if token[0] == "[": + add_to_phrase = True + first_word.append(len(words)) + parts = token.split("/") + phrase_id.append(parts[1][3:]) + phrase_type.append(parts[2:]) + else: + words.append(token) + + sentence_data = {"sentence": " ".join(words), "phrases": []} + for index, phrase, p_id, p_type in zip(first_word, phrases, phrase_id, phrase_type): + sentence_data["phrases"].append( + {"first_word_index": index, "phrase": phrase, "phrase_id": p_id, "phrase_type": p_type} + ) + + annotations.append(sentence_data) + + return annotations + + +def _box_inter_union(boxes1: np.array, boxes2: np.array) -> Tuple[np.array, np.array]: + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = np.maximum(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = np.minimum(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clip(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + return inter, union + + +def box_iou(boxes1: np.array, boxes2: np.array) -> np.array: + """ + Return intersection-over-union (Jaccard index) of boxes. + + Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + + Args: + boxes1 (Tensor[N, 4]) + boxes2 (Tensor[M, 4]) + + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 + """ + inter, union = _box_inter_union(boxes1, boxes2) + iou = inter / union + return iou + + +class UnionFind: + """Optimized union find structure""" + + def __init__(self, n): + """Initialize a union find with n components""" + self.compo = list(range(n)) + self.weight = [1] * n + self.nb_compo = n + + def get_nb_compo(self): + return self.nb_compo + + def find(self, x): + if self.compo[x] == x: + return x + self.compo[x] = self.find(self.compo[x]) + return self.compo[x] + + def unite(self, a, b): + fa = self.find(a) + fb = self.find(b) + if fa != fb: + self.nb_compo -= 1 + if self.weight[fb] > self.weight[fa]: + fa, fb = fb, fa + self.compo[fb] = fa + self.weight[fa] += self.weight[fb] + + +def get_equivalent_boxes(all_boxes, iou_threshold=0.95): + """Find clusters of highly overlapping boxes + Parameters: + - all_boxes: a list of boxes in [center_x, center_y, w, h] format + - iou_threshold: threshold at which we consider two boxes to be the same + + Returns a dict where the keys are an arbitrary id, and the values are the equivalence lists + """ + if len(all_boxes) == 0: + return {0: []} + uf = UnionFind(len(all_boxes)) + + # xy_boxes = box_xywh_to_xyxy(np.asarray(all_boxes)) + xy_boxes = box_xywh_to_xyxy(torch.as_tensor(all_boxes, dtype=torch.float)) + iou = box_iou(xy_boxes, xy_boxes) + for i, j in zip(*np.where(iou >= iou_threshold)): + uf.unite(i, j) + compo = defaultdict(list) + for i in range(len(all_boxes)): + compo[uf.find(i)].append(i) + return compo + + +def convert( + subset: str, flickr_path: Path, merge_ground_truth: bool, next_img_id: int = 1, + next_id: int = 1 +): + with open(flickr_path / f"{subset}.txt") as fd: + ids = [int(l.strip()) for l in fd] + + multibox_entity_count = 0 + + out_results = [] + total_phrase = 0 + total_bbox = 0 + + print(f"Exporting {subset}...") + for img_id in tqdm(ids): + + with open(flickr_path / "annotations" / "Annotations" / f"{img_id}.xml") as xml_file: + annotation = xmltodict.parse(xml_file.read())["annotation"] + + cur_img = { + "filename": annotation["filename"], + "height": int(annotation["size"]["height"]), + "width": int(annotation["size"]["width"]), + } + + instance_list = [] + # image = cv2.imread(output_path / "flickr30k-images" / annotation["filename"]) + # if image.shape[1] != cur_img["width"] or image.shape[0] != cur_img["height"]: + # print("before exif correction: ", cur_img) + # cur_img["width"], cur_img["height"] = image.shape[1], image.shape[0] + # print("after exif correction: ", cur_img) + + anno_file = os.path.join(flickr_path, "annotations/Annotations/%d.xml" % img_id) + + # Parse Annotation + root = parse(anno_file).getroot() + obj_elems = root.findall("./object") + target_bboxes = {} + + for elem in obj_elems: + if elem.find("bndbox") is None or len(elem.find("bndbox")) == 0: + continue + xmin = float(elem.findtext("./bndbox/xmin")) + ymin = float(elem.findtext("./bndbox/ymin")) + xmax = float(elem.findtext("./bndbox/xmax")) + ymax = float(elem.findtext("./bndbox/ymax")) + assert 0 < xmin and 0 < ymin + + h = ymax - ymin + w = xmax - xmin + + coco_box = [xmin, ymin, w, h] + + for name in elem.findall("name"): + entity_id = int(name.text) + assert 0 < entity_id + if not entity_id in target_bboxes: + target_bboxes[entity_id] = [] + else: + multibox_entity_count += 1 + # Dict from entity_id to list of all the bounding boxes + target_bboxes[entity_id].append(coco_box) + + if merge_ground_truth: + merged_bboxes = defaultdict(list) + for eid, bbox_list in target_bboxes.items(): + boxes_xyxy = box_xywh_to_xyxy(torch.as_tensor(bbox_list, dtype=torch.float)) + gt_box_merged = [ + min(boxes_xyxy[:, 0]).item(), + min(boxes_xyxy[:, 1]).item(), + max(boxes_xyxy[:, 2]).item(), + max(boxes_xyxy[:, 3]).item(), + ] + merged_bboxes[eid] = [xyxy2xywh(gt_box_merged)] # convert back to xywh for coco format + + target_bboxes = merged_bboxes + + sents = get_sentence_data(flickr_path / "annotations/Sentences" / f"{img_id}.txt") + for sent_id, sent in enumerate(sents): + + spans = {} # global phrase ID to span in sentence + phraseid2entityid = {} + entityid2phraseid = defaultdict(list) + sentence = sent["sentence"] + entity_ids = [int(p["phrase_id"]) for p in sent["phrases"]] + + for global_phrase_id, phrase in enumerate(sent["phrases"]): + phraseid2entityid[global_phrase_id] = int(phrase["phrase_id"]) + entityid2phraseid[int(phrase["phrase_id"])].append(global_phrase_id) + first_word = phrase["first_word_index"] + beg = sum([len(x) for x in sentence.split()[:first_word]]) + first_word + spans[global_phrase_id] = (beg, beg + len(phrase["phrase"])) + assert sentence[beg: beg + len(phrase["phrase"])] == phrase["phrase"] + + all_boxes_in_sent = [] + for ent_id in entity_ids: + if ent_id in target_bboxes: + for bb in target_bboxes[ent_id]: + all_boxes_in_sent.append({"ent_id": int(ent_id), "coords": bb}) + + equivalences = get_equivalent_boxes([b["coords"] for b in all_boxes_in_sent], 0.95) + + tokens_positive_eval = [] + for gpid, span in spans.items(): + if phraseid2entityid[gpid] in target_bboxes: + tokens_positive_eval.append([span]) + + for equiv in equivalences.values(): + if len(equiv) == 0: + continue + cur_entids = set([all_boxes_in_sent[bid]["ent_id"] for bid in equiv]) + token_spans = [] + for entid in cur_entids: + token_spans += [spans[gid] for gid in entityid2phraseid[entid]] + xmin, ymin, w, h = all_boxes_in_sent[equiv[-1]]["coords"] + + phrase = " ".join([sentence[sp[0]:sp[1]] for sp in token_spans]) + + cur_obj = { + "bbox": [xmin, ymin, w + xmin, h + ymin], + "exp": phrase, + } + next_id += 1 + instance_list.append(cur_obj) + + # 相同图片名的实例合并到一起 + out_instance = {} + for instance in instance_list: + if instance['exp'] in out_instance: + data = out_instance[instance['exp']] + if isinstance(data['bbox'][0], list): + # 如果 bbox 是相同的,就直接留一个就行 + is_same = False + for bbox in data['bbox']: + if bbox == instance['bbox']: + is_same = True + break + if not is_same: + data['bbox'].append(instance['bbox']) + else: + # 如果 bbox 是相同的,就直接留一个就行 + if data['bbox'] != instance['bbox']: + data['bbox'] = [data['bbox'], instance['bbox']] + else: + out_instance[instance['exp']] = copy.deepcopy(instance) + + out_instance = list(out_instance.values()) + + # 不同 phrase 但是 bbox 相同的需要合并 + new_out_instance = [] + temp_bboxes = [] + for instance in out_instance: + bbox = instance['bbox'] + if bbox not in temp_bboxes: + new_out_instance.append(copy.deepcopy(instance)) + temp_bboxes.append(bbox) + else: + index = temp_bboxes.index(bbox) + instance_ = new_out_instance[index] + if isinstance(instance_['exp'], list): + # 如果 phrase 是相同的,就直接留一个就行 + is_same = False + for exp in instance_['exp']: + if exp.lower() == instance['exp'].lower(): + is_same = True + break + if not is_same: + instance_['exp'].append(instance['exp']) + else: + # 如果去除大小写后一样,则只保留其中一个 + if instance_['exp'].lower() != instance['exp'].lower(): + instance_['exp'] = [instance_['exp'], instance['exp']] + + # 每条数据 nms + for instance in new_out_instance: + if isinstance(instance['bbox'][0], list): + bboxes = torch.as_tensor(instance['bbox'], dtype=torch.float).reshape(-1, 4) + score = torch.ones(len(bboxes), dtype=torch.float) + index = batched_nms(bboxes, score, score, iou_threshold=0.9) + if len(index) != len(score): + # print('nms vaild', cur_img['filename'], instance['exp'], bboxes, bboxes[index]) + print('nms vaild', cur_img['filename'], instance['exp']) + bboxes = bboxes[index].numpy().tolist() + instance['bbox'] = bboxes + + total_phrase += len(new_out_instance) + total_bbox_ = [len(ins['bbox']) for ins in new_out_instance] + total_bbox += sum(total_bbox_) + next_img_id += 1 + cur_img['referring'] = {} + cur_img['referring']['instances'] = new_out_instance + out_results.append(cur_img) + + print(f'total image: {len(out_results)}, total phrase: {total_phrase}, total bbox: {total_bbox}') + if merge_ground_truth: + filename = f"flickr30k_mergedGT_{subset}_rec.json" + else: + filename = f"flickr30k_separateGT_{subset}_rec.json" + + out_path = osp.join(flickr_path, filename) + + with jsonlines.open(out_path, mode='w') as writer: + writer.write_all(out_results) + print(f'save to {out_path}') + + +def main(args): + flickr_path = Path(args.flickr_path) + convert("train", flickr_path, args.merge_ground_truth) + + +if __name__ == "__main__": + main(parse_args()) diff --git a/projects/mm_gdino_clip/script/gqa2rec.py b/projects/mm_gdino_clip/script/gqa2rec.py new file mode 100644 index 00000000000..7c0e55f4422 --- /dev/null +++ b/projects/mm_gdino_clip/script/gqa2rec.py @@ -0,0 +1,383 @@ +""" +data_path : path to original GQA annotations to be downloaded from https://cs.stanford.edu/people/dorarad/gqa/download.html +img_path : path to original GQA images to be downloaded from https://cs.stanford.edu/people/dorarad/gqa/download.html +sg_path : path to original GQA scene graphs to be downloaded from https://cs.stanford.edu/people/dorarad/gqa/download.html +vg_img_data_path : path to image info for VG images to be downloaded from https://visualgenome.org/static/data/dataset/image_data.json.zip + + +data/gqa + - questions1.2 + - sceneGraphs + - image_data.json # from VG +""" + +import argparse +import json +import os +import re +from collections import defaultdict +from pathlib import Path +import sys +from tqdm import tqdm +import os.path as osp +import jsonlines +import torch +import copy +from torchvision.ops.boxes import batched_nms + + +PACKAGE_PARENT = "." +SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__)))) +sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT))) +from utils.spans import consolidate_spans + + +# pip install nltk spacy +# python -m spacy download en_core_web_sm +def parse_args(): + parser = argparse.ArgumentParser("Conversion script") + + parser.add_argument( + "--data_path", + default='data/gqa/questions1.2/', + type=str, + help="Path to the gqa dataset", + ) + parser.add_argument( + "--sg_path", + default='data/gqa/sceneGraphs/', + type=str, + help="Path to the gqa dataset scene graph", + ) + + parser.add_argument( + "--vg_img_data_path", + default='data/gqa/', + type=str, + help="Path to image meta data for VG" + ) + return parser.parse_args() + + +def convert(out_results, split, data_path, sg_path, imid2data, next_img_id=1, next_id=1): + print("Loading", data_path / f"{split}_balanced_questions.json") + with open(data_path / f"{split}_balanced_questions.json", "r") as f: + data = json.load(f) + print("Loading", sg_path / f"{split}_sceneGraphs.json") + with open(sg_path / f"{split}_sceneGraphs.json", "r") as f: + sg_data = json.load(f) + + img2ann = defaultdict(dict) + for k, v in data.items(): + img2ann[v["imageId"]][k] = v + print(len(img2ann)) + print(img2ann["2354786"]) + print(img2ann[list(img2ann.keys())[0]].keys()) + + # Add missing annotations by inspecting the semantic field + regexp = re.compile(r"([0-9]+)") + regexp2 = re.compile(r"([A-z]+)") + count = 0 + + for k, v in img2ann.items(): + for ann_id, annotations in v.items(): + expected_boxes = [] + for item in annotations["semantic"]: + if item["operation"] == "select": + if len(regexp.findall(item["argument"])) > 0: + expected_boxes.append( + (regexp2.findall(item["argument"])[0].strip(), regexp.findall(item["argument"])[0]) + ) + question_boxes = list(annotations["annotations"]["question"].values()) + + for name, box_id in expected_boxes: + if box_id not in question_boxes: + count += 1 + beg = annotations["question"].find(name) + end = beg + len(name) + annotations["annotations"]["question"][(beg, end)] = box_id + + print(len(img2ann)) + print(img2ann["2354786"]) + print(img2ann[list(img2ann.keys())[0]].keys()) + + # Add annotations for the questions where there is a box for the answer but not for the question (what/where/who questions) + for k, v in img2ann.items(): + for ann_id, ann in v.items(): + question_objects = list(ann["annotations"]["question"].values()) + answer_objects = list(ann["annotations"]["answer"].values()) + if len(set(answer_objects) - set(question_objects)) > 0: + + for box_id in answer_objects: + if box_id not in question_objects: + + if ann["question"].find("What") > -1: + beg = ann["question"].find("What") + end = beg + len("What") + elif ann["question"].find("what") > -1: + beg = ann["question"].find("what") + end = beg + len("what") + elif ann["question"].find("Who") > -1: + beg = ann["question"].find("Who") + end = beg + len("Who") + elif ann["question"].find("who") > -1: + beg = ann["question"].find("who") + end = beg + len("who") + elif ann["question"].find("Where") > -1: + beg = ann["question"].find("Where") + end = beg + len("Where") + elif ann["question"].find("where") > -1: + beg = ann["question"].find("where") + end = beg + len("where") + else: + continue + + ann["annotations"]["question"][(beg, end)] = box_id + + print(f"Dumping {split}...") + # next_img_id = 0 + # next_id = 0 + + for k, v in tqdm(img2ann.items()): + filename = f"{k}.jpg" + cur_img = { + "filename": filename, + "height": imid2data[int(k)]["height"], + "width": imid2data[int(k)]["width"], + # "id": next_img_id, + # "original_id": k, + } + instance_list = [] + + # image = read_image(data_path / "images" / filename, format="BGR") + # if image.shape[1] != cur_img["width"] or image.shape[0] != cur_img["height"]: + # print("before exif correction: ", cur_img) + # cur_img["width"], cur_img["height"] = image.shape[1], image.shape[0] + # print("after exif correction: ", cur_img) + # if filename == "860.jpg": + # print(v) + + for ann_id, annotation in v.items(): + question = annotation["question"] + answer = annotation["answer"] + full_answer = annotation["fullAnswer"] + + if len(annotation["annotations"]["question"]) > 0: + + # assert len(annotation["annotations"]["question"]) == 1 + # if len(annotation["annotations"]["question"]) > 1: + # print(annotation) + phrase_all = [] + for text_tok_id, box_anno_id in annotation["annotations"]["question"].items(): + target_bbox = sg_data[k]["objects"][box_anno_id] + x, y, h, w = target_bbox["x"], target_bbox["y"], target_bbox["h"], target_bbox["w"] + target_bbox = [x, y, w, h] + + if isinstance(text_tok_id, str): + if ":" in text_tok_id: + text_tok_id = text_tok_id.split(":") + if isinstance(text_tok_id, list) and len(text_tok_id) > 1: + beg = sum([len(x) for x in question.split()[: int(text_tok_id[0])]]) + int(text_tok_id[0]) + end = ( + sum([len(x) for x in question.split()[: int(text_tok_id[1]) - 1]]) + + int(text_tok_id[1]) + - 1 + ) + end = end + len(question.split()[int(text_tok_id[1]) - 1]) + else: + beg = sum([len(x) for x in question.split()[: int(text_tok_id)]]) + int(text_tok_id) + end = beg + len(question.split()[int(text_tok_id)]) + else: + beg, end = text_tok_id + + cleaned_span = consolidate_spans([(beg, end)], question) + + question_positive = " ".join([question[sp[0]:sp[1]] for sp in cleaned_span]) + + if question_positive.lower() in ["what", "who", "where"]: + phrase = answer + else: + phrase = question_positive + phrase_all.append(phrase) + + for text_tok_id, box_anno_id in annotation["annotations"]["question"].items(): + target_bbox = sg_data[k]["objects"][box_anno_id] + x, y, h, w = target_bbox["x"], target_bbox["y"], target_bbox["h"], target_bbox["w"] + target_bbox = [x, y, w + x, h + y] + + if isinstance(text_tok_id, str): + if ":" in text_tok_id: + text_tok_id = text_tok_id.split(":") + if isinstance(text_tok_id, list) and len(text_tok_id) > 1: + beg = sum([len(x) for x in question.split()[: int(text_tok_id[0])]]) + int(text_tok_id[0]) + end = ( + sum([len(x) for x in question.split()[: int(text_tok_id[1]) - 1]]) + + int(text_tok_id[1]) + - 1 + ) + end = end + len(question.split()[int(text_tok_id[1]) - 1]) + else: + beg = sum([len(x) for x in question.split()[: int(text_tok_id)]]) + int(text_tok_id) + end = beg + len(question.split()[int(text_tok_id)]) + else: + beg, end = text_tok_id + + cleaned_span = consolidate_spans([(beg, end)], question) + + question_positive = " ".join([question[sp[0]:sp[1]] for sp in cleaned_span]) + + phrase = question_positive + if any([phrase.lower().startswith(p) for p in ["what", "who", "where"]]): + phrase = answer + elif question_positive.lower() == "wh": + phrase = answer + elif question_positive.lower() == "ho": + phrase = answer + + if sum([1 if p in full_answer else 0 for p in phrase_all]) == 1: + if answer in full_answer and phrase in full_answer: + phrase = full_answer + # beg = full_answer.index(phrase) + # end = beg + len(phrase) + # print([[(beg, end)]], full_answer, phrase) + # cleaned_span, phrase = get_canonical_spans([[(beg, end)]], full_answer) + # print(cleaned_span, phrase) + + if phrase.lower() == "he": + if "man" in full_answer or "boy" in full_answer or "guy" in full_answer: + phrase = full_answer + else: + phrase = "man" + if phrase.lower() == "she": + if "woman" in full_answer or "lady" in full_answer or "girl" in full_answer: + phrase = full_answer + else: + phrase = "woman" + + if len(phrase) == 2 and not (phrase.lower() == "tv" or phrase.lower() == "cd"): + phrase = full_answer + + if len(phrase) == 1: + phrase = full_answer + + if phrase.lower().startswith("no, "): + phrase = phrase[4:] + if phrase.lower().startswith("yes, "): + phrase = phrase[5:] + + cur_obj = { + # "area": h * w, + # "iscrowd": 0, + # "category_id": 1, + "bbox": target_bbox, + # "image_id": next_img_id, + # "id": next_id, + # "question": question, + # "answer": answer, + # "full_answer": full_answer, + # "tokens_positive": cleaned_span, + # "question_positive": question_positive, + "exp": phrase, + } + + next_id += 1 + instance_list.append(cur_obj) + + # 相同图片名的实例合并到一起 + out_instance = {} + for instance in instance_list: + if instance['exp'] in out_instance: + data = out_instance[instance['exp']] + if isinstance(data['bbox'][0], list): + # 如果 bbox 是相同的,就直接留一个就行 + is_same = False + for bbox in data['bbox']: + if bbox == instance['bbox']: + is_same = True + break + if not is_same: + data['bbox'].append(instance['bbox']) + else: + # 如果 bbox 是相同的,就直接留一个就行 + if data['bbox'] != instance['bbox']: + data['bbox'] = [data['bbox'], instance['bbox']] + else: + out_instance[instance['exp']] = copy.deepcopy(instance) + + out_instance = list(out_instance.values()) + + # 不同 phrase 但是 bbox 相同的需要合并 + new_out_instance = [] + temp_bboxes = [] + for instance in out_instance: + bbox = instance['bbox'] + if bbox not in temp_bboxes: + new_out_instance.append(copy.deepcopy(instance)) + temp_bboxes.append(bbox) + else: + index = temp_bboxes.index(bbox) + instance_ = new_out_instance[index] + if isinstance(instance_['exp'], list): + # 如果 phrase 是相同的,就直接留一个就行 + is_same = False + for exp in instance_['exp']: + if exp.lower() == instance['exp'].lower(): + is_same = True + break + if not is_same: + instance_['exp'].append(instance['exp']) + else: + # 如果去除大小写后一样,则只保留其中一个 + if instance_['exp'].lower() != instance['exp'].lower(): + instance_['exp'] = [instance_['exp'], instance['exp']] + + # 每条数据 nms + for instance in new_out_instance: + if isinstance(instance['bbox'][0], list): + bboxes = torch.as_tensor(instance['bbox'], dtype=torch.float).reshape(-1, 4) + score = torch.ones(len(bboxes), dtype=torch.float) + index = batched_nms(bboxes, score, score, iou_threshold=0.9) + if len(index) != len(score): + # print('nms vaild', cur_img['filename'], instance['exp'], bboxes, bboxes[index]) + print('nms vaild', cur_img['filename'], instance['exp']) + bboxes = bboxes[index].numpy().tolist() + instance['bbox'] = bboxes + + next_img_id += 1 + cur_img['referring'] = {} + cur_img['referring']['instances'] = new_out_instance + out_results.append(cur_img) + + return out_results, next_img_id, next_id + +def main(args): + data_path = Path(args.data_path) + sg_path = Path(args.sg_path) + + print("Loading", f"{args.vg_img_data_path}/image_data.json") + with open(f"{args.vg_img_data_path}/image_data.json", "r") as f: + image_data = json.load(f) + imid2data = {x["image_id"]: x for x in image_data} + + out_results = [] + out_results, next_img_id, next_id = convert(out_results, "train", data_path, sg_path, imid2data) + out_results, _, _ = convert(out_results, "val", data_path, sg_path, imid2data, next_img_id, next_id) + + total_phrase = 0 + total_bbox = 0 + for result in out_results: + total_phrase += len(result['referring']['instances']) + total_bbox_ = [len(ins['bbox']) for ins in result['referring']['instances']] + total_bbox += sum(total_bbox_) + print(f'total image: {len(out_results)}, total phrase: {total_phrase}, total bbox: {total_bbox}') + + filename = f"gqa_rec.json" + out_path = osp.join(args.vg_img_data_path, filename) + + with jsonlines.open(out_path, mode='w') as writer: + writer.write_all(out_results) + print(f'save to {out_path}') + + +if __name__ == "__main__": + main(parse_args()) diff --git a/projects/mm_gdino_clip/script/grit_vg_to_rec.py b/projects/mm_gdino_clip/script/grit_vg_to_rec.py new file mode 100644 index 00000000000..f8183e3ad29 --- /dev/null +++ b/projects/mm_gdino_clip/script/grit_vg_to_rec.py @@ -0,0 +1,22 @@ +import json +import jsonlines + +root_path = '/mnt/workspace/zhaoxiangyu/code_new/grounding_mm_mine/grit_try/' +grit_path = root_path + 'grit_ref_all_after_filter.jsonl' + +with open(grit_path, 'r') as f: + rec_data_list = [json.loads(line) for line in f] + +for data in rec_data_list: + referring = data['referring'] + new_dict = {} + for ref in referring: + new_dict['exp'] = ref['phrase'] + new_dict['bbox'] = ref['bbox'] + data['referring'] = {} + data['referring']['instances'] = new_dict + +out_path = root_path + 'grit_ref_all_after_filter_rec.json' +with jsonlines.open(out_path, mode='w') as writer: + writer.write_all(rec_data_list) +print(f'save to {out_path}') diff --git a/projects/mm_gdino_clip/script/merge_flickrvg_to_rec.py b/projects/mm_gdino_clip/script/merge_flickrvg_to_rec.py new file mode 100644 index 00000000000..5fd72f0f591 --- /dev/null +++ b/projects/mm_gdino_clip/script/merge_flickrvg_to_rec.py @@ -0,0 +1,71 @@ +import json +import jsonlines + +root_path = '/home/PJLAB/huanghaian/dataset/flickr30k_entities/' +rec_path = root_path + 'flickr30k_separateGT_train_rec.json' +vg_path = root_path + 'final_flickr_separateGT_train_vg.json' + +with open(rec_path, 'r') as f: + rec_data_list = [json.loads(line) for line in f] + +rec_data_list_name = [data['filename'] for data in rec_data_list] + +with open(vg_path, 'r') as f: + vg_data_list = [json.loads(line) for line in f] + +num = 0 +in_num = 0 + +for vg_data in vg_data_list: + anno = vg_data['grounding'] + regions = anno['regions'] + + # 每个 caption 只有一个 phrase + if len(regions) > 1: + continue + + filename = vg_data['filename'] + caption = anno['caption'] + bbox = regions[0]['bbox'] + + index = rec_data_list_name.index(filename) + if index == -1: + continue + + if not isinstance(bbox[0], list): + bbox = [bbox] + bbox = set([sum(r) for r in bbox]) + + rec_data = rec_data_list[index] + anno = rec_data.get('referring', {}) + instances = [obj for obj in anno.get('instances', [])] + for ins in instances: + rec_bbox = ins['bbox'] + if not isinstance(rec_bbox[0], list): + rec_bbox = [rec_bbox] + rec_bbox = set([sum(r) for r in rec_bbox]) + # 严格匹配 + if rec_bbox == bbox: + if isinstance(ins['exp'], list): + is_same = False + for exp in ins['exp']: + if exp.lower() == caption.lower(): + is_same = True + break + if not is_same: + in_num += 1 + ins['exp'].append(caption) + else: + if ins['exp'].lower() != caption.lower(): + in_num += 1 + ins['exp'] = [ins['exp'], caption] + break + num += 1 + +print(num) # 17233 +print(in_num) # 17111 + +out_path = root_path + 'flickr30k_separateGT_train_mergevg_rec.json' +with jsonlines.open(out_path, mode='w') as writer: + writer.write_all(rec_data_list) +print(f'save to {out_path}') diff --git a/projects/mm_gdino_clip/script/merge_gqavg_to_rec.py b/projects/mm_gdino_clip/script/merge_gqavg_to_rec.py new file mode 100644 index 00000000000..297ba3d3c0a --- /dev/null +++ b/projects/mm_gdino_clip/script/merge_gqavg_to_rec.py @@ -0,0 +1,112 @@ +import json +import jsonlines +import re +import tqdm + +root_path = '/home/PJLAB/huanghaian/dataset/gqa/' +rec_path = root_path + 'gqa_rec.json' +vg_path = root_path + 'final_mixed_train_no_coco_vg.json' + +with open(rec_path, 'r') as f: + rec_data_list = [json.loads(line) for line in f] + +rec_data_list_name = [data['filename'] for data in rec_data_list] + +with open(vg_path, 'r') as f: + vg_data_list = [json.loads(line) for line in f] + + +def split_sentence(sentence): + pattern = r'([?.])' # 正则表达式模式,匹配问号 "?" 或句号 "." + sentences = re.split(pattern, sentence) + sentences = [s.strip() + p for s, p in zip(sentences[0::2], sentences[1::2])] + return sentences + + +num = 0 +in_num = 0 +new_results = [] + +for vg_data in tqdm.tqdm(vg_data_list): + filename = vg_data['filename'] + anno = vg_data['grounding'] + regions = anno['regions'] + all_phrase = [r['phrase'] for r in regions] + caption = anno['caption'] + # 按照分隔符切割为多段 + caption_list = split_sentence(caption) + + for caption in caption_list: + if caption.endswith('?'): # 问句不要了 + continue + count = 0 + for i, p in enumerate(all_phrase): + # 如果这个 phrase 是列表,则抛弃 + if isinstance(p, list): + break + # 如果这个 caption 位于多个 phrase 中,则抛弃 + if p in caption: + index = i + count += 1 + if count > 1 or count == 0: + continue + num += 1 + + # 我们只需要这个 caption 中只有一个名词短语的数据 + data = regions[index] + new_results.append({'bbox': data['bbox'], 'exp': caption, 'filename': filename, 'height': vg_data['height'], + 'width': vg_data['width']}) + +print(num) # 989203 +print(len(new_results), new_results[0]) + +new_image = 0 +for new in tqdm.tqdm(new_results): + filename = new.pop('filename') + width = new.pop('width') + height = new.pop('height') + bbox = new['bbox'] + caption = new['exp'] + if not isinstance(bbox[0], list): + bbox = [bbox] + new_bbox = set([sum(r) for r in bbox]) + + if filename not in rec_data_list_name: + new_image += 1 + rec_data_list.append({'filename': filename, 'width': width, 'height': height, + 'referring': {'instances': [{'bbox': new['bbox'], 'exp': new['exp']}]}}) + rec_data_list_name = [data['filename'] for data in rec_data_list] + else: + index = rec_data_list_name.index(filename) + rec_data = rec_data_list[index] + anno = rec_data.get('referring', {}) + instances = [obj for obj in anno.get('instances', [])] + for ins in instances: + rec_bbox = ins['bbox'] + if not isinstance(rec_bbox[0], list): + rec_bbox = [rec_bbox] + rec_bbox = set([sum(r) for r in rec_bbox]) + # 非常严格的匹配策略,确保不会出现错误 + if rec_bbox == new_bbox: + if isinstance(ins['exp'], list): + is_same = False + for exp in ins['exp']: + if exp.lower() == caption.lower(): + is_same = True + break + if not is_same: + in_num += 1 + ins['exp'].append(caption) + else: + if ins['exp'].lower() != caption.lower(): + in_num += 1 + ins['exp'] = [ins['exp'], caption] + break + +print(in_num) # 47266 +print(new_image) # 12052 + +out_path = root_path + 'gqa_mergevg_rec.json' +with jsonlines.open(out_path, mode='w') as writer: + writer.write_all(rec_data_list) +print(f'save to {out_path}') diff --git a/projects/mm_gdino_clip/script/refcoco2rec.py b/projects/mm_gdino_clip/script/refcoco2rec.py new file mode 100644 index 00000000000..00328df7b2b --- /dev/null +++ b/projects/mm_gdino_clip/script/refcoco2rec.py @@ -0,0 +1,97 @@ +import jsonlines +from pycocotools.coco import COCO +from tqdm import tqdm +import os + +ann_path = '/home/PJLAB/huanghaian/dataset/coco2014/mdetr_annotations/finetune_refcocog_train.json' + + +def _has_only_empty_bbox(anno): + return all(any(o <= 1 for o in obj['bbox'][2:]) for obj in anno) + + +def has_valid_annotation(anno): + # if it's empty, there is no annotation + if len(anno) == 0: + return False + # if all boxes have close to zero area, there is no annotation + if _has_only_empty_bbox(anno): + return False + return True + + +coco = COCO(ann_path) +ids = list(sorted(coco.imgs.keys())) +out_results = [] + +i = 0 +for img_id in tqdm(ids): + if i > 1000: + break + if isinstance(img_id, str): + ann_ids = coco.getAnnIds(imgIds=[img_id], iscrowd=0) + else: + ann_ids = coco.getAnnIds(imgIds=img_id, iscrowd=0) + annos = coco.loadAnns(ann_ids) + + if not has_valid_annotation(annos): + continue + + img_info = coco.loadImgs(img_id)[0] + file_name = img_info['file_name'] + caption = img_info['caption'] + instance_list = [] + + for anno in annos: + box = anno['bbox'] + + x1, y1, w, h = box + inter_w = max(0, min(x1 + w, int(img_info['width'])) - max(x1, 0)) + inter_h = max(0, min(y1 + h, int(img_info['height'])) - max(y1, 0)) + if inter_w * inter_h == 0: + continue + if anno['area'] <= 0 or w < 1 or h < 1: + continue + + if anno.get('iscrowd', False): + continue + bbox_xyxy = [ + x1, y1, + min(x1 + w, int(img_info['width'])), + min(y1 + h, int(img_info['height'])) + ] + instance_list.append({ + 'bbox': bbox_xyxy, + 'exp': caption, + }) + + # 相同图片名的实例合并到一起 + if i != 0 and file_name == out_results[-1]['filename']: + pre_instance_list = out_results[-1]['referring']['instances'] + for instance in instance_list: + no_find = True + for pre_instance in pre_instance_list: + if instance['bbox'] == pre_instance['bbox'] and instance['exp'] != pre_instance['exp']: + if isinstance(pre_instance['exp'], list): + pre_instance['exp'].append(instance['exp']) + else: + pre_instance['exp'] = [pre_instance['exp'], instance['exp']] + no_find = False + break + if no_find: + pre_instance_list.append(instance) + else: + out_results.append({ + 'filename': file_name, + 'height': img_info['height'], + 'width': img_info['width'], + 'referring': { + 'instances': instance_list + } + }) + i += 1 +file_name = os.path.basename(ann_path) +out_path = os.path.join(os.path.dirname(ann_path), os.path.basename(ann_path)[:-5] + '_ref.json') +with jsonlines.open(out_path, mode='w') as writer: + writer.write_all(out_results) +print(f'save to {out_path}') diff --git a/projects/mm_gdino_clip/script/utils/boxes.py b/projects/mm_gdino_clip/script/utils/boxes.py new file mode 100644 index 00000000000..e655602dca5 --- /dev/null +++ b/projects/mm_gdino_clip/script/utils/boxes.py @@ -0,0 +1,85 @@ +# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved +"""Utilities to manipulate and convert boxes""" +from collections import defaultdict +from typing import Any, Dict + +import torch +from torchvision.ops.boxes import box_iou + +from .unionfind import UnionFind + + +def obj_to_box(obj: Dict[str, Any]): + """Extract the bounding box of a given object as a list""" + return [obj["x"], obj["y"], obj["w"], obj["h"]] + + +def region_to_box(obj: Dict[str, Any]): + """Extract the bounding box of a given region as a list""" + return [obj["x"], obj["y"], obj["width"], obj["height"]] + + +def get_boxes_equiv(orig_boxes, iou_threshold): + """Given a set of boxes, returns a dict containing clusters of boxes that are highly overlapping. + For optimization, return None if none of the boxes are overlapping + A high overlap is characterized by the iou_threshold + Boxes are expected as [top_left_x, top_left_y, width, height] + """ + boxes = torch.as_tensor(orig_boxes, dtype=torch.float) + # Convert to (x,y,x,y) format + boxes[:, 2:] += boxes[:, :2] + ious = box_iou(boxes, boxes) + uf = UnionFind(len(boxes)) + for i in range(len(boxes)): + for j in range(i + 1, len(boxes)): + if ious[i][j] >= iou_threshold: + uf.unite(i, j) + if len(orig_boxes) == uf.nb_compo: + # We didn't found any opportunity for merging, returning as-is + # print("no merging") + return None, None + # print("merging") + compo2boxes = defaultdict(list) + compo2id = defaultdict(list) + + for i in range(len(boxes)): + compo2boxes[uf.find(i)].append(boxes[i]) + compo2id[uf.find(i)].append(i) + assert len(compo2boxes) == uf.nb_compo + return compo2boxes, compo2id + + +def xyxy_to_xywh(boxes: torch.Tensor): + """Converts a set of boxes in [top_left_x, top_left_y, bottom_right_x, bottom_right_y] format to + [top_left_x, top_left_y, width, height] format""" + assert boxes.shape[-1] == 4 + converted = boxes.clone() + converted[..., 2:] -= converted[..., :2] + return converted + + +def combine_boxes(orig_boxes, iou_threshold=0.7): + """Given a set of boxes, returns the average of all clusters of boxes that are highly overlapping. + A high overlap is characterized by the iou_threshold + Boxes are expected as [top_left_x, top_left_y, width, height] + """ + compo2boxes, _ = get_boxes_equiv(orig_boxes, iou_threshold) + if compo2boxes is None: + return orig_boxes + result_boxes = [] + for box_list in compo2boxes.values(): + result_boxes.append(xyxy_to_xywh(torch.stack(box_list, 0).mean(0)).tolist()) + return result_boxes + + +def box_iou_helper(b1, b2): + """returns the iou matrix between two sets of boxes + The boxes are expected in the format [top_left_x, top_left_y, w, h] + """ + boxes_r1 = torch.as_tensor(b1, dtype=torch.float) + # Convert to (x,y,x,y) format + boxes_r1[:, 2:] += boxes_r1[:, :2] + boxes_r2 = torch.as_tensor(b2, dtype=torch.float) + # Convert to (x,y,x,y) format + boxes_r2[:, 2:] += boxes_r2[:, :2] + return box_iou(boxes_r1, boxes_r2) diff --git a/projects/mm_gdino_clip/script/utils/dump.py b/projects/mm_gdino_clip/script/utils/dump.py new file mode 100644 index 00000000000..d4e2763a6f2 --- /dev/null +++ b/projects/mm_gdino_clip/script/utils/dump.py @@ -0,0 +1,104 @@ +# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved +import json +from typing import Any, List, NamedTuple, Optional, Tuple + + +class Annotation(NamedTuple): + area: float + iscrowd: int + category_id: int + bbox: List[float] + giou_friendly_bbox: List[float] + tokens_positive: List[Tuple[int, int]] + + +class Datapoint(NamedTuple): + image_id: int + dataset_name: str + tokens_negative: List[Tuple[int, int]] + original_id: int + caption: str + annotations: List[Annotation] + + +def convert2dataset_combined( + datapoint_list_coco: List[Datapoint], + datapoint_list_vg: List[Datapoint], + imgid2imginfo_coco, + imgid2imginfo_vg, + output_path, +): + """""" + print(f"Dumping combined coco and vg images related all training examples...") + next_img_id = 0 + next_id = 0 + + annotations = [] + images = [] + + for datapoint in datapoint_list_coco: + img_id = datapoint.image_id + filename = imgid2imginfo_coco[img_id]["file_name"] + cur_img = { + "file_name": filename, + "height": imgid2imginfo_coco[img_id]["height"], + "width": imgid2imginfo_coco[img_id]["width"], + "id": next_img_id, + "original_id": img_id, + "caption": datapoint.caption, + "tokens_negative": datapoint.tokens_negative, + "data_source": "coco", + "dataset_name": datapoint.dataset_name, + } + + for anns in datapoint.annotations: + cur_obj = { + "area": float(anns.area), + "iscrowd": anns.iscrowd, + "image_id": next_img_id, + "category_id": anns.category_id, + "id": next_id, + "bbox": anns.bbox, + "tokens_positive": anns.tokens_positive, + } + next_id += 1 + annotations.append(cur_obj) + + next_img_id += 1 + images.append(cur_img) + + for datapoint in datapoint_list_vg: + img_id = datapoint.image_id + filename = f"{img_id}.jpg" + cur_img = { + "file_name": filename, + "height": imgid2imginfo_vg[img_id]["height"], + "width": imgid2imginfo_vg[img_id]["width"], + "id": next_img_id, + "original_id": img_id, + "caption": datapoint.caption, + "tokens_negative": datapoint.tokens_negative, + "data_source": "vg", + "dataset_name": datapoint.dataset_name, + } + + for anns in datapoint.annotations: + cur_obj = { + "area": float(anns.area), + "iscrowd": anns.iscrowd, + "image_id": next_img_id, + "category_id": anns.category_id, + "id": next_id, + "bbox": anns.bbox, + "tokens_positive": anns.tokens_positive, + } + next_id += 1 + annotations.append(cur_obj) + + next_img_id += 1 + images.append(cur_img) + + ds = {"info": [], "licenses": [], "images": images, "annotations": annotations, "categories": []} + with open(output_path / f"final_mixed_train.json", "w") as j_file: + json.dump(ds, j_file) + return next_img_id, next_id diff --git a/projects/mm_gdino_clip/script/utils/spans.py b/projects/mm_gdino_clip/script/utils/spans.py new file mode 100644 index 00000000000..b2839ac1e85 --- /dev/null +++ b/projects/mm_gdino_clip/script/utils/spans.py @@ -0,0 +1,235 @@ +# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved +"""""" + +from typing import List, Tuple + +from .text import STOP_WORDS, nlp + + +class PreprocessError(Exception): + pass + + +def span_intersect_span(span1: Tuple[int, int], span2: Tuple[int, int]): + """Returns True if the given spans intersect""" + return (span1[0] <= span2[0] < span1[1]) or (span2[0] <= span1[0] < span2[1]) + + +def span_intersect_spanlist(span: Tuple[int, int], target_spans: List[Tuple[int, int]]): + """Returns True if the given spans intersect with any in the given list""" + for t in target_spans: + if span_intersect_span(span, t): + return True + return False + + +def spanlist_intersect_spanlist(spans: List[Tuple[int, int]], target_spans: List[Tuple[int, int]]): + """Returns True if the given spans intersect with any in the given list""" + for s in spans: + if span_intersect_spanlist(s, target_spans): + return True + return False + + +def consolidate_spans(spans: List[Tuple[int, int]], caption: str, rec=True): + """Accepts a list of spans and the the corresponding caption. + Returns a cleaned list of spans where: + - Overlapping spans are merged + - It is guaranteed that spans start and end on a word + """ + sorted_spans = sorted(spans) + cur_end = -1 + cur_beg = None + final_spans: List[Tuple[int, int]] = [] + for s in sorted_spans: + if s[0] >= cur_end: + if cur_beg is not None: + final_spans.append((cur_beg, cur_end)) + cur_beg = s[0] + cur_end = max(cur_end, s[1]) + + if cur_beg is not None: + final_spans.append((cur_beg, cur_end)) + + # Now clean the begining/end + clean_spans: List[Tuple[int, int]] = [] + for s in final_spans: + beg, end = s + end = min(end, len(caption)) + while beg < len(caption) and not caption[beg].isalnum(): + beg += 1 + while end > 0 and not caption[end - 1].isalnum(): + end -= 1 + # Try to get hyphenated words + if end < len(caption) and caption[end] == "-": + # print("trigg") + next_space = caption.find(" ", end) + if next_space == -1: + end = len(caption) + else: + end = next_space + 1 + if beg > 0 and caption[beg - 1] == "-": + prev_space = caption.rfind(" ", 0, beg) + if prev_space == -1: + beg = 0 + else: + beg = prev_space + 1 + if 0 <= beg < end <= len(caption): + clean_spans.append((beg, end)) + if rec: + return consolidate_spans(clean_spans, caption, False) + return clean_spans + + +def get_canonical_spans(orig_spans: List[List[Tuple[int, int]]], orig_caption: str, whitespace_only=False): + """This functions computes the spans after reduction of the caption to it's normalized version + For example, if the caption is "There is a man wearing sneakers" and the span is [(11,14)] ("man"), + then the normalized sentence is "man wearing sneakers" so the new span is [(0,3)] + """ + # print("orig caption", orig_caption) + # print("orig spans", [orig_caption[t[0]:t[1]] for span in orig_spans for t in span]) + new_spans = [sorted(spans) for spans in orig_spans] + caption = orig_caption.lower() + + def remove_chars(pos, amount): + for i in range(len(new_spans)): + for j in range(len(new_spans[i])): + if pos >= new_spans[i][j][1]: + continue + beg, end = new_spans[i][j] + if span_intersect_span(new_spans[i][j], (pos, pos + amount)): + # assert new_spans[i][j][0] == pos or amount == 1, "unexpected deletion from middle of span" + new_spans[i][j] = (beg, end - amount) + else: + new_spans[i][j] = (beg - amount, end - amount) + + def change_chars(old_beg, old_end, delta): + for i in range(len(new_spans)): + for j in range(len(new_spans[i])): + if old_beg >= new_spans[i][j][1]: + continue + beg, end = new_spans[i][j] + if span_intersect_span(new_spans[i][j], (old_beg, old_end)): + if not (new_spans[i][j][0] <= old_beg < old_end <= new_spans[i][j][1]): + raise PreprocessError(f"deleted spans should be contained in known span") + assert ( + new_spans[i][j][0] <= old_beg < old_end <= new_spans[i][j][1] + ), "deleted spans should be contained in known span" + new_spans[i][j] = (beg, end + delta) + else: + new_spans[i][j] = (beg + delta, end + delta) + + # Pre pass, removing double spaces and leading spaces + # Check for leading spaces + while caption[0] == " ": + remove_chars(0, 1) + caption = caption[1:] + cur_start = 0 + pos = caption.find(" ", cur_start) + while pos != -1: + amount = 1 + # print("remvoing", removed, pos) + remove_chars(pos, amount) + caption = caption.replace(" ", " ", 1) + pos = caption.find(" ", cur_start) + # print("after whitespace caption", caption) + # print("after whitespace spans", [caption[t[0]:t[1]] for span in new_spans for t in span]) + if whitespace_only: + return new_spans, caption + + # First pass, removing punctuation + for punct in [".", ",", "!", "?", ":"]: + pos = caption.find(punct) + while pos != -1: + remove_chars(pos, len(punct)) + caption = caption.replace(punct, "", 1) + pos = caption.find(punct) + # print("after punct caption", caption) + # print("after punct spans", [caption[t[0]:t[1]] for span in new_spans for t in span]) + + # parsing needs to happen before stop words removal + all_tokens = nlp(caption) + tokens = [] + + # Second pass, removing stop words + ## Remove from tokenization + for t in all_tokens: + if str(t) not in STOP_WORDS: + tokens.append(t) + ## Remove from actual sentence + for stop in STOP_WORDS: + cur_start = 0 + pos = caption.find(stop, cur_start) + while pos != -1: + # Check that we are matching a full word + if (pos == 0 or caption[pos - 1] == " ") and ( + pos + len(stop) == len(caption) or caption[pos + len(stop)] == " " + ): + removed = stop + spaces = 0 + if pos + len(stop) < len(caption) and caption[pos + len(stop)] == " ": + removed += " " + spaces += 1 + if pos > 0 and caption[pos - 1] == " ": + removed = " " + removed + spaces += 1 + if spaces == 0: + raise PreprocessError( + f"No spaces found in '{caption}', position={pos}, stopword={stop}, len={len(stop)}" + ) + assert spaces > 0 + replaced = "" if spaces == 1 else " " + amount = len(removed) - len(replaced) + # print("remvoing", removed, pos) + remove_chars(pos, amount) + caption = caption.replace(removed, replaced, 1) + # print("cur caption", caption) + # print("cur spans", [caption[t[0]:t[1]] for span in new_spans for t in span if t[0] < t[1]]) + else: + cur_start += 1 + pos = caption.find(stop, cur_start) + + # print("final caption", caption) + # print("final spans", [caption[t[0]:t[1]] for span in new_spans for t in span if t[0] < t[1]]) + + # Third pass, lemmatization + final_caption = [] + if len(tokens) != len(caption.strip().split(" ")): + raise PreprocessError( + f"''{tokens}'', len={len(tokens)}, {caption.strip().split(' ')}, len={len(caption.strip().split(' '))}" + ) + + # tokens = nlp(caption) + cur_beg = 0 + for i, w in enumerate(caption.strip().split(" ")): + if tokens[i].lemma_[0] != "-": + # print(w, "lemmatized to", tokens[i].lemma_) + final_caption.append(tokens[i].lemma_) + change_chars(cur_beg, cur_beg + len(w), len(tokens[i].lemma_) - len(w)) + else: + # print(w, "skipped lemmatized to", tokens[i].lemma_) + final_caption.append(w) + cur_beg += 1 + len(final_caption[-1]) + # print("cur_beg", cur_beg) + # print("cur spans", [caption[t[0]:t[1]] for span in new_spans for t in span if t[0] < t[1]], new_spans) + + clean_caption = " ".join(final_caption) + # Cleanup empty spans + clean_spans = [] + for spans in new_spans: + cur = [] + for s in spans: + if 0 <= s[0] < s[1]: + cur.append(s) + clean_spans.append(cur) + + # print("clean caption", clean_caption) + # print("clean spans", [clean_caption[t[0]:t[1]] for span in clean_spans for t in span]) + return clean_spans, clean_caption + + +def shift_spans(spans: List[Tuple[int, int]], offset: int) -> List[Tuple[int, int]]: + final_spans = [] + for beg, end in spans: + final_spans.append((beg + offset, end + offset)) + return final_spans diff --git a/projects/mm_gdino_clip/script/utils/text.py b/projects/mm_gdino_clip/script/utils/text.py new file mode 100644 index 00000000000..b094bf4a239 --- /dev/null +++ b/projects/mm_gdino_clip/script/utils/text.py @@ -0,0 +1,135 @@ +# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved +"""Provides various text related util function""" +import re +from typing import List, Tuple + +import nltk +import spacy + +nlp = spacy.load("en_core_web_sm") + +nltk.download("stopwords") +from nltk.corpus import stopwords + +STOP_WORDS = set(stopwords.words("english")) - set(["above", "below", "between", "further", "he", "she", "they"]) + + +def get_noun_phrase(root): + queue = [root] + all_toks = [root] + while len(queue) > 0: + curr = queue.pop() + if curr.tag_ in ["NN", "NNS", "NNP", "NNPS"]: + queue += curr.lefts + all_toks += curr.lefts + return all_toks + + +def get_root_and_nouns(text: str, lazy=True) -> Tuple[str, str, List[Tuple[int, int]], List[Tuple[int, int]]]: + """Given a sentence, returns a tuple with the following items: + -- root text:str : the text associated with the root of the sentence + -- negative_text:str: all the text that shouldn't be positively matched with a box other than the main one + -- root_span: List[Tuple[int, int]] spans covering the root expressions, returned as a list of (beg, end) character spans + -- negative_span: List[Tuple[int, int]] spans covering the negative expressions, returned as a list of (beg, end) character spans + + If lazy is False, then we try a bit harder to find the precise root of the sentence + """ + sents = nlp(text) + negative_text = [] + + if len([x for x in sents if x.tag_ in ["NN", "NNS", "NNP", "NNPS", "PRP"]]) <= 1: + if lazy or len([x for x in sents if x.tag_ in ["NN", "NNS", "NNP", "NNPS", "PRP"]]) == 0: + return text, " ", [(0, len(text))], [(0, len(text))] + + root = None + for token in sents: + if token.dep_ == "ROOT": + if token.tag_ == "UH": + continue + root = token + break + + if root is None: + return text, "", [(0, len(text))], [(0, len(text))] + + if ( + len([c for c in root.children if c.tag_ in ["VB", "VBD", "VBG", "VBN", "VBP", "VBZ"] and c.dep_ == "compound"]) + > 0 + ): + return text, "", [(0, len(text))], [(0, len(text))] + + all_toks = [] + if root.tag_ in ["NN", "NNS", "NNP", "NNPS"]: + all_toks = get_noun_phrase(root) + root_text = " ".join([x.text for x in all_toks]) + root_spans = [(x.idx, x.idx + len(x.text)) for x in all_toks] + else: + root = [x for x in root.children if x.tag_ in ["NN", "NNS", "NNP", "NNPS", "PRP"]] + if len(root) < 1: + return text, "", [(0, len(text))], [(0, len(text))] + else: + root = root[0] + all_toks = list(root.lefts) + [root] + root_text = " ".join([x.text for x in all_toks]) + root_spans = [(x.idx, x.idx + len(x.text)) for x in all_toks] + + everything_else = set() + for token in sents: + if token.tag_ in ["NN", "NNS", "NNP", "NNPS"] and token.dep_ not in ["ROOT"] and token not in all_toks: + everything_else = everything_else.union(set(get_noun_phrase(token))) + + negative_tokens = set(sents) - set(everything_else) + negative_text = " ".join([x.text for x in negative_tokens]) + negative_spans = [(x.idx, x.idx + len(x.text)) for x in negative_tokens] + + return root_text, negative_text, root_spans, negative_spans + + +def normalize_sentence(sentence): + """Returns a list of non stopwords for the sentence, obtained after cleaning ponctuation and spaces""" + + sent = sentence.lower() + sent = remove_punctuation(sentence.lower()) + sent = normalize_whitespace(sent) + tokens = nlp(sent) + return " ".join( + [ + tokens[i].lemma_ if tokens[i].lemma_[0] != "-" else w + for i, w in enumerate(sent.split(" ")) + if w not in STOP_WORDS + ] + ) + + +def remove_punctuation(text): + """ + This function removes all ponctuation. + """ + corrected = str(text) + corrected = re.sub(r"([!?,;.:-])", r"", corrected) + return corrected + + +def simplify_punctuation(text): + """ + This function simplifies doubled or more complex punctuation. The exception is '...'. + """ + corrected = str(text) + corrected = re.sub(r"([!?,;:-])\1+", r"\1", corrected) + corrected = re.sub(r"\.{2,}", r"...", corrected) + corrected = re.sub(r"\s?-\s?", r"-", corrected) + return corrected + + +def normalize_whitespace(text): + """ + This function normalizes whitespaces, removing duplicates and converting all to standard spaces + """ + corrected = str(text) + corrected = re.sub(r"//t", r"\t", corrected) + corrected = re.sub(r"\n", r" ", corrected) + corrected = re.sub(r"_", r" ", corrected) + corrected = re.sub(r"\r", r" ", corrected) + corrected = re.sub(r"\t", r" ", corrected) + corrected = re.sub(r"\s+", r" ", corrected) + return corrected.strip(" ") diff --git a/projects/mm_gdino_clip/script/utils/unionfind.py b/projects/mm_gdino_clip/script/utils/unionfind.py new file mode 100644 index 00000000000..9617a7b20b5 --- /dev/null +++ b/projects/mm_gdino_clip/script/utils/unionfind.py @@ -0,0 +1,31 @@ +# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved +"""Simple union find structure implementation""" + + +class UnionFind: + """Optimized union find structure""" + + def __init__(self, n): + """Initialize a union find with n components""" + self.compo = list(range(n)) + self.weight = [1] * n + self.nb_compo = n + + def get_nb_compo(self): + return self.nb_compo + + def find(self, x): + if self.compo[x] == x: + return x + self.compo[x] = self.find(self.compo[x]) + return self.compo[x] + + def unite(self, a, b): + fa = self.find(a) + fb = self.find(b) + if fa != fb: + self.nb_compo -= 1 + if self.weight[fb] > self.weight[fa]: + fa, fb = fb, fa + self.compo[fb] = fa + self.weight[fa] += self.weight[fb] diff --git a/projects/mm_gdino_clip/text_transformers.py b/projects/mm_gdino_clip/text_transformers.py new file mode 100644 index 00000000000..99d50b00cf5 --- /dev/null +++ b/projects/mm_gdino_clip/text_transformers.py @@ -0,0 +1,426 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json + +from mmcv.transforms import BaseTransform + +from mmdet.registry import TRANSFORMS +from mmdet.structures.bbox import BaseBoxes + +try: + from transformers import AutoTokenizer + from transformers import BertModel as HFBertModel +except ImportError: + AutoTokenizer = None + HFBertModel = None + +import random +import re + +import numpy as np + + +def clean_string(phrase): + # return re.sub(r"([.,'!?\"()*#:;])", "", phrase.lower()).replace("-", " ").replace("/", " ") + + phrase = re.sub(r"([.,'!?\"()*#:;])", "", phrase.lower()).replace("-", " ").replace("/", " ") + phrase = phrase.strip("\n").strip("\r").strip().lstrip(" ").rstrip(" ") + phrase = re.sub(" +", " ", phrase) + + replacements = { + "½": "half", + "—": "-", + "™": "", + "¢": "cent", + "ç": "c", + "û": "u", + "é": "e", + "°": " degree", + "è": "e", + "…": "", + } + for k, v in replacements.items(): + phrase = phrase.replace(k, v) + + return phrase + + +def clean_name(name): + name = re.sub(r'\(.*\)', '', name) + name = re.sub(r'_', ' ', name) + name = re.sub(r' ', ' ', name) + name = name.lower() + return name + + +def check_for_positive_overflow(gt_bboxes, gt_labels, text, tokenizer, + max_tokens): + # Check if we have too many positive labels + # generate a caption by appending the positive labels + positive_label_list = np.unique(gt_labels).tolist() + # random shuffule so we can sample different annotations + # at different epochs + random.shuffle(positive_label_list) + + kept_lables = [] + length = 0 + + for index, label in enumerate(positive_label_list): + + label_text = clean_name(text[str(label)]) + '. ' + + tokenized = tokenizer.tokenize(label_text) + + length += len(tokenized) + + if length > max_tokens: + break + else: + kept_lables.append(label) + + keep_box_index = [] + keep_gt_labels = [] + for i in range(len(gt_labels)): + if gt_labels[i] in kept_lables: + keep_box_index.append(i) + keep_gt_labels.append(gt_labels[i]) + + return gt_bboxes[keep_box_index], np.array( + keep_gt_labels, dtype=np.long), length + + +def generate_senetence_given_labels(positive_label_list, negative_label_list, + text): + label_to_positions = {} + + label_list = negative_label_list + positive_label_list + + random.shuffle(label_list) + + pheso_caption = '' + + label_remap_dict = {} + for index, label in enumerate(label_list): + + start_index = len(pheso_caption) + + pheso_caption += clean_name(text[str(label)]) + + end_index = len(pheso_caption) + + if label in positive_label_list: + label_to_positions[index] = [[start_index, end_index]] + label_remap_dict[int(label)] = index + + # if index != len(label_list) - 1: + # pheso_caption += '. ' + pheso_caption += '. ' + + return label_to_positions, pheso_caption, label_remap_dict + + +@TRANSFORMS.register_module() +class RandomSamplingNegPosV2(BaseTransform): + + def __init__(self, + tokenizer_name, + num_sample_negative=85, + max_tokens=256, + full_sampling_prob=0.5, + label_map_file=None): + if AutoTokenizer is None: + raise RuntimeError( + 'transformers is not installed, please install it by: ' + 'pip install transformers.') + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + self.num_sample_negative = num_sample_negative + self.full_sampling_prob = full_sampling_prob + self.max_tokens = max_tokens + self.label_map = None + if label_map_file: + with open(label_map_file, 'r') as file: + self.label_map = json.load(file) + + def transform(self, results: dict) -> dict: + dataset_mode = results['dataset_mode'] + if dataset_mode == 'OD': + results['dataset_mode'] = 'VG' + return self.od_aug(results) + elif dataset_mode == 'VG': + return self.vg_aug(results) + else: + return self.rec_aug(results) + + def rec_aug(self, results): + gt_bboxes = results['gt_bboxes'] + if isinstance(gt_bboxes, BaseBoxes): + gt_bboxes = gt_bboxes.tensor + gt_labels = results['gt_bboxes_labels'] + + if 'text' not in results: + assert self.label_map is not None + text = self.label_map + else: + text = results['text'] + + if 'image_to_exp' in results: # REC + keys = list(results['image_to_exp'].keys()) + positive_label_list = np.unique(gt_labels).tolist() + + # 85 有点大,会消耗比较多显存,稍微改小点 + full_negative = self.num_sample_negative + + if full_negative > len(keys): + full_negative = len(keys) + + outer_prob = random.random() + + if outer_prob < self.full_sampling_prob: + # c. probability_full: add both all positive and all negatives + num_negatives = full_negative + else: + if random.random() < 1.0: + num_negatives = np.random.choice(max(1, full_negative)) + 1 + else: + num_negatives = full_negative + + negative_label_list = set() + if num_negatives != -1: + if num_negatives > len(keys): + num_negatives = len(keys) + + for i in np.random.choice( + keys, size=num_negatives, replace=False): + if i not in results['img_path']: + others_exp = results['image_to_exp'][i] + if len(others_exp) == 0: + continue + if isinstance(others_exp, list): + others_exp = random.choice(others_exp) + if isinstance(others_exp, list) and len(others_exp) > 0: + others_exp = random.choice(others_exp) + negative_label_list.add(others_exp) + + random.shuffle(positive_label_list) + + negative_label_list = list(negative_label_list) + + text = results['text'] # dict + _pos_texts = [text[p] for p in positive_label_list] + _flat_pos_texts = [] + for p in _pos_texts: + if isinstance(p, list): + p = [_p.lower() for _p in p] + _flat_pos_texts.extend(p) + else: + _flat_pos_texts.append(p.lower()) + _negative_label_list = [] + for n in negative_label_list: + if n.lower() not in _flat_pos_texts: + _negative_label_list.append(n) + negative_label_list = _negative_label_list + if len(negative_label_list) > 0: + random.shuffle(negative_label_list) + + label_list = positive_label_list + negative_label_list + + random.shuffle(label_list) + + label_remap_dict = {} + new_text = [] + for index, label in enumerate(label_list): + if label in positive_label_list: + label_remap_dict[int(label)] = index + _text = text[label] + if isinstance(_text, list): + _text = random.choice(_text) + new_text.append(_text) + else: + new_text.append(label) + + if len(gt_labels) > 0: + gt_labels = np.vectorize(lambda x: label_remap_dict[x])(gt_labels) + + results['gt_bboxes'] = gt_bboxes + results['gt_bboxes_labels'] = gt_labels + + new_text = [clean_string(phrase) for phrase in new_text] + + if results.get('flip', False): + new_text = [ + phrase.replace("left", "@").replace("right", "left").replace("@", "right") + for phrase in new_text + ] + results['text'] = new_text + else: # OD + valid_negative_indexes = list(text.keys()) + + positive_label_list = np.unique(gt_labels).tolist() + full_negative = self.num_sample_negative + + if full_negative > len(valid_negative_indexes): + full_negative = len(valid_negative_indexes) + + outer_prob = random.random() + + if outer_prob < self.full_sampling_prob: + # c. probability_full: add both all positive and all negatives + num_negatives = full_negative + else: + if random.random() < 1.0: + num_negatives = np.random.choice(max(1, full_negative)) + 1 + else: + num_negatives = full_negative + + # Keep some negatives + negative_label_list = set() + if num_negatives != -1: + if num_negatives > len(valid_negative_indexes): + num_negatives = len(valid_negative_indexes) + + for i in np.random.choice( + valid_negative_indexes, size=num_negatives, replace=False): + if int(i) not in positive_label_list: + negative_label_list.add(i) + + random.shuffle(positive_label_list) + + negative_label_list = list(negative_label_list) + random.shuffle(negative_label_list) + + label_list = positive_label_list + negative_label_list + random.shuffle(label_list) + + label_remap_dict = {} + for index, label in enumerate(label_list): + if label in positive_label_list: + label_remap_dict[int(label)] = index + + if len(gt_labels) > 0: + gt_labels = np.vectorize(lambda x: label_remap_dict[x])(gt_labels) + + results['gt_bboxes'] = gt_bboxes + results['gt_bboxes_labels'] = gt_labels + results['text'] = [text[str(l)] for l in label_list] + + results['dataset_mode'] = 'REC' + if 'tokens_positive' in results: + del results['tokens_positive'] + + return results + + def vg_aug(self, results): + gt_bboxes = results['gt_bboxes'] + if isinstance(gt_bboxes, BaseBoxes): + gt_bboxes = gt_bboxes.tensor + gt_labels = results['gt_bboxes_labels'] + text = results['text'].lower().strip() + if not text.endswith('.'): + text = text + '. ' + + phrases = results['phrases'] + # TODO: add neg + positive_label_list = np.unique(gt_labels).tolist() + label_to_positions = {} + for label in positive_label_list: + label_to_positions[label] = phrases[label]['tokens_positive'] + + results['gt_bboxes'] = gt_bboxes + results['gt_bboxes_labels'] = gt_labels + + results['text'] = text + results['tokens_positive'] = label_to_positions + return results + + def od_aug(self, results): + gt_bboxes = results['gt_bboxes'] + if isinstance(gt_bboxes, BaseBoxes): + gt_bboxes = gt_bboxes.tensor + gt_labels = results['gt_bboxes_labels'] + + if 'text' not in results: + assert self.label_map is not None + text = self.label_map + else: + text = results['text'] + + original_box_num = len(gt_labels) + # If the category name is in the format of 'a/b' (in object365), + # we randomly select one of them. + for key, value in text.items(): + if '/' in value: + text[key] = random.choice(value.split('/')).strip() + + gt_bboxes, gt_labels, positive_caption_length = \ + check_for_positive_overflow(gt_bboxes, gt_labels, + text, self.tokenizer, self.max_tokens) + + if len(gt_bboxes) < original_box_num: + print('WARNING: removed {} boxes due to positive caption overflow'. + format(original_box_num - len(gt_bboxes))) + + valid_negative_indexes = list(text.keys()) + + positive_label_list = np.unique(gt_labels).tolist() + full_negative = self.num_sample_negative + + if full_negative > len(valid_negative_indexes): + full_negative = len(valid_negative_indexes) + + outer_prob = random.random() + + if outer_prob < self.full_sampling_prob: + # c. probability_full: add both all positive and all negatives + num_negatives = full_negative + else: + if random.random() < 1.0: + num_negatives = np.random.choice(max(1, full_negative)) + 1 + else: + num_negatives = full_negative + + # Keep some negatives + negative_label_list = set() + if num_negatives != -1: + if num_negatives > len(valid_negative_indexes): + num_negatives = len(valid_negative_indexes) + + for i in np.random.choice( + valid_negative_indexes, size=num_negatives, replace=False): + if i not in positive_label_list: + negative_label_list.add(i) + + random.shuffle(positive_label_list) + + negative_label_list = list(negative_label_list) + random.shuffle(negative_label_list) + + negative_max_length = self.max_tokens - positive_caption_length + screened_negative_label_list = [] + + for negative_label in negative_label_list: + label_text = clean_name(text[str(negative_label)]) + '. ' + + tokenized = self.tokenizer.tokenize(label_text) + + negative_max_length -= len(tokenized) + + if negative_max_length > 0: + screened_negative_label_list.append(negative_label) + else: + break + negative_label_list = screened_negative_label_list + label_to_positions, pheso_caption, label_remap_dict = \ + generate_senetence_given_labels(positive_label_list, + negative_label_list, text) + + # label remap + if len(gt_labels) > 0: + gt_labels = np.vectorize(lambda x: label_remap_dict[x])(gt_labels) + + results['gt_bboxes'] = gt_bboxes + results['gt_bboxes_labels'] = gt_labels + + results['text'] = pheso_caption + results['tokens_positive'] = label_to_positions + + return results