Skip to content

Commit

Permalink
Add IBOTPatchLoss (#1616)
Browse files Browse the repository at this point in the history
* Add IBOTPatchLoss
* Add Center
* Refactor DINOLoss to use Center
  • Loading branch information
guarin authored Jul 31, 2024
1 parent a06a954 commit 4879525
Show file tree
Hide file tree
Showing 7 changed files with 321 additions and 26 deletions.
3 changes: 3 additions & 0 deletions docs/source/lightly.loss.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ lightly.loss
.. autoclass:: lightly.loss.hypersphere_loss.HypersphereLoss
:members:

.. autoclass:: lightly.loss.ibot_loss.IBOTPatchLoss
:members:

.. autoclass:: lightly.loss.koleo_loss.KoLeoLoss
:members:

Expand Down
1 change: 1 addition & 0 deletions lightly/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from lightly.loss.dcl_loss import DCLLoss, DCLWLoss
from lightly.loss.dino_loss import DINOLoss
from lightly.loss.emp_ssl_loss import EMPSSLLoss
from lightly.loss.ibot_loss import IBOTPatchLoss
from lightly.loss.koleo_loss import KoLeoLoss
from lightly.loss.mmcr_loss import MMCRLoss
from lightly.loss.msn_loss import MSNLoss
Expand Down
67 changes: 41 additions & 26 deletions lightly/loss/dino_loss.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import List

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module

from lightly.models.modules.center import Center

class DINOLoss(nn.Module):

class DINOLoss(Module):
"""
Implementation of the loss described in 'Emerging Properties in
Self-Supervised Vision Transformers'. [0]
Expand Down Expand Up @@ -61,14 +63,21 @@ def __init__(
warmup_teacher_temp_epochs: int = 30,
student_temp: float = 0.1,
center_momentum: float = 0.9,
center_mode: str = "mean",
):
super().__init__()
self.warmup_teacher_temp_epochs = warmup_teacher_temp_epochs
self.teacher_temp = teacher_temp
self.student_temp = student_temp
self.center_momentum = center_momentum

self.register_buffer("center", torch.zeros(1, 1, output_dim))
self._center = Center(
size=(1, 1, output_dim),
mode=center_mode,
momentum=center_momentum,
_register_buffer=False,
)
self.register_buffer("center", self._center.center)

# we apply a warm up for the teacher temperature because
# a too high temperature makes the training instable at the beginning
self.teacher_temp_schedule = torch.linspace(
Expand All @@ -77,26 +86,39 @@ def __init__(
steps=warmup_teacher_temp_epochs,
)

# Center momentum is registered as property for backwards compatibility as it used
# to be stored as attribute.
@property
def center_momentum(self) -> float:
return self._center.momentum

@center_momentum.setter
def center_momentum(self, value: float) -> None:
self._center.momentum = value

def forward(
self,
teacher_out: List[torch.Tensor],
student_out: List[torch.Tensor],
teacher_out: List[Tensor],
student_out: List[Tensor],
epoch: int,
) -> torch.Tensor:
) -> Tensor:
"""Cross-entropy between softmax outputs of the teacher and student
networks.
Args:
teacher_out:
List of view feature tensors from the teacher model. Each
tensor is assumed to contain features from one view of the batch
and have length batch_size.
List of tensors with shape (batch_size, output_dim) containing features
from the teacher model. Each tensor must represent one view of the
batch.
student_out:
List of view feature tensors from the student model. Each tensor
is assumed to contain features from one view of the batch and
have length batch_size.
List of tensors with shape (batch_size, output_dim) containing features
from the student model. Each tensor must represent one view of the
batch.
epoch:
The current training epoch.
update_center:
If True, the center used for the teacher output is updated after the
loss calculation.
Returns:
The average cross-entropy loss.
Expand All @@ -109,7 +131,7 @@ def forward(
teacher_temp = self.teacher_temp

teacher_out = torch.stack(teacher_out)
t_out = F.softmax((teacher_out - self.center) / teacher_temp, dim=-1)
t_out = F.softmax((teacher_out - self._center.value) / teacher_temp, dim=-1)

student_out = torch.stack(student_out)
s_out = F.log_softmax(student_out / self.student_temp, dim=-1)
Expand All @@ -129,20 +151,13 @@ def forward(
return loss

@torch.no_grad()
def update_center(self, teacher_out: torch.Tensor) -> None:
def update_center(self, teacher_out: Tensor) -> None:
"""Moving average update of the center used for the teacher output.
Args:
teacher_out:
Stacked output from the teacher model.
Tensor with shape (num_views, batch_size, output_dim) containing
features from the teacher model.
"""
batch_center = torch.mean(teacher_out, dim=(0, 1), keepdim=True)
if dist.is_available() and dist.is_initialized():
dist.all_reduce(batch_center)
batch_center = batch_center / dist.get_world_size()

# ema update
self.center = self.center * self.center_momentum + batch_center * (
1 - self.center_momentum
)
self._center.update(teacher_out)
96 changes: 96 additions & 0 deletions lightly/loss/ibot_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch
from torch import Tensor
from torch.nn import Module
from torch.nn import functional as F

from lightly.models.modules.center import Center


class IBOTPatchLoss(Module):
"""Implementation of the iBOT patch loss [0] as used in DINOv2 [1].
Implementation is based on [2].
- [0]: iBOT, 2021, https://arxiv.org/abs/2111.07832
- [1]: DINOv2, 2023, https://arxiv.org/abs/2304.07193
- [2]: https://github.com/facebookresearch/dinov2/blob/main/dinov2/loss/ibot_patch_loss.py
Attributes:
output_dim:
Dimension of the model output.
teacher_temp:
Temperature for the teacher output.
student_temp:
Temperature for the student output.
center_mode:
Mode for center calculation. Only 'mean' is supported.
center_momentum:
Momentum term for the center update.
"""

def __init__(
self,
output_dim: int,
teacher_temp: float = 0.04,
student_temp: float = 0.1,
center_mode: str = "mean",
center_momentum: float = 0.9,
) -> None:
super().__init__()
self.teacher_temp = teacher_temp
self.student_temperature = student_temp
self.center = Center(
size=(1, output_dim),
mode=center_mode,
momentum=center_momentum,
)

def forward(
self,
teacher_out: Tensor,
student_out: Tensor,
mask: Tensor,
) -> Tensor:
"""Forward pass through the iBOT patch loss.
Args:
teacher_out:
Tensor with shape (batch_size * sequence_length, embed_dim) containing
the teacher output of the masked tokens.
student_out:
Tensor with shape (batch_size * sequence_length, embed_dim) containing
the student output of the masked tokens.
mask:
Boolean tensor with shape (batch_size, height, width) containing the
token mask. Exactly batch_size * sequence_length entries must be set to
True in the mask.
Returns:
Loss value.
"""
# B = batch size, N = sequence length = number of masked tokens, D = embed dim
# H = height (in tokens), W = width (in tokens)
# Note that N <= H * W depending on how many tokens are masked.

# Calculate cross entropy loss.
teacher_softmax = F.softmax(
(teacher_out - self.center.value) / self.teacher_temp, dim=-1
)
student_log_softmax = F.log_softmax(
student_out / self.student_temperature, dim=-1
)
# (B * N, D) -> (B * N)
loss = -torch.sum(teacher_softmax * student_log_softmax, dim=-1)

# Get weights.
# (B, H, W) -> (B, 1, 1)
num_masked_per_image = mask.sum(dim=(1, 2), keepdim=True).clamp(min=1.0)
# (B, 1, 1) -> (B, H, W) -> (B * N)
weight = (1.0 / num_masked_per_image).expand_as(mask)[mask]

# Apply weighting.
B = mask.shape[0]
loss = (loss * weight).sum() / B

self.center.update(teacher_out)
return loss
93 changes: 93 additions & 0 deletions lightly/models/modules/center.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Tuple

import torch
import torch.distributed as dist
from torch import Tensor
from torch.nn import Module


class Center(Module):
"""Center module to compute and store the center of a feature tensor as used
in DINO [0].
- [0]: DINO, 2021, https://arxiv.org/abs/2104.14294
Attributes:
size:
Size of the tracked center tensor. Dimensions across which the center
is computed must be set to 1. For example, if the feature tensor has shape
(batch_size, sequence_length, feature_dim) and the center should be computed
across the batch and sequence dimensions, the size should be
(1, 1, feature_dim).
mode:
Mode to compute the center. Currently only 'mean' is supported.
momentum:
Momentum term for the center calculation.
_register_buffer:
Deprecated, do not use. This argument is only kept for backwards
compatibility with DINOLoss.
"""

def __init__(
self,
size: Tuple[int, ...],
mode: str = "mean",
momentum: float = 0.9,
_register_buffer: bool = True,
) -> None:
super().__init__()

mode_to_fn = {
"mean": self._center_mean,
}
if mode not in mode_to_fn:
raise ValueError(
f"Unknown mode '{mode}'. Valid modes are {sorted(mode_to_fn.keys())}."
)
self._center_fn = mode_to_fn[mode]
self.size = size
self.dim = tuple(i for i, s in enumerate(size) if s == 1)

center = torch.zeros(self.size)
if _register_buffer:
self.register_buffer("center", center)
else:
# Do not register buffer for backwards compatilibity with DINOLoss as the
# loss already registers the buffer. If we register it here again there will
# be an extra entry in the state dict.
self.center = center

self.momentum = momentum

@property
def value(self) -> Tensor:
"""The current value of the center. Use this property to do any operations based
on the center."""
return self.center

@torch.no_grad()
def update(self, x: Tensor) -> None:
"""Update the center with a new batch of features.
Args:
x:
Feature tensor used to update the center. Must have the same number of
dimensions as self.size.
"""
batch_center = self._center_fn(x)
# Use copy for backwards compatibility with DINOLoss.
self.center.copy_(self._center_momentum(batch_center))

@torch.no_grad()
def _center_mean(self, x: Tensor) -> Tensor:
"""Returns the center of the input tensor by calculating the mean."""
batch_center = torch.mean(x, dim=self.dim, keepdim=True)
if dist.is_available() and dist.is_initialized():
dist.all_reduce(batch_center)
batch_center = batch_center / dist.get_world_size()
return batch_center

@torch.no_grad()
def _center_momentum(self, batch_center: Tensor) -> Tensor:
"""Returns the new center with momentum update."""
return self.center * self.momentum + batch_center * (1 - self.momentum)
46 changes: 46 additions & 0 deletions tests/loss/test_ibot_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest
import torch

from lightly.loss.ibot_loss import IBOTPatchLoss


class TestIBOTPatchLoss:
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_forward(self, device: str) -> None:
if not torch.cuda.is_available() and device == "cuda":
pytest.skip("CUDA not available")

criterion = IBOTPatchLoss(
output_dim=2,
teacher_temp=0.1,
student_temp=0.2,
center_mode="mean",
center_momentum=0.9,
)
teacher_out = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])
student_out = torch.tensor([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]])
mask = torch.tensor(
[
[[True, False], [True, False]],
[[False, False], [False, True]],
[[False, False], [False, False]],
]
)

loss = criterion.forward(
teacher_out=teacher_out, student_out=student_out, mask=mask
)
assert loss == pytest.approx(0.4057, rel=0.0001)
expected_center = 0.1 * teacher_out.mean(0)
assert torch.all(torch.isclose(criterion.center.value, expected_center))
# Loss value was calculated with the original implementation from:
# https://github.com/facebookresearch/dinov2/blob/main/dinov2/loss/ibot_patch_loss.py
#
# Code:
# orig_criterion = iBOTPatchLoss(patch_out_dim=2, student_temp=0.2)
# orig_t_center = orig_criterion.softmax_center_teacher(teacher_out, 0.1)
# orig_loss = orig_criterion.forward_masked(
# student_patch_tokens_masked=student_out,
# teacher_patch_tokens_masked=orig_t_center,
# student_masks_flat=mask.flatten(start_dim=1),
# )
Loading

0 comments on commit 4879525

Please sign in to comment.