Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,5 @@ venv/

# Hidden folder
.hidden/

.DS_Store
33 changes: 25 additions & 8 deletions src/scilpy/cli/scil_tracking_local_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ def _build_arg_parser():
track_g.add_argument('--rap_save_entry_exit', default=None,
help='Save RAP entry/exit coordinates as a binary mask.\n'
'Provide output filename (.nii.gz).')
track_g.add_argument('--rap_labels', default=None,
help='Region-Adaptive Propagation label volume (.nii.gz) .\n'
'Voxel values are integer labels (0=background, 1..N=regions) .\n'
'Used with --rap_method switch to select policies per label.')

m_g = p.add_argument_group('Memory options')
add_processes_arg(m_g)
Expand Down Expand Up @@ -195,15 +199,16 @@ def main():
verify_compression_th(args.compress_th)
verify_seed_options(parser, args)

if args.rap_mask is not None and args.rap_method == "None":
if (args.rap_mask is not None or args.rap_labels is not None) and args.rap_method == "None":
parser.error('No RAP method selected.')
if not args.rap_method == "None" and args.rap_mask is None:
parser.error('No RAP mask selected.')
if args.rap_method == 'continue' and args.rap_mask is None:
parser.error('RAP method "continue" requires --rap_mask.')
if args.rap_method == 'switch' and (args.rap_mask is None and args.rap_labels is None):
parser.error('RAP method "switch" requires --rap_mask or --rap_labels.')
if args.rap_method == 'switch' and args.rap_params is None:
parser.error('RAP method "switch" requires --rap_params to be specified.')
if args.rap_params is not None and args.rap_method != 'switch':
parser.error('--rap_params can only be used with --rap_method switch.')

tracts_format = detect_format(args.out_tractogram)
if tracts_format is not TrkFile:
logging.warning("You have selected option --save_seeds but you are "
Expand Down Expand Up @@ -292,21 +297,32 @@ def main():
space=our_space, origin=our_origin, is_legacy=is_legacy)

# ------- INSTANTIATING RAP OBJECT -------
rap_mask = None
rap_labels = None
if args.rap_mask:
logging.info("Loading RAP mask.")
rap_img = nib.load(args.rap_mask)
rap_data = rap_img.get_fdata(caching='unchanged', dtype=float)
rap_res = rap_img.header.get_zooms()[:3]
rap_mask = DataVolume(rap_data, rap_res, args.mask_interp)
else:
rap_mask = None

if args.rap_labels:
logging.info("Loading RAP labels.")
rap_label_img = nib.load(args.rap_labels)

rap_label_data = np.asanyarray(rap_label_img.dataobj).astype(np.uint8)
rap_label_res = rap_label_img.header.get_zooms()[:3]
rap_labels = DataVolume(rap_label_data, rap_label_res, 'nearest')
rap_mask_from_labels = (rap_label_data > 0).astype(np.float32)
rap_mask = DataVolume(rap_mask_from_labels, rap_label_res, 'nearest')

if args.rap_method == "continue":
rap = RAPContinue(rap_mask, propagator, max_nbr_pts,
step_size=vox_step_size)
elif args.rap_method == "switch":
rap = RAPSwitch(rap_mask, propagator, max_nbr_pts,
rap_params_file=args.rap_params)
rap_params_file=args.rap_params,
rap_labels=rap_labels)
else:
rap = None

Expand Down Expand Up @@ -352,7 +368,8 @@ def main():
sft = StatefulTractogram(streamlines, mask_img,
space=our_space, origin=our_origin,
data_per_streamline=data_per_streamline)
save_tractogram(sft, args.out_tractogram)
save_tractogram(sft, args.out_tractogram, bbox_valid_check=False
)


if __name__ == "__main__":
Expand Down
129 changes: 105 additions & 24 deletions src/scilpy/tracking/rap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import json
import logging
import numpy as np
from copy import deepcopy
from dipy.core.geometry import math
from scilpy.tracking.propagator import get_sphere_neighbours


class RAP:
Expand Down Expand Up @@ -65,7 +67,7 @@ def rap_multistep_propagate(self, line, prev_direction):

