diff --git a/predict.py b/predict.py index 956b6f8894..01ccb3a0a3 100755 --- a/predict.py +++ b/predict.py @@ -69,11 +69,24 @@ def _generate_name(fn): return args.output or list(map(_generate_name, args.input)) -def mask_to_image(mask: np.ndarray): +# if multiclass semantic segmentation, consider setting the mapping dict used during training, example: +# mapping = {(0, 0, 0): 0, (255, 0, 255): 1, (0, 255, 255): 2} +def mask_to_image(mask: np.ndarray, mapping = {}): if mask.ndim == 2: return Image.fromarray((mask * 255).astype(np.uint8)) - elif mask.ndim == 3: - return Image.fromarray((np.argmax(mask, axis=0) * 255 / mask.shape[0]).astype(np.uint8)) + else: + # probabilities to indexes --> index of each class that has the highest probability + mask = torch.argmax(mask, axis=0) + # mask shape: (h, w) + # reverse the mapping values we have used during training + rev_mapping = {mapping[k]: k for k in mapping} + # create an empty image with 3 channels of shape : (3, h, w) + pred_image = torch.zeros(3, mask.size(0), mask.size(1), dtype=torch.uint8) + # replace predicted mask values with mapped values + for k in rev_mapping: + pred_image[:, mask == k] = torch.tensor(rev_mapping[k]).byte().view(3, 1) + final_mask_pred = pred_image.permute(1, 2, 0).numpy() + return Image.fromarray(final_mask_pred) if __name__ == '__main__': @@ -104,7 +117,9 @@ def mask_to_image(mask: np.ndarray): if not args.no_save: out_filename = out_files[i] - result = mask_to_image(mask) + # if multiclass semantic segmentation, consider setting the mapping dict used during training, example: + # mapping = {(0, 0, 0): 0, (255, 0, 255): 1, (0, 255, 255): 2} + result = mask_to_image(mask, mapping={}) result.save(out_filename) logging.info(f'Mask saved to {out_filename}') diff --git a/train.py b/train.py index 6067c72bb3..be302c4c74 100644 --- a/train.py +++ b/train.py @@ -32,9 +32,11 @@ def train_net(net, amp: bool = False): # 1. Create dataset try: - dataset = CarvanaDataset(dir_img, dir_mask, img_scale) + # if multi_class semantic segmentation add class mapping + # example for 3 class segmentation : mapping = {(0, 0, 0): 0, (255, 0, 255): 1, (0, 255, 255): 2} + dataset = CarvanaDataset(dir_img, dir_mask, img_scale, mapping = {}) except (AssertionError, RuntimeError): - dataset = BasicDataset(dir_img, dir_mask, img_scale) + dataset = BasicDataset(dir_img, dir_mask, img_scale, mapping = {}) # 2. Split into train / validation partitions n_val = int(len(dataset) * val_percent) diff --git a/utils/data_loading.py b/utils/data_loading.py index 8bb4f9252c..1698517b36 100644 --- a/utils/data_loading.py +++ b/utils/data_loading.py @@ -10,12 +10,13 @@ class BasicDataset(Dataset): - def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = ''): + def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = '', mapping={}): self.images_dir = Path(images_dir) self.masks_dir = Path(masks_dir) assert 0 < scale <= 1, 'Scale must be between 0 and 1' self.scale = scale self.mask_suffix = mask_suffix + self.mapping = mapping self.ids = [splitext(file)[0] for file in listdir(images_dir) if not file.startswith('.')] if not self.ids: @@ -39,8 +40,6 @@ def preprocess(pil_img, scale, is_mask): else: img_ndarray = img_ndarray.transpose((2, 0, 1)) - img_ndarray = img_ndarray / 255 - return img_ndarray @staticmethod @@ -53,6 +52,19 @@ def load(filename): else: return Image.open(filename) + @classmethod + def mask_to_class(cls, mask: np.ndarray, mapping): + mask_ = np.zeros((mask.shape[1], mask.shape[2])) + for k in mapping: + k_array = np.array(k) + # to have the same dim as the mask + k_array = np.expand_dims(k_array, axis=(1, 2)) + # Extract each class indexes + idx = (mask == k_array) + validx = (idx.sum(0) == 3) + mask_[validx] = mapping[k] + return mask_ + def __getitem__(self, idx): name = self.ids[idx] mask_file = list(self.masks_dir.glob(name + self.mask_suffix + '.*')) @@ -69,6 +81,9 @@ def __getitem__(self, idx): img = self.preprocess(img, self.scale, is_mask=False) mask = self.preprocess(mask, self.scale, is_mask=True) + # mapping the class colors + mask = self.mask_to_class(mask, self.mapping) + return { 'image': torch.as_tensor(img.copy()).float().contiguous(), 'mask': torch.as_tensor(mask.copy()).long().contiguous() @@ -76,5 +91,5 @@ def __getitem__(self, idx): class CarvanaDataset(BasicDataset): - def __init__(self, images_dir, masks_dir, scale=1): + def __init__(self, images_dir, masks_dir, scale=1, mapping = {}): super().__init__(images_dir, masks_dir, scale, mask_suffix='_mask')