Skip to content
4 changes: 3 additions & 1 deletion sample-apps/radiology/lib/infers/deepedit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import logging
from typing import Callable, Sequence, Union

from lib.transforms.transforms import GetCentroidsd
from lib.transforms.transforms import GetCentroidsd, OrientationGuidanceMultipleLabelDeepEditd

from monai.apps.deepedit.transforms import (
AddGuidanceFromPointsDeepEditd,
AddGuidanceSignalDeepEditd,
Expand Down Expand Up @@ -88,6 +89,7 @@ def pre_transforms(self, data=None):
if self.type == InferType.DEEPEDIT:
t.extend(
[
OrientationGuidanceMultipleLabelDeepEditd(ref_image="image", label_names=self.labels),
AddGuidanceFromPointsDeepEditd(ref_image="image", guidance="guidance", label_names=self.labels),
Resized(keys="image", spatial_size=self.spatial_size, mode="area"),
ResizeGuidanceMultipleLabelDeepEditd(guidance="guidance", ref_image="image"),
Expand Down
34 changes: 34 additions & 0 deletions sample-apps/radiology/lib/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import torch
from einops import rearrange
from monai.config import KeysCollection, NdarrayOrTensor
from monai.data import MetaTensor
from monai.networks.layers import GaussianFilter
Expand Down Expand Up @@ -511,6 +512,39 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
return d


class OrientationGuidanceMultipleLabelDeepEditd(Transform):
def __init__(self, ref_image: str, label_names=None):
"""
Convert the guidance to the RAS orientation
"""
self.ref_image = ref_image
self.label_names = label_names

def transform_points(self, point, affine):
"""transform point to the coordinates of the transformed image
point: numpy array [bs, N, 3]
"""
bs, N = point.shape[:2]
point = np.concatenate((point, np.ones((bs, N, 1))), axis=-1)
point = rearrange(point, "b n d -> d (b n)")
point = affine @ point
point = rearrange(point, "d (b n)-> b n d", b=bs)[:, :, :3]
return point

def __call__(self, data):
d: Dict = dict(data)
for key_label in self.label_names.keys():
points = d.get(key_label, [])
if len(points) < 1:
continue
reoriented_points = self.transform_points(
np.array(points)[None],
np.linalg.inv(d[self.ref_image].meta["affine"].numpy()) @ d[self.ref_image].meta["original_affine"],
)
d[key_label] = reoriented_points[0]
return d


def get_guidance_tensor_for_key_label(data, key_label, device) -> torch.Tensor:
"""Makes sure the guidance is in a tensor format."""
tmp_gui = data.get(key_label, torch.tensor([], dtype=torch.int32, device=device))
Expand Down
Loading