diff --git a/.gitignore b/.gitignore index ad1a91c78..3ad5a6417 100644 --- a/.gitignore +++ b/.gitignore @@ -75,3 +75,5 @@ venv/ # Hidden folder .hidden/ + +.DS_Store \ No newline at end of file diff --git a/src/scilpy/cli/scil_tracking_local_dev.py b/src/scilpy/cli/scil_tracking_local_dev.py index c97b7b97a..5ae2ea804 100755 --- a/src/scilpy/cli/scil_tracking_local_dev.py +++ b/src/scilpy/cli/scil_tracking_local_dev.py @@ -167,6 +167,10 @@ def _build_arg_parser(): track_g.add_argument('--rap_save_entry_exit', default=None, help='Save RAP entry/exit coordinates as a binary mask.\n' 'Provide output filename (.nii.gz).') + track_g.add_argument('--rap_labels', default=None, + help='Region-Adaptive Propagation label volume (.nii.gz) .\n' + 'Voxel values are integer labels (0=background, 1..N=regions) .\n' + 'Used with --rap_method switch to select policies per label.') m_g = p.add_argument_group('Memory options') add_processes_arg(m_g) @@ -195,15 +199,16 @@ def main(): verify_compression_th(args.compress_th) verify_seed_options(parser, args) - if args.rap_mask is not None and args.rap_method == "None": + if (args.rap_mask is not None or args.rap_labels is not None) and args.rap_method == "None": parser.error('No RAP method selected.') - if not args.rap_method == "None" and args.rap_mask is None: - parser.error('No RAP mask selected.') + if args.rap_method == 'continue' and args.rap_mask is None: + parser.error('RAP method "continue" requires --rap_mask.') + if args.rap_method == 'switch' and (args.rap_mask is None and args.rap_labels is None): + parser.error('RAP method "switch" requires --rap_mask or --rap_labels.') if args.rap_method == 'switch' and args.rap_params is None: parser.error('RAP method "switch" requires --rap_params to be specified.') if args.rap_params is not None and args.rap_method != 'switch': parser.error('--rap_params can only be used with --rap_method switch.') - tracts_format = detect_format(args.out_tractogram) if tracts_format is not TrkFile: logging.warning("You have selected option --save_seeds but you are " @@ -292,21 +297,32 @@ def main(): space=our_space, origin=our_origin, is_legacy=is_legacy) # ------- INSTANTIATING RAP OBJECT ------- + rap_mask = None + rap_labels = None if args.rap_mask: logging.info("Loading RAP mask.") rap_img = nib.load(args.rap_mask) rap_data = rap_img.get_fdata(caching='unchanged', dtype=float) rap_res = rap_img.header.get_zooms()[:3] rap_mask = DataVolume(rap_data, rap_res, args.mask_interp) - else: - rap_mask = None + + if args.rap_labels: + logging.info("Loading RAP labels.") + rap_label_img = nib.load(args.rap_labels) + + rap_label_data = np.asanyarray(rap_label_img.dataobj).astype(np.uint8) + rap_label_res = rap_label_img.header.get_zooms()[:3] + rap_labels = DataVolume(rap_label_data, rap_label_res, 'nearest') + rap_mask_from_labels = (rap_label_data > 0).astype(np.float32) + rap_mask = DataVolume(rap_mask_from_labels, rap_label_res, 'nearest') if args.rap_method == "continue": rap = RAPContinue(rap_mask, propagator, max_nbr_pts, step_size=vox_step_size) elif args.rap_method == "switch": rap = RAPSwitch(rap_mask, propagator, max_nbr_pts, - rap_params_file=args.rap_params) + rap_params_file=args.rap_params, + rap_labels=rap_labels) else: rap = None @@ -352,7 +368,8 @@ def main(): sft = StatefulTractogram(streamlines, mask_img, space=our_space, origin=our_origin, data_per_streamline=data_per_streamline) - save_tractogram(sft, args.out_tractogram) + save_tractogram(sft, args.out_tractogram, bbox_valid_check=False + ) if __name__ == "__main__": diff --git a/src/scilpy/tracking/rap.py b/src/scilpy/tracking/rap.py index 5fa6733b9..53769c82e 100644 --- a/src/scilpy/tracking/rap.py +++ b/src/scilpy/tracking/rap.py @@ -3,7 +3,9 @@ import json import logging import numpy as np +from copy import deepcopy from dipy.core.geometry import math +from scilpy.tracking.propagator import get_sphere_neighbours class RAP: @@ -65,7 +67,7 @@ def rap_multistep_propagate(self, line, prev_direction): class RAPSwitch(RAP): """RAP class that switches tracking parameters when inside the RAP mask.""" - def __init__(self, mask_rap, propagator, max_nbr_pts, rap_params_file): + def __init__(self, mask_rap, propagator, max_nbr_pts, rap_params_file, rap_labels=None): """ Parameters ---------- @@ -87,6 +89,7 @@ def __init__(self, mask_rap, propagator, max_nbr_pts, rap_params_file): # Load parameters from JSON file with open(rap_params_file, 'r') as f: rap_params = json.load(f) + cfg = rap_params # Store original parameters self.original_step_size = propagator.step_size @@ -97,6 +100,23 @@ def __init__(self, mask_rap, propagator, max_nbr_pts, rap_params_file): # Convert theta from degrees to radians self.rap_theta = math.radians(rap_params.get('theta', math.degrees(self.original_theta))) + + self.rap_mask = mask_rap + self.rap_labels = rap_labels + self._mode = 'label' if rap_labels is not None else 'mask' + + self._base = { + 'step_size': propagator.step_size, + 'theta': propagator.theta, + 'algo' : getattr(propagator, 'algo', None), + 'tracking_neighbours' : getattr(propagator, 'tracking_neighbours', None) + } + + if self._mode == 'label': + self.default_cfg = cfg.get('default', {}) + self.methods_cfg = cfg.get('methods', {}) + else: + self.rap_cfg = cfg logging.info("RAP parameters loaded:") logging.info(f" Original step_size: {self.original_step_size:.3f}, " @@ -125,31 +145,92 @@ def rap_multistep_propagate(self, line, prev_direction): Whether the line is valid. """ # Switch to RAP parameters - self.propagator.step_size = self.rap_step_size - self.propagator.theta = self.rap_theta - - # Update tracking neighbours with new theta - from scilpy.tracking.propagator import get_sphere_neighbours - self.propagator.tracking_neighbours = get_sphere_neighbours( - self.propagator.sphere, self.rap_theta) - - # Perform propagation with new parameters - new_pos, new_dir, is_direction_valid = \ - self.propagator.propagate(line, prev_direction) - - # Restore original parameters - self.propagator.step_size = self.original_step_size - self.propagator.theta = self.original_theta - self.propagator.tracking_neighbours = get_sphere_neighbours( - self.propagator.sphere, self.original_theta) - - # Add the new point to the line - if is_direction_valid: + is_line_valid = True + + # We allow RAP to extend the streamline while it stays inside the RAP region + # In mask mode: "inside" means rap_mask > 0 + # In label mode: "inside" means label > 0 and we can switch config per label + while len(line) < self.max_nbr_pts: + curr_pos = line[-1] + + # Select config depending on RAP mode + if self._mode == 'label': + label = self._get_label(curr_pos, self.propagator.space, self.propagator.origin) + if label <= 0: + break + cfg = self._merge_cfg(label) + else: + # Classic binary RAP mask behaviour + if not self.is_in_rap_region(curr_pos, self.propagator.space, self.propagator.origin): + break + cfg = self.rap_cfg + + # Apply selected params for ONE step, then restore + self._apply_cfg(cfg) + try: + new_pos, new_dir, valid = self.propagator.propagate(line, prev_direction) + finally: + self._restore_base() + + is_line_valid = is_line_valid and valid + if not valid: + break + line.append(new_pos) - return line, new_dir, True - else: - return line, prev_direction, False + prev_direction = new_dir + return line, prev_direction, is_line_valid + + def _get_label(self, curr_pos, space, origin): + v = self.rap_labels.get_value_at_coordinate(*curr_pos, space=space, origin=origin) + try: + return int(v) + except Exception: + return int(np.round(v)) + + def _merge_cfg(self, label): + cfg = deepcopy(self.default_cfg) + override = self.methods_cfg.get(str(label), {}) + cfg.update(override) + return cfg + + def _apply_cfg(self, cfg): + if 'step_size' in cfg and cfg['step_size'] is not None: + self.propagator.step_size = float(cfg['step_size']) + if 'algo' in cfg and cfg['algo'] is not None: + self.propagator.algo = str(cfg['algo']) + if 'theta' in cfg and cfg['theta'] is not None: + theta_rad = np.deg2rad(float(cfg['theta'])) + self.propagator.theta = theta_rad + # theta change => neighbours change + self.propagator.tracking_neighbours = get_sphere_neighbours(self.propagator.sphere, self.propagator.theta) + + def _restore_base(self): + self.propagator.step_size = self._base['step_size'] + self.propagator.theta = self._base['theta'] + if self._base['algo'] is not None: + self.propagator.algo = self._base['algo'] + if self._base['tracking_neighbours'] is not None: + self.propagator.tracking_neighbours = self._base['tracking_neighbours'] + + def is_in_rap_region(self, curr_pos, space, origin): + """Override base class to support label-mode when rap_mask is None. + Tracker uses this to decide whether to enter/exit RAP. + - mask mode: inside if rap_mask > 0 + - label mode: inside if rap_labels label > 0 + """ + if self._mode == 'label': + if self.rap_labels is None: + return False + val = self.rap_labels.get_value_at_coordinate( + *curr_pos, space=space, origin=origin) + return val > 0 + + # mask mode (legacy) + if self.rap_mask is None: + return False + return self.rap_mask.get_value_at_coordinate( + *curr_pos, space=space, origin=origin) > 0 class RAPGraph(RAP): def __init__(self, mask_rap, propagator, max_nbr_pts, neighboorhood_size):