diff --git a/deeptrack/math.py b/deeptrack/math.py index 13af56ce..05783551 100644 --- a/deeptrack/math.py +++ b/deeptrack/math.py @@ -1249,25 +1249,26 @@ def __init__( super().__init__(np.mean, ksize=ksize, **kwargs) -#TODO ***AL*** revise MaxPooling - torch, typing, docstring, unit test class MaxPooling(Pool): """Apply max pooling to images. - This class reduces the resolution of an image by dividing it into - non-overlapping blocks of size `ksize` and applying the max function to - each block. The result is a downsampled image where each pixel value - represents the maximum value within the corresponding block of the - original image. - This is useful for reducing the size of an image while retaining the - most significant features. + This class inherits from `Pool` to reduce the resolution of an image by + dividing it into non-overlapping blocks of size `ksize` and applying the + `max` function to each block. The result is a downsampled image where each + pixel value represents the maximum value within the corresponding block of + the original image. This is useful for reducing the size of an image while + retaining the most significant features. + + If the backend is numpy, the downsampling is performed using + `skimage.measure.block_reduce`. + If the backend is torch, the downsampling + is performed using `torch.nn.functional.max_pool2d`. Parameters ---------- ksize: int Size of the pooling kernel. - cval: number - Value to pad edges with if necessary. Default 0. - func_kwargs: dict + **kwargs: dict Additional parameters sent to the pooling function. Examples @@ -1281,12 +1282,13 @@ class MaxPooling(Pool): >>> max_pooling = dt.MaxPooling(ksize=8) >>> output_image = max_pooling(input_image) >>> print(output_image.shape) - (8, 8) + (4, 4) Notes ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be + Calling this feature returns a pooled image of the input, it will return + either numpy or torch depending on the backend. If `store_properties` is + set to `True` and the input is a numpy array, the returned array will be automatically wrapped in an `Image` object. This behavior is handled internally and does not affect the return type of the `get()` method. @@ -1299,7 +1301,8 @@ def __init__( ): """Initialize the parameters for max pooling. - This constructor initializes the parameters for max pooling. + This constructor initializes the parameters for max pooling and checks + whether to use the numpy or torch implementation, defaults to numpy. Parameters ---------- @@ -1309,9 +1312,110 @@ def __init__( Additional keyword arguments. """ - super().__init__(np.max, ksize=ksize, **kwargs) + def _get_numpy( + self, + image: NDArray, + ksize: int=3, + **kwargs, + ): + """Method to perform average pooling with the numpy backend enabled. + + Returns the result of the image passed to the scikit image block_reduce + function with `np.max()` as the pooling function. + + Parameters + ---------- + image: NDArray + Input image to be pooled. + ksize: int + Kernel size of the pooling operation. + + Returns + ------- + NDArray + The pooled image as a `NDArray`. + + """ + return utils.safe_call( + skimage.measure.block_reduce, + image=image, + func=self.pooling, # This will be np.max for this class. + block_size=ksize, + **kwargs, + ) + + def _get_torch( + self, + image: torch.Tensor, + ksize: int=3, + **kwargs, + ): + """Method to perform max pooling with the torch backend enabled. + + Returns the result of the image passed to a torch max pooling layer. + + Parameters + ---------- + image: torch.Tensor + Input image to be pooled. + ksize: int + Kernel size of the pooling operation. + + Returns + ------- + torch.Tensor + The pooled image as a `torch.Tensor`. + + """ + # If needed, expand tensor shape + if len(image.shape) == 2: + expanded_image = image.unsqueeze(0) + + pooled_image = torch.nn.functional.max_pool2d( + expanded_image, kernel_size=ksize, + ) + # Remove the expanded dim. + return pooled_image.squeeze(0) + + return torch.nn.functional.max_pool2d( + image, + kernel_size=ksize, + ) + + def get( + self, + image: NDArray | torch.Tensor, + ksize: int=3, + **kwargs, + ): + """Method to perform pooling with either torch or numpy backend. + + Checks the current backend and chooses the appropriate function to pool + the input image, either `_get_torch` or `_get_numpy`. + + Parameters + ---------- + image: NDArray | torch.Tensor + Input image to be pooled. + ksize: int + Kernel size of the pooling operation. + + Returns + ------- + NDArray | torch.Tensor + The pooled image as `NDArray` or `torch.Tensor` depending on + the backend. + + """ + if self.get_backend() == "numpy": + return self._get_numpy(image, ksize, **kwargs,) + elif self.get_backend() == "torch": + return self._get_torch(image, ksize, **kwargs,) + else: + raise NotImplementedError(f"Backend {self.backend} not supported") + #TODO ***AL*** revise MinPooling - torch, typing, docstring, unit test class MinPooling(Pool): diff --git a/deeptrack/tests/test_math.py b/deeptrack/tests/test_math.py index 197afb40..d3593633 100644 --- a/deeptrack/tests/test_math.py +++ b/deeptrack/tests/test_math.py @@ -78,16 +78,43 @@ def test_Blur(self): #input_image = xp.asarray(np.array([[1, 2], [3, 4]], dtype=float)) #expected_output = xp.asarray(np.array([[1, 1.5], [2, 2.5]])) - #eature = math.Blur(filter_function=uniform_filter, size=2) + #feature = math.Blur(filter_function=uniform_filter, size=2) #blurred_image = feature.resolve(input_image) #self.assertTrue(xp.all(blurred_image == expected_output)) + def test_MaxPooling(self): + input_image = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=float) + feature = math.MaxPooling(ksize=2) + pooled_image = feature.resolve(input_image) + self.assertTrue(np.all(pooled_image == [[6.0, 8.0]])) + self.assertEqual(pooled_image.shape, (1, 2)) + # Extending the test and setting the backend to torch @unittest.skipUnless(TORCH_AVAILABLE, "PyTorch is not installed.") class TestMath_Torch(TestMath_Numpy): BACKEND = "torch" - pass + + def test_MaxPooling(self): + # (1, 1, 2, 4) + input_image = torch.tensor([[[ [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0] ]]]) + feature = math.MaxPooling(ksize=2) + pooled_image = feature(input_image, ksize=2) + expected = torch.tensor([[[[6.0, 8.0]]]]) + self.assertEqual(pooled_image.shape, expected.shape) + self.assertTrue(torch.allclose(pooled_image, expected)) + self.assertTrue(isinstance(pooled_image, torch.Tensor)) + + # (2, 4) + input_image = torch.tensor([ [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0] ]) + feature = math.MaxPooling(ksize=2) + pooled_image = feature(input_image, ksize=2) + expected = torch.tensor([[6.0, 8.0]]) + self.assertEqual(pooled_image.shape, expected.shape) + self.assertTrue(torch.allclose(pooled_image, expected)) + self.assertTrue(isinstance(pooled_image, torch.Tensor)) class TestMath(unittest.TestCase):