-
Notifications
You must be signed in to change notification settings - Fork 57
Al/math/minpooling #410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Al/math/minpooling #410
Changes from all commits
067882b
b90e77d
52655d4
c40c2a1
d124d72
d49d8d3
1c838e9
60735fc
f025cc5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1313,15 +1313,19 @@ def __init__( | |
super().__init__(np.max, ksize=ksize, **kwargs) | ||
|
||
|
||
#TODO ***AL*** revise MinPooling - torch, typing, docstring, unit test | ||
class MinPooling(Pool): | ||
"""Apply min pooling to images. | ||
|
||
This class reduces the resolution of an image by dividing it into | ||
non-overlapping blocks of size `ksize` and applying the min function to | ||
each block. The result is a downsampled image where each pixel value | ||
represents the minimum value within the corresponding block of the | ||
original image. | ||
This class inherits from `Pool` to reduce the resolution of an image by | ||
dividing it into non-overlapping blocks of size `ksize` and applying the | ||
`min` function to each block. The result is a downsampled image where each | ||
pixel value represents the minimum value within the corresponding block of | ||
the original image. | ||
|
||
If the backend is numpy, the downsampling is performed using | ||
`skimage.measure.block_reduce`. | ||
If the backend is torch, the downsampling is performed using the inverse | ||
of `torch.nn.functional.max_pool2d` by changing the sign of the input. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -1339,15 +1343,16 @@ class MinPooling(Pool): | |
>>> input_image = np.random.rand(32, 32) | ||
|
||
Define a min pooling feature: | ||
>>> min_pooling = dt.MinPooling(ksize=3) | ||
>>> min_pooling = dt.MinPooling(ksize=4) | ||
>>> output_image = min_pooling(input_image) | ||
>>> print(output_image.shape) | ||
(32, 32) | ||
(8, 8) | ||
|
||
Notes | ||
----- | ||
Calling this feature returns a `np.ndarray` by default. If | ||
`store_properties` is set to `True`, the returned array will be | ||
Calling this feature returns a pooled image of the input, it will return | ||
either numpy or torch depending on the backend. If `store_properties` is | ||
set to `True` and the input is a numpy array, the returned array will be | ||
automatically wrapped in an `Image` object. This behavior is handled | ||
internally and does not affect the return type of the `get()` method. | ||
|
||
|
@@ -1360,7 +1365,8 @@ def __init__( | |
): | ||
"""Initialize the parameters for min pooling. | ||
|
||
This constructor initializes the parameters for min pooling. | ||
This constructor initializes the parameters for min pooling and checks | ||
whether to use the numpy or torch implementation, defaults to numpy. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -1373,6 +1379,109 @@ def __init__( | |
|
||
super().__init__(np.min, ksize=ksize, **kwargs) | ||
|
||
def _get_numpy( | ||
self, | ||
image: NDArray, | ||
ksize: int=3, | ||
**kwargs, | ||
): | ||
"""Method to perform average pooling with the numpy backend enabled. | ||
|
||
Returns the result of the image passed to the scikit image block_reduce | ||
function with `np.min` as the pooling function. | ||
|
||
Parameters | ||
---------- | ||
image: NDArray | ||
Input image to be pooled. | ||
ksize: int | ||
Kernel size of the pooling operation. | ||
|
||
Returns | ||
------- | ||
NDArray | ||
The pooled image as a `NDArray`. | ||
|
||
""" | ||
return utils.safe_call( | ||
skimage.measure.block_reduce, | ||
image=image, | ||
func=self.pooling, # This will be np.min for this class. | ||
block_size=ksize, | ||
**kwargs, | ||
) | ||
|
||
def _get_torch( | ||
self, | ||
image: torch.Tensor, | ||
ksize: int=3, | ||
**kwargs, | ||
): | ||
"""Method to perform min pooling with the torch backend enabled. | ||
As torch does not contain a min pooling layer, the equivalent | ||
operation is to first multiply the input image with `-1`, | ||
perform max pooling and multiply the max pooled image with `-1`. | ||
|
||
Parameters | ||
---------- | ||
image: torch.Tensor | ||
Input image to be pooled. | ||
ksize: int | ||
Kernel size of the pooling operation. | ||
|
||
Returns | ||
------- | ||
torch.Tensor | ||
The pooled image as a `torch.Tensor`. | ||
|
||
""" | ||
# If needed, expand tensor shape | ||
if len(image.shape) == 2: | ||
expanded_image = image.unsqueeze(0) | ||
|
||
pooled_image = -torch.nn.functional.max_pool2d( | ||
expanded_image*(-1), kernel_size=ksize, | ||
) | ||
# Remove the expanded dim. | ||
return pooled_image.squeeze(0) | ||
|
||
return -torch.nn.functional.max_pool2d( | ||
image*(-1), | ||
kernel_size=ksize, | ||
) | ||
|
||
def get( | ||
self, | ||
image: NDArray | torch.Tensor, | ||
ksize: int=3, | ||
**kwargs, | ||
): | ||
"""Method to perform pooling with either torch or numpy backend. | ||
|
||
Checks the current backend and chooses the appropriate function to pool | ||
the input image, either `_get_torch` or `_get_numpy`. | ||
|
||
Parameters | ||
---------- | ||
image: NDArray | torch.Tensor | ||
Input image to be pooled. | ||
ksize: int | ||
Kernel size of the pooling operation. | ||
|
||
Returns | ||
------- | ||
NDArray | torch.Tensor | ||
The pooled image as `NDArray` or `torch.Tensor` depending on | ||
the backend. | ||
|
||
""" | ||
|
||
if self.get_backend() == "numpy": | ||
return self._get_numpy(image, ksize, **kwargs,) | ||
elif self.get_backend() == "torch": | ||
return self._get_torch(image, ksize, **kwargs,) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer to check the type of the image, instead of checking the backend. So "if apc.is_torch_array(array): ...", like we have done in several of the features. But I don't think that the way it is done now is wrong There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can always do both but I think maintaining consistency is more important in this case like you pointed out. |
||
else: | ||
raise NotImplementedError(f"Backend {self.backend} not supported") | ||
|
||
#TODO ***AL*** revise MedianPooling - torch, typing, docstring, unit test | ||
class MedianPooling(Pool): | ||
|
Uh oh!
There was an error while loading. Please reload this page.