Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lionel-lig-5887 remove heuristics in masked pooling #1777

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 15 additions & 20 deletions lightly/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,21 @@


def pool_masked(
source: Tensor, mask: Tensor, reduce: str = "mean", num_cls: Optional[int] = None
source: Tensor, mask: Tensor, num_cls: int, reduce: str = "mean"
) -> Tensor:
"""Reduce image feature maps (B, C, H, W) or (C, H, W) according to an integer
index given by `mask` (B, H, W) or (H, W).
"""Reduce image feature maps :math:`(B, C, H, W)` or :math:`(C, H, W)` according to
an integer index given by `mask` :math:`(B, H, W)` or :math:`(H, W)`.

Args:
source: Float tensor of shape (B, C, H, W) or (C, H, W) to be reduced.
mask: Integer tensor of shape (B, H, W) or (H, W) containing the integer indices.
reduce: The reduction operation to be applied, one of 'prod', 'mean', 'amax' or
'amin'. Defaults to 'mean'.
num_cls: The number of classes in the possible masks. If None, the number of classes
is inferred from the unique elements in `mask`. This is useful when not all
classes are present in the mask.
source: Float tensor of shape :math:`(B, C, H, W)` or :math:`(C, H, W)` to be
reduced.
mask: Integer tensor of shape :math:`(B, H, W)` or :math:`(H, W)` containing the
integer indices.
num_cls: The number of classes in the possible masks.

Returns:
A tensor of shape (B, C, N) or (C, N) where N is the number of unique elements
in `mask` or `num_cls` if specified.
A tensor of shape :math:`(B, C, N)` or :math:`(C, N)` where :math:`N`
corresponds to `num_cls`.
liopeer marked this conversation as resolved.
Show resolved Hide resolved
"""
if source.dim() == 3:
return _mask_reduce(source, mask, reduce, num_cls)
Expand All @@ -55,29 +53,26 @@ def pool_masked(


def _mask_reduce(
source: Tensor, mask: Tensor, reduce: str = "mean", num_cls: Optional[int] = None
source: Tensor, mask: Tensor, num_cls: int, reduce: str = "mean"
) -> Tensor:
output = _mask_reduce_batched(
source.unsqueeze(0), mask.unsqueeze(0), num_cls=num_cls
source.unsqueeze(0), mask.unsqueeze(0), num_cls=num_cls, reduce=reduce
)
return output.squeeze(0)


def _mask_reduce_batched(
source: Tensor, mask: Tensor, num_cls: Optional[int] = None
source: Tensor, mask: Tensor, num_cls: int, reduce: str = "mean"
) -> Tensor:
b, c, h, w = source.shape
if num_cls is None:
cls = mask.unique(sorted=True)
else:
cls = torch.arange(num_cls, device=mask.device)
cls = torch.arange(num_cls, device=mask.device)
num_cls = cls.size(0)
# create output tensor
output = source.new_zeros((b, c, num_cls)) # (B C N)
mask = mask.unsqueeze(1).expand(-1, c, -1, -1).view(b, c, -1) # (B C HW)
source = source.view(b, c, -1) # (B C HW)
output.scatter_reduce_(
dim=2, index=mask, src=source, reduce="mean", include_self=False
dim=2, index=mask, src=source, reduce=reduce, include_self=False
) # (B C N)
# scatter_reduce_ produces NaNs if the count is zero
output = torch.nan_to_num(output, nan=0.0)
Expand Down
9 changes: 0 additions & 9 deletions tests/models/test_ModelUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,6 @@ def test_masked_pooling_manual(
assert out_manual.shape == (1, 3, 2)
assert (out_manual == expected_result2[:, :2]).all()

def test_masked_pooling_auto(
self, feature_map2: Tensor, mask2: Tensor, expected_result2: Tensor
) -> None:
out_auto = pool_masked(
feature_map2.unsqueeze(0), mask2.unsqueeze(0), num_cls=None
)
assert out_auto.shape == (1, 3, 2)
assert (out_auto == expected_result2[:, :2]).all()

# Type ignore because untyped decorator makes function untyped.
@pytest.mark.parametrize(
"feature_map, mask, expected_result",
Expand Down