class RAPSwitch(RAP):
"""RAP class that switches tracking parameters when inside the RAP mask."""
def __init__(self, mask_rap, propagator, max_nbr_pts, rap_params_file):
def __init__(self, mask_rap, propagator, max_nbr_pts, rap_params_file, rap_labels=None):
"""
Parameters
----------
Expand All @@ -87,6 +89,7 @@ def __init__(self, mask_rap, propagator, max_nbr_pts, rap_params_file):
# Load parameters from JSON file
with open(rap_params_file, 'r') as f:
rap_params = json.load(f)
cfg = rap_params

# Store original parameters
self.original_step_size = propagator.step_size
Expand All @@ -97,6 +100,23 @@ def __init__(self, mask_rap, propagator, max_nbr_pts, rap_params_file):
# Convert theta from degrees to radians
self.rap_theta = math.radians(rap_params.get('theta',
math.degrees(self.original_theta)))

self.rap_mask = mask_rap
self.rap_labels = rap_labels
self._mode = 'label' if rap_labels is not None else 'mask'

self._base = {
'step_size': propagator.step_size,
'theta': propagator.theta,
'algo' : getattr(propagator, 'algo', None),
'tracking_neighbours' : getattr(propagator, 'tracking_neighbours', None)
}

if self._mode == 'label':
self.default_cfg = cfg.get('default', {})
self.methods_cfg = cfg.get('methods', {})
else:
self.rap_cfg = cfg

logging.info("RAP parameters loaded:")
logging.info(f" Original step_size: {self.original_step_size:.3f}, "
Expand Down Expand Up @@ -125,31 +145,92 @@ def rap_multistep_propagate(self, line, prev_direction):
Whether the line is valid.
"""
# Switch to RAP parameters
self.propagator.step_size = self.rap_step_size
self.propagator.theta = self.rap_theta

# Update tracking neighbours with new theta
from scilpy.tracking.propagator import get_sphere_neighbours
self.propagator.tracking_neighbours = get_sphere_neighbours(
self.propagator.sphere, self.rap_theta)

# Perform propagation with new parameters
new_pos, new_dir, is_direction_valid = \
self.propagator.propagate(line, prev_direction)

# Restore original parameters
self.propagator.step_size = self.original_step_size
self.propagator.theta = self.original_theta
self.propagator.tracking_neighbours = get_sphere_neighbours(
self.propagator.sphere, self.original_theta)

# Add the new point to the line
if is_direction_valid:
is_line_valid = True

# We allow RAP to extend the streamline while it stays inside the RAP region
# In mask mode: "inside" means rap_mask > 0
# In label mode: "inside" means label > 0 and we can switch config per label
while len(line) < self.max_nbr_pts:
curr_pos = line[-1]

# Select config depending on RAP mode
if self._mode == 'label':
label = self._get_label(curr_pos, self.propagator.space, self.propagator.origin)
if label <= 0:
break
cfg = self._merge_cfg(label)
else:
# Classic binary RAP mask behaviour
if not self.is_in_rap_region(curr_pos, self.propagator.space, self.propagator.origin):
break
cfg = self.rap_cfg

# Apply selected params for ONE step, then restore
self._apply_cfg(cfg)
try:
new_pos, new_dir, valid = self.propagator.propagate(line, prev_direction)
finally:
self._restore_base()

is_line_valid = is_line_valid and valid
if not valid:
break

line.append(new_pos)
return line, new_dir, True
else:
return line, prev_direction, False
prev_direction = new_dir

return line, prev_direction, is_line_valid

def _get_label(self, curr_pos, space, origin):
v = self.rap_labels.get_value_at_coordinate(*curr_pos, space=space, origin=origin)
try:
return int(v)
except Exception:
return int(np.round(v))

def _merge_cfg(self, label):
cfg = deepcopy(self.default_cfg)
override = self.methods_cfg.get(str(label), {})
cfg.update(override)
return cfg

def _apply_cfg(self, cfg):
if 'step_size' in cfg and cfg['step_size'] is not None:
self.propagator.step_size = float(cfg['step_size'])
if 'algo' in cfg and cfg['algo'] is not None:
self.propagator.algo = str(cfg['algo'])
if 'theta' in cfg and cfg['theta'] is not None:
theta_rad = np.deg2rad(float(cfg['theta']))
self.propagator.theta = theta_rad
# theta change => neighbours change
self.propagator.tracking_neighbours = get_sphere_neighbours(self.propagator.sphere, self.propagator.theta)

def _restore_base(self):
self.propagator.step_size = self._base['step_size']
self.propagator.theta = self._base['theta']
if self._base['algo'] is not None:
self.propagator.algo = self._base['algo']
if self._base['tracking_neighbours'] is not None:
self.propagator.tracking_neighbours = self._base['tracking_neighbours']

def is_in_rap_region(self, curr_pos, space, origin):
"""Override base class to support label-mode when rap_mask is None.
Tracker uses this to decide whether to enter/exit RAP.
- mask mode: inside if rap_mask > 0
- label mode: inside if rap_labels label > 0
"""
if self._mode == 'label':
if self.rap_labels is None:
return False
val = self.rap_labels.get_value_at_coordinate(
*curr_pos, space=space, origin=origin)
return val > 0

# mask mode (legacy)
if self.rap_mask is None:
return False
return self.rap_mask.get_value_at_coordinate(
*curr_pos, space=space, origin=origin) > 0

class RAPGraph(RAP):
def __init__(self, mask_rap, propagator, max_nbr_pts, neighboorhood_size):
Expand Down