diff --git a/maskrcnn_benchmark/structures/segmentation_mask.py b/maskrcnn_benchmark/structures/segmentation_mask.py index ba1290b91..a37473635 100644 --- a/maskrcnn_benchmark/structures/segmentation_mask.py +++ b/maskrcnn_benchmark/structures/segmentation_mask.py @@ -195,7 +195,9 @@ def __getitem__(self, item): else: # advanced indexing on a single dimension selected_polygons = [] - if isinstance(item, torch.Tensor) and item.dtype == torch.uint8: + if isinstance(item, torch.Tensor) and ( + item.dtype == torch.uint8 or item.dtype == torch.bool + ): item = item.nonzero() item = item.squeeze(1) if item.numel() > 0 else item item = item.tolist()