Skip to content

Mg/features/sample to masks #401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 192 additions & 54 deletions deeptrack/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def propagate_data_to_dependencies(
"OneOf",
"OneOfDict",
"LoadImage", # TODO ***MG***
"SampleToMasks", # TODO ***MG***
"SampleToMasks",
"AsType", # TODO ***MG***
"ChannelFirst2d",
"Upscale", # TODO ***AL***
Expand Down Expand Up @@ -7451,21 +7451,21 @@ def get(
class SampleToMasks(Feature):
"""Create a mask from a list of images.

This feature applies a transformation function to each input image and
merges the resulting masks into a single multi-layer image. Each input
image must have a `position` property that determines its placement within
the final mask. When used with scatterers, the `voxel_size` property must
This feature applies a transformation function to each input image and
merges the resulting masks into a single multi-layer image. Each input
image must have a `position` property that determines its placement within
the final mask. When used with scatterers, the `voxel_size` property must
be provided for correct object sizing.

Parameters
----------
transformation_function: Callable[[Image], Image]
A function that transforms each input image into a mask with
A function that transforms each input image into a mask with
`number_of_masks` layers.
number_of_masks: PropertyLike[int], optional
The number of mask layers to generate. Default is 1.
output_region: PropertyLike[tuple[int, int, int, int]], optional
The size and position of the output mask, typically aligned with
The size and position of the output mask, typically aligned with
`optics.output_region`.
merge_method: PropertyLike[str | Callable | list[str | Callable]], optional
Method for merging individual masks into the final image. Can be:
Expand All @@ -7474,16 +7474,32 @@ class SampleToMasks(Feature):
- "or": Combine masks using a logical OR operation.
- "mul": Multiply masks.
- Function: Custom function taking two images and merging them.
- List: Specifies the merge method for each mask layer, where each
element is either a string (one of the above) or a callable.

**kwargs: dict[str, Any]
Additional keyword arguments passed to the parent `Feature` class.

Methods
-------
`get(image: np.ndarray | Image, transformation_function: Callable[[Image], Image], **kwargs: dict[str, Any]) -> Image`
Applies the transformation function to the input image.
`_process_and_get(images: list[np.ndarray] | np.ndarray | list[Image] | Image, **kwargs: dict[str, Any]) -> Image | np.ndarray`
Processes a list of images and generates a multi-layer mask.
`get(
image: np.ndarray | torch.Tensor | Image,
transformation_function: Callable[[Image], Image],
**kwargs: dict[str, Any]
) -> Image`
Apply the transformation function to the input image.
`_process_and_get(
images: (
list[np.ndarray]
| np.ndarray
| list[torch.Tensor]
| torch.Tensor
| list[Image]
| Image
),
**kwargs: dict[str, Any]
) -> Image | np.ndarray | torch.Tensor`
Process a list of images and generate a multi-layer mask.

Returns
-------
Expand All @@ -7504,7 +7520,7 @@ class SampleToMasks(Feature):

Define optics and particles:
>>> import numpy as np
>>>
>>>
>>> optics = dt.Fluorescence(output_region=(0, 0, 64, 64))
>>> particle = dt.PointParticle(
>>> position=lambda: np.random.uniform(5, 55, size=2),
Expand Down Expand Up @@ -7539,14 +7555,84 @@ class SampleToMasks(Feature):
>>> plt.title("Mask")
>>> plt.show()

Example demonstrating different merge methods:
>>> optics = dt.Fluorescence(output_region=(0, 0, 64, 64))
>>> particle = dt.Ellipse(
... position=lambda: np.random.uniform(10, 50, size=2),
... radius=2e-6
... )
>>> particles = particle ^ 2
>>> sim_im_pip = optics(particles)

Define a custom mask function that generates three identical layers:
>>> def mask_function(obj):
... value = np.random.randint(1, 10)
... mask = np.squeeze(obj > 0) * value
... h, w = mask.shape
... masks = np.zeros((h, w, 3), dtype=np.uint8)
... masks[..., 0] = mask
... masks[..., 1] = mask
... masks[..., 2] = mask
... return masks

Create mask pipeline with different merge strategies for each layer:
>>> sim_mask_pip = particles >> dt.SampleToMasks(
... lambda: mask_function,
... number_of_masks=3,
... output_region=optics.output_region,
... merge_method=['add', 'or', 'overwrite']
... )
>>> pipeline = sim_im_pip & sim_mask_pip
>>> pipeline.store_properties()

Generate image and mask with three channels:
>>> image, mask = pipeline.update()()

Visualize the image and the three mask layers:
>>> import matplotlib.pyplot as plt
>>> plt.subplot(1, 4, 1)
>>> plt.imshow(image, cmap='gray')
>>> plt.title("Image")
>>> plt.subplot(1, 4, 2)
>>> plt.imshow(mask[..., 0], cmap='gray')
>>> plt.title("Mask: 'add'")
>>> plt.subplot(1, 4, 3)
>>> plt.imshow(mask[..., 1], cmap='gray')
>>> plt.title("Mask: 'or'")
>>> plt.subplot(1, 4, 4)
>>> plt.imshow(mask[..., 2], cmap='gray')
>>> plt.title("Mask: 'overwrite'")
>>> plt.show()

"""

def __init__(
self: Feature,
transformation_function: Callable[[Image], Image],
transformation_function: PropertyLike[
Callable[
[
NDArray
| list[NDArray]
| torch.Tensor
| list[torch.Tensor]
| Image
| list[Image]
],
NDArray
| list[NDArray]
| torch.Tensor
| list[torch.Tensor]
| Image
| list[Image]
],
],
number_of_masks: PropertyLike[int] = 1,
output_region: PropertyLike[tuple[int, int, int, int]] = None,
merge_method: PropertyLike[str | Callable | list[str | Callable]] = "add",
output_region: PropertyLike[tuple[int, int, int, int]] | None = None,
merge_method: PropertyLike[
str
| Callable[[Any, Any], Any]
| list[str | Callable[[Any, Any], Any]]
] = "add",
**kwargs: Any,
):
"""Initialize the SampleToMasks feature.
Expand All @@ -7556,14 +7642,16 @@ def __init__(
transformation_function: Callable[[Image], Image]
Function to transform input images into masks.
number_of_masks: PropertyLike[int], optional
Number of mask layers. Default is 1.
Number of mask layers. It defaults to 1.
output_region: PropertyLike[tuple[int, int, int, int]], optional
Output region of the mask. Default is None.
merge_method: PropertyLike[str | Callable | list[str | Callable]], optional
Output region of the mask. It defaults to `None`.
merge_method: PropertyLike[
str | Callable | list[str | Callable]
], optional
Method to merge masks. Default is "add".
**kwargs: dict[str, Any]
Additional keyword arguments passed to the parent class.

"""

super().__init__(
Expand All @@ -7576,15 +7664,15 @@ def __init__(

def get(
self: Feature,
image: np.ndarray | Image,
image: NDArray | torch.Tensor | Image,
transformation_function: Callable[[Image], Image],
**kwargs: Any,
) -> Image:
"""Apply the transformation function to a single image.

Parameters
----------
image: np.ndarray | Image
image: np.ndarray | torch.Tensor | Image
The input image.
transformation_function: Callable[[Image], Image]
Function to transform the image.
Expand All @@ -7602,34 +7690,45 @@ def get(

def _process_and_get(
self: Feature,
images: list[np.ndarray] | np.ndarray | list[Image] | Image,
images: (
list[NDArray]
| NDArray
| list[torch.Tensor]
| torch.Tensor
| list[Image]
| Image
),
**kwargs: Any,
) -> Image | np.ndarray:
) -> NDArray | torch.Tensor | Image:
"""Process a list of images and generate a multi-layer mask.

Parameters
----------
images: np.ndarray or list[np.ndarrray] or Image or list[Image]
images: np.ndarray or list[np.ndarrray] or torch.Tensor or
list[torch.tensor] or Image or list[Image]
List of input images or a single image.
**kwargs: dict[str, Any]
Additional parameters including `output_region`, `number_of_masks`,
Additional parameters including `output_region`, `number_of_masks`,
and `merge_method`.

Returns
-------
Image or np.ndarray
Image, np.ndarray, or torch.Tensor
The final mask image.

"""

# Handle list of images.
if isinstance(images, list) and len(images) != 1:
list_of_labels = super()._process_and_get(images, **kwargs)
if not self._wrap_array_with_image:
for idx, (label, image) in enumerate(zip(list_of_labels,
images)):
list_of_labels[idx] = \
Image(label, copy=False).merge_properties_from(image)
for idx, (label, image) in enumerate(
zip(list_of_labels, images)
):
list_of_labels[idx] = Image(
label,
copy=False
).merge_properties_from(image)
else:
if isinstance(images, list):
images = images[0]
Expand All @@ -7638,50 +7737,90 @@ def _process_and_get(

if "position" in prop:

inp = Image(np.array(images))
if apc.is_torch_array(images):
inp = Image(images)

else:
inp = Image(np.array(images))

inp.append(prop)
out = Image(self.get(inp, **kwargs))
out.merge_properties_from(inp)
list_of_labels.append(out)

# Create an empty output image.
output_region = kwargs["output_region"]
output = np.zeros(
(
shape = (
output_region[2] - output_region[0],
output_region[3] - output_region[1],
kwargs["number_of_masks"],
)
)
if apc.is_torch_array(images):
output = torch.zeros(shape)
else:
output = np.zeros(shape)

from deeptrack.optics import _get_position

# Merge masks into the output.
for label in list_of_labels:
position = _get_position(label)
p0 = np.round(position - output_region[0:2])
p0 = xp.round(position - output_region[0:2])

if np.any(p0 > output.shape[0:2]) or \
np.any(p0 + label.shape[0:2] < 0):
if xp.any(p0 > output.shape[0:2]) or xp.any(
p0 + label.shape[0:2] < 0
):
continue

crop_x = int(-np.min([p0[0], 0]))
crop_y = int(-np.min([p0[1], 0]))
crop_x_end = int(
label.shape[0]
- np.max([p0[0] + label.shape[0] - output.shape[0], 0])
)
crop_y_end = int(
label.shape[1]
- np.max([p0[1] + label.shape[1] - output.shape[1], 0])
)
if apc.is_torch_array(images):
crop_x = int(
-torch.minimum(p0[0], torch.tensor(0, device=p0.device))
)
crop_y = int(
-torch.minimum(p0[1], torch.tensor(0, device=p0.device))
)
crop_x_end = int(
label.shape[0] - torch.max(torch.stack([
p0[0] + label.shape[0] - output.shape[0],
torch.tensor(0, device=p0.device)
]))
)
crop_y_end = int(
label.shape[1]- torch.max(torch.stack([
p0[1] + label.shape[1] - output.shape[1],
torch.tensor(0, device=p0.device)
]))
)

labelarg = label[crop_x:crop_x_end, crop_y:crop_y_end, :]
labelarg = label[crop_x:crop_x_end, crop_y:crop_y_end, :]

p0[0] = torch.max(
p0[0], torch.tensor(0, dtype=p0.dtype, device=p0.device)
)
p0[1] = torch.max(
p0[1], torch.tensor(0, dtype=p0.dtype, device=p0.device)
)

p0 = p0.int()

else:
crop_x = int(-np.min([p0[0], 0]))
crop_y = int(-np.min([p0[1], 0]))
crop_x_end = int(
label.shape[0]
- np.max([p0[0] + label.shape[0] - output.shape[0], 0])
)
crop_y_end = int(
label.shape[1]
- np.max([p0[1] + label.shape[1] - output.shape[1], 0])
)

p0[0] = np.max([p0[0], 0])
p0[1] = np.max([p0[1], 0])
labelarg = label[crop_x:crop_x_end, crop_y:crop_y_end, :]

p0 = p0.astype(int)
p0[0] = np.max([p0[0], 0])
p0[1] = np.max([p0[1], 0])

p0 = p0.astype(int)

output_slice = output[
p0[0] : p0[0] + labelarg.shape[0],
Expand All @@ -7705,8 +7844,7 @@ def _process_and_get(
elif merge == "overwrite":
output_slice[
labelarg[..., label_index] != 0, label_index
] = labelarg[labelarg[..., label_index] != 0, \
label_index]
] = labelarg[labelarg[..., label_index] != 0, label_index]
output[
p0[0] : p0[0] + labelarg.shape[0],
p0[1] : p0[1] + labelarg.shape[1],
Expand Down
Loading
Loading