diff --git a/mmdet2trt/apis/inference.py b/mmdet2trt/apis/inference.py index 6ac84aa..f3db70b 100644 --- a/mmdet2trt/apis/inference.py +++ b/mmdet2trt/apis/inference.py @@ -77,9 +77,11 @@ def get_classes_from_config(model_cfg): data_cfg = model_cfg.data def get_module_from_train_val(train_val_cfg): - while train_val_cfg.type == 'RepeatDataset' or \ - train_val_cfg.type == 'MultiImageMixDataset': - train_val_cfg = train_val_cfg.dataset + while train_val_cfg.type in ('RepeatDataset', + 'MultiImageMixDataset', + 'ConcatDataset'): + train_val_cfg = train_val_cfg.datasets[0] if hasattr( + train_val_cfg, 'datasets') else train_val_cfg.dataset return module_dict[train_val_cfg.type] data_cfg_type_list = ['train', 'val', 'test'] @@ -182,19 +184,20 @@ def forward(self, img, img_metas, *args, **kwargs): masks = masks.detach().cpu().numpy() num_classes = len(self.CLASSES) class_agnostic = True - segms_results = [] - for i in range(batch_size): - segms_results = FCNMaskHead.get_seg_masks( - Addict( - num_classes=num_classes, - class_agnostic=class_agnostic), - masks, - old_dets, - labels, - rcnn_test_cfg=Addict(mask_thr_binary=0.5), - ori_shape=img_metas[i]['ori_shape'], - scale_factor=scale_factor, - rescale=rescale) + segms_results = [[] for _ in range(num_classes)] + if num_dets>0: + for i in range(batch_size): + segms_results = FCNMaskHead.get_seg_masks( + Addict( + num_classes=num_classes, + class_agnostic=class_agnostic), + masks, + old_dets, + labels, + rcnn_test_cfg=Addict(mask_thr_binary=0.5), + ori_shape=img_metas[i]['ori_shape'], + scale_factor=scale_factor, + rescale=rescale) results.append((dets_results, segms_results)) else: results.append(dets_results)