From 378abce1f86ac62afdc55168ed1bfbdecf6a0b0a Mon Sep 17 00:00:00 2001 From: Lionel Date: Fri, 10 Jan 2025 10:39:24 +0100 Subject: [PATCH 1/4] remove heuristics --- lightly/models/utils.py | 35 ++++++++++++++------------------- tests/models/test_ModelUtils.py | 9 --------- 2 files changed, 15 insertions(+), 29 deletions(-) diff --git a/lightly/models/utils.py b/lightly/models/utils.py index 1b0b3da5a..3a29d0c43 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -28,23 +28,21 @@ def pool_masked( - source: Tensor, mask: Tensor, reduce: str = "mean", num_cls: Optional[int] = None + source: Tensor, mask: Tensor, num_cls: int, reduce: str = "mean" ) -> Tensor: - """Reduce image feature maps (B, C, H, W) or (C, H, W) according to an integer - index given by `mask` (B, H, W) or (H, W). + """Reduce image feature maps :math:`(B, C, H, W)` or :math:`(C, H, W)` according to + an integer index given by `mask` :math:`(B, H, W)` or :math:`(H, W)`. Args: - source: Float tensor of shape (B, C, H, W) or (C, H, W) to be reduced. - mask: Integer tensor of shape (B, H, W) or (H, W) containing the integer indices. - reduce: The reduction operation to be applied, one of 'prod', 'mean', 'amax' or - 'amin'. Defaults to 'mean'. - num_cls: The number of classes in the possible masks. If None, the number of classes - is inferred from the unique elements in `mask`. This is useful when not all - classes are present in the mask. + source: Float tensor of shape :math:`(B, C, H, W)` or :math:`(C, H, W)` to be + reduced. + mask: Integer tensor of shape :math:`(B, H, W)` or :math:`(H, W)` containing the + integer indices. + num_cls: The number of classes in the possible masks. Returns: - A tensor of shape (B, C, N) or (C, N) where N is the number of unique elements - in `mask` or `num_cls` if specified. + A tensor of shape :math:`(B, C, N)` or :math:`(C, N)` where :math:`N` + corresponds to `num_cls`. """ if source.dim() == 3: return _mask_reduce(source, mask, reduce, num_cls) @@ -55,29 +53,26 @@ def pool_masked( def _mask_reduce( - source: Tensor, mask: Tensor, reduce: str = "mean", num_cls: Optional[int] = None + source: Tensor, mask: Tensor, num_cls: int, reduce: str = "mean" ) -> Tensor: output = _mask_reduce_batched( - source.unsqueeze(0), mask.unsqueeze(0), num_cls=num_cls + source.unsqueeze(0), mask.unsqueeze(0), num_cls=num_cls, reduce=reduce ) return output.squeeze(0) def _mask_reduce_batched( - source: Tensor, mask: Tensor, num_cls: Optional[int] = None + source: Tensor, mask: Tensor, num_cls: int, reduce: str = "mean" ) -> Tensor: b, c, h, w = source.shape - if num_cls is None: - cls = mask.unique(sorted=True) - else: - cls = torch.arange(num_cls, device=mask.device) + cls = torch.arange(num_cls, device=mask.device) num_cls = cls.size(0) # create output tensor output = source.new_zeros((b, c, num_cls)) # (B C N) mask = mask.unsqueeze(1).expand(-1, c, -1, -1).view(b, c, -1) # (B C HW) source = source.view(b, c, -1) # (B C HW) output.scatter_reduce_( - dim=2, index=mask, src=source, reduce="mean", include_self=False + dim=2, index=mask, src=source, reduce=reduce, include_self=False ) # (B C N) # scatter_reduce_ produces NaNs if the count is zero output = torch.nan_to_num(output, nan=0.0) diff --git a/tests/models/test_ModelUtils.py b/tests/models/test_ModelUtils.py index 0c2b292b4..2bde9f4fd 100644 --- a/tests/models/test_ModelUtils.py +++ b/tests/models/test_ModelUtils.py @@ -104,15 +104,6 @@ def test_masked_pooling_manual( assert out_manual.shape == (1, 3, 2) assert (out_manual == expected_result2[:, :2]).all() - def test_masked_pooling_auto( - self, feature_map2: Tensor, mask2: Tensor, expected_result2: Tensor - ) -> None: - out_auto = pool_masked( - feature_map2.unsqueeze(0), mask2.unsqueeze(0), num_cls=None - ) - assert out_auto.shape == (1, 3, 2) - assert (out_auto == expected_result2[:, :2]).all() - # Type ignore because untyped decorator makes function untyped. @pytest.mark.parametrize( "feature_map, mask, expected_result", From e6929631c9ab529dee65e74404f763ba666c4559 Mon Sep 17 00:00:00 2001 From: Lionel Date: Fri, 10 Jan 2025 10:46:28 +0100 Subject: [PATCH 2/4] small change in docstring --- lightly/models/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lightly/models/utils.py b/lightly/models/utils.py index 3a29d0c43..9ac703182 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -40,9 +40,9 @@ def pool_masked( integer indices. num_cls: The number of classes in the possible masks. - Returns: - A tensor of shape :math:`(B, C, N)` or :math:`(C, N)` where :math:`N` - corresponds to `num_cls`. + Returns: + A tensor of shape :math:`(B, C, N)` or :math:`(C, N)` where :math:`N` + corresponds to `num_cls`. """ if source.dim() == 3: return _mask_reduce(source, mask, reduce, num_cls) From d35cbd1b4e2b19a2158edbf2b802a6163e98275a Mon Sep 17 00:00:00 2001 From: Lionel Date: Fri, 10 Jan 2025 10:48:03 +0100 Subject: [PATCH 3/4] another formatting error --- lightly/models/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightly/models/utils.py b/lightly/models/utils.py index 9ac703182..347181cd2 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -40,8 +40,8 @@ def pool_masked( integer indices. num_cls: The number of classes in the possible masks. - Returns: - A tensor of shape :math:`(B, C, N)` or :math:`(C, N)` where :math:`N` + Returns: + A tensor of shape :math:`(B, C, N)` or :math:`(C, N)` where :math:`N` corresponds to `num_cls`. """ if source.dim() == 3: From 62875c427246f1d26224488526933bc803b30376 Mon Sep 17 00:00:00 2001 From: Lionel Peer Date: Fri, 10 Jan 2025 11:25:45 +0100 Subject: [PATCH 4/4] Update lightly/models/utils.py Co-authored-by: guarin <43336610+guarin@users.noreply.github.com> --- lightly/models/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lightly/models/utils.py b/lightly/models/utils.py index 347181cd2..b1911ce16 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -41,8 +41,7 @@ def pool_masked( num_cls: The number of classes in the possible masks. Returns: - A tensor of shape :math:`(B, C, N)` or :math:`(C, N)` where :math:`N` - corresponds to `num_cls`. + A tensor of shape :math:`(B, C, num_cls)` or :math:`(C, num_cls)`. """ if source.dim() == 3: return _mask_reduce(source, mask, reduce, num_cls)