-
Notifications
You must be signed in to change notification settings - Fork 289
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add IBOTPatchLoss * Add Center * Refactor DINOLoss to use Center
- Loading branch information
Showing
7 changed files
with
321 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import torch | ||
from torch import Tensor | ||
from torch.nn import Module | ||
from torch.nn import functional as F | ||
|
||
from lightly.models.modules.center import Center | ||
|
||
|
||
class IBOTPatchLoss(Module): | ||
"""Implementation of the iBOT patch loss [0] as used in DINOv2 [1]. | ||
Implementation is based on [2]. | ||
- [0]: iBOT, 2021, https://arxiv.org/abs/2111.07832 | ||
- [1]: DINOv2, 2023, https://arxiv.org/abs/2304.07193 | ||
- [2]: https://github.com/facebookresearch/dinov2/blob/main/dinov2/loss/ibot_patch_loss.py | ||
Attributes: | ||
output_dim: | ||
Dimension of the model output. | ||
teacher_temp: | ||
Temperature for the teacher output. | ||
student_temp: | ||
Temperature for the student output. | ||
center_mode: | ||
Mode for center calculation. Only 'mean' is supported. | ||
center_momentum: | ||
Momentum term for the center update. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
output_dim: int, | ||
teacher_temp: float = 0.04, | ||
student_temp: float = 0.1, | ||
center_mode: str = "mean", | ||
center_momentum: float = 0.9, | ||
) -> None: | ||
super().__init__() | ||
self.teacher_temp = teacher_temp | ||
self.student_temperature = student_temp | ||
self.center = Center( | ||
size=(1, output_dim), | ||
mode=center_mode, | ||
momentum=center_momentum, | ||
) | ||
|
||
def forward( | ||
self, | ||
teacher_out: Tensor, | ||
student_out: Tensor, | ||
mask: Tensor, | ||
) -> Tensor: | ||
"""Forward pass through the iBOT patch loss. | ||
Args: | ||
teacher_out: | ||
Tensor with shape (batch_size * sequence_length, embed_dim) containing | ||
the teacher output of the masked tokens. | ||
student_out: | ||
Tensor with shape (batch_size * sequence_length, embed_dim) containing | ||
the student output of the masked tokens. | ||
mask: | ||
Boolean tensor with shape (batch_size, height, width) containing the | ||
token mask. Exactly batch_size * sequence_length entries must be set to | ||
True in the mask. | ||
Returns: | ||
Loss value. | ||
""" | ||
# B = batch size, N = sequence length = number of masked tokens, D = embed dim | ||
# H = height (in tokens), W = width (in tokens) | ||
# Note that N <= H * W depending on how many tokens are masked. | ||
|
||
# Calculate cross entropy loss. | ||
teacher_softmax = F.softmax( | ||
(teacher_out - self.center.value) / self.teacher_temp, dim=-1 | ||
) | ||
student_log_softmax = F.log_softmax( | ||
student_out / self.student_temperature, dim=-1 | ||
) | ||
# (B * N, D) -> (B * N) | ||
loss = -torch.sum(teacher_softmax * student_log_softmax, dim=-1) | ||
|
||
# Get weights. | ||
# (B, H, W) -> (B, 1, 1) | ||
num_masked_per_image = mask.sum(dim=(1, 2), keepdim=True).clamp(min=1.0) | ||
# (B, 1, 1) -> (B, H, W) -> (B * N) | ||
weight = (1.0 / num_masked_per_image).expand_as(mask)[mask] | ||
|
||
# Apply weighting. | ||
B = mask.shape[0] | ||
loss = (loss * weight).sum() / B | ||
|
||
self.center.update(teacher_out) | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from typing import Tuple | ||
|
||
import torch | ||
import torch.distributed as dist | ||
from torch import Tensor | ||
from torch.nn import Module | ||
|
||
|
||
class Center(Module): | ||
"""Center module to compute and store the center of a feature tensor as used | ||
in DINO [0]. | ||
- [0]: DINO, 2021, https://arxiv.org/abs/2104.14294 | ||
Attributes: | ||
size: | ||
Size of the tracked center tensor. Dimensions across which the center | ||
is computed must be set to 1. For example, if the feature tensor has shape | ||
(batch_size, sequence_length, feature_dim) and the center should be computed | ||
across the batch and sequence dimensions, the size should be | ||
(1, 1, feature_dim). | ||
mode: | ||
Mode to compute the center. Currently only 'mean' is supported. | ||
momentum: | ||
Momentum term for the center calculation. | ||
_register_buffer: | ||
Deprecated, do not use. This argument is only kept for backwards | ||
compatibility with DINOLoss. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
size: Tuple[int, ...], | ||
mode: str = "mean", | ||
momentum: float = 0.9, | ||
_register_buffer: bool = True, | ||
) -> None: | ||
super().__init__() | ||
|
||
mode_to_fn = { | ||
"mean": self._center_mean, | ||
} | ||
if mode not in mode_to_fn: | ||
raise ValueError( | ||
f"Unknown mode '{mode}'. Valid modes are {sorted(mode_to_fn.keys())}." | ||
) | ||
self._center_fn = mode_to_fn[mode] | ||
self.size = size | ||
self.dim = tuple(i for i, s in enumerate(size) if s == 1) | ||
|
||
center = torch.zeros(self.size) | ||
if _register_buffer: | ||
self.register_buffer("center", center) | ||
else: | ||
# Do not register buffer for backwards compatilibity with DINOLoss as the | ||
# loss already registers the buffer. If we register it here again there will | ||
# be an extra entry in the state dict. | ||
self.center = center | ||
|
||
self.momentum = momentum | ||
|
||
@property | ||
def value(self) -> Tensor: | ||
"""The current value of the center. Use this property to do any operations based | ||
on the center.""" | ||
return self.center | ||
|
||
@torch.no_grad() | ||
def update(self, x: Tensor) -> None: | ||
"""Update the center with a new batch of features. | ||
Args: | ||
x: | ||
Feature tensor used to update the center. Must have the same number of | ||
dimensions as self.size. | ||
""" | ||
batch_center = self._center_fn(x) | ||
# Use copy for backwards compatibility with DINOLoss. | ||
self.center.copy_(self._center_momentum(batch_center)) | ||
|
||
@torch.no_grad() | ||
def _center_mean(self, x: Tensor) -> Tensor: | ||
"""Returns the center of the input tensor by calculating the mean.""" | ||
batch_center = torch.mean(x, dim=self.dim, keepdim=True) | ||
if dist.is_available() and dist.is_initialized(): | ||
dist.all_reduce(batch_center) | ||
batch_center = batch_center / dist.get_world_size() | ||
return batch_center | ||
|
||
@torch.no_grad() | ||
def _center_momentum(self, batch_center: Tensor) -> Tensor: | ||
"""Returns the new center with momentum update.""" | ||
return self.center * self.momentum + batch_center * (1 - self.momentum) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import pytest | ||
import torch | ||
|
||
from lightly.loss.ibot_loss import IBOTPatchLoss | ||
|
||
|
||
class TestIBOTPatchLoss: | ||
@pytest.mark.parametrize("device", ["cpu", "cuda"]) | ||
def test_forward(self, device: str) -> None: | ||
if not torch.cuda.is_available() and device == "cuda": | ||
pytest.skip("CUDA not available") | ||
|
||
criterion = IBOTPatchLoss( | ||
output_dim=2, | ||
teacher_temp=0.1, | ||
student_temp=0.2, | ||
center_mode="mean", | ||
center_momentum=0.9, | ||
) | ||
teacher_out = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) | ||
student_out = torch.tensor([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]) | ||
mask = torch.tensor( | ||
[ | ||
[[True, False], [True, False]], | ||
[[False, False], [False, True]], | ||
[[False, False], [False, False]], | ||
] | ||
) | ||
|
||
loss = criterion.forward( | ||
teacher_out=teacher_out, student_out=student_out, mask=mask | ||
) | ||
assert loss == pytest.approx(0.4057, rel=0.0001) | ||
expected_center = 0.1 * teacher_out.mean(0) | ||
assert torch.all(torch.isclose(criterion.center.value, expected_center)) | ||
# Loss value was calculated with the original implementation from: | ||
# https://github.com/facebookresearch/dinov2/blob/main/dinov2/loss/ibot_patch_loss.py | ||
# | ||
# Code: | ||
# orig_criterion = iBOTPatchLoss(patch_out_dim=2, student_temp=0.2) | ||
# orig_t_center = orig_criterion.softmax_center_teacher(teacher_out, 0.1) | ||
# orig_loss = orig_criterion.forward_masked( | ||
# student_patch_tokens_masked=student_out, | ||
# teacher_patch_tokens_masked=orig_t_center, | ||
# student_masks_flat=mask.flatten(start_dim=1), | ||
# ) |
Oops, something went wrong.