diff --git a/src/scilpy/cli/scil_labels_from_mask.py b/src/scilpy/cli/scil_labels_from_mask.py index 6991b254f..bb172d786 100755 --- a/src/scilpy/cli/scil_labels_from_mask.py +++ b/src/scilpy/cli/scil_labels_from_mask.py @@ -38,6 +38,9 @@ def _build_arg_parser(): p.add_argument('--min_volume', type=float, default=7, help='Minimum volume in mm3 [%(default)s],' 'Useful for lesions.') + p.add_argument('--min_distance', type=int, default=None, + help='Minimum distance in voxels [%(default)s],' + 'Useful for confluent lesions.') add_verbose_arg(p) add_overwrite_arg(p) @@ -57,11 +60,13 @@ def main(): mask_data = get_data_as_mask(mask_img) voxel_volume = np.prod(np.diag(mask_img.affine)[:3]) min_voxel_count = args.min_volume // voxel_volume + min_distance = args.min_distance # Get labels from mask label_map = get_labels_from_mask( mask_data, args.labels, args.background_label, - min_voxel_count=min_voxel_count) + min_voxel_count=min_voxel_count, + min_distance=min_distance) # Save result out_img = nib.Nifti1Image(label_map.astype(np.uint16), mask_img.affine) nib.save(out_img, args.out_labels) diff --git a/src/scilpy/image/labels.py b/src/scilpy/image/labels.py index 60c09ad2c..e8d1292c2 100644 --- a/src/scilpy/image/labels.py +++ b/src/scilpy/image/labels.py @@ -8,6 +8,8 @@ import numpy as np from scipy import ndimage as ndi from scipy.spatial import cKDTree +from skimage.segmentation import watershed +from skimage.feature import peak_local_max from scilpy.tractanalysis.reproducibility_measures import compute_bundle_adjacency_voxel @@ -72,7 +74,7 @@ def get_binary_mask_from_labels(atlas, label_list): def get_labels_from_mask(mask_data, labels=None, background_label=0, - min_voxel_count=0): + min_voxel_count=0, min_distance=0): """ Get labels from a binary mask which contains multiple blobs. Each blob will be assigned a label, by default starting from 1. Background will @@ -90,6 +92,11 @@ def get_labels_from_mask(mask_data, labels=None, background_label=0, min_voxel_count: int, optional Minimum number of voxels for a blob to be considered. Blobs with fewer voxels will be ignored. + min_distance : int, optional + The minimum voxels separating the peaks of two + neighboring blobs. If a confluent blob is detected, it will be + split into multiple labels based on this distance. If None, no + splitting is performed. Returns ------- @@ -97,7 +104,22 @@ def get_labels_from_mask(mask_data, labels=None, background_label=0, The labels. """ # Get the number of structures and assign labels to each blob - label_map, nb_structures = ndi.label(mask_data) + if min_distance: + distance = ndi.distance_transform_edt(mask_data) + coords = peak_local_max( + distance, + min_distance=min_distance, + labels=mask_data, + threshold_abs=0 + ) + mask = np.zeros(distance.shape, dtype=bool) + mask[tuple(coords.T)] = True + markers, _ = ndi.label(mask) + label_map = watershed(-distance, markers, mask=mask_data) + nb_structures = np.max(label_map) + else: + label_map, nb_structures = ndi.label(mask_data) + if min_voxel_count: new_count = 0 for label in range(1, nb_structures + 1):