diff --git a/lightly/loss/__init__.py b/lightly/loss/__init__.py index e0e2291cf..973459a3f 100644 --- a/lightly/loss/__init__.py +++ b/lightly/loss/__init__.py @@ -16,6 +16,7 @@ from lightly.loss.negative_cosine_similarity import NegativeCosineSimilarity from lightly.loss.ntx_ent_loss import NTXentLoss from lightly.loss.pmsn_loss import PMSNCustomLoss, PMSNLoss +from lightly.loss.supcon_loss import SupConLoss from lightly.loss.swav_loss import SwaVLoss from lightly.loss.sym_neg_cos_sim_loss import SymNegCosineSimilarityLoss from lightly.loss.tico_loss import TiCoLoss diff --git a/lightly/loss/supcon_loss.py b/lightly/loss/supcon_loss.py new file mode 100644 index 000000000..26d48d892 --- /dev/null +++ b/lightly/loss/supcon_loss.py @@ -0,0 +1,168 @@ +""" Contrastive Loss Functions """ + +# Copyright (c) 2020. Lightly AG and its affiliates. +# All Rights Reserved + +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch import distributed as torch_dist +from torch import nn + +from lightly.utils import dist + + +class SupConLoss(nn.Module): + """Implementation of the Supervised Contrastive Loss. + + This implementation follows the SupCon[0] paper. + + - [0] SupCon, 2020, https://arxiv.org/abs/2004.11362 + + Attributes: + temperature: + Scale logits by the inverse of the temperature. + gather_distributed: + If True then negatives from all GPUs are gathered before the + loss calculation. If a memory bank is used and gather_distributed is True, + then tensors from all gpus are gathered before the memory bank is updated. + rescale: + Optionally rescale final loss by the temperature for stability. + Raises: + ValueError: If abs(temperature) < 1e-8 to prevent divide by zero. + + Examples: + >>> # initialize loss function without memory bank + >>> loss_fn = NTXentLoss(memory_bank_size=0) + >>> + >>> # generate two random transforms of images + >>> t0 = transforms(images) + >>> t1 = transforms(images) + >>> + >>> # feed through SimCLR or MoCo model + >>> out0, out1 = model(t0), model(t1) + >>> + >>> # calculate loss + >>> loss = loss_fn(out0, out1) + + """ + + def __init__( + self, + temperature: float = 0.5, + gather_distributed: bool = False, + rescale: bool = True, + ): + """Initializes the SupConLoss module with the specified parameters. + + Args: + temperature: + Scale logits by the inverse of the temperature. + gather_distributed: + If True, negatives from all GPUs are gathered before the loss calculation. + rescale: + Optionally rescale final loss by the temperature for stability. + + Raises: + ValueError: If temperature is less than 1e-8 to prevent divide by zero. + ValueError: If gather_distributed is True but torch.distributed is not available. + """ + super().__init__() + self.temperature = temperature + self.gather_distributed = gather_distributed + self.rescale = rescale + self.cross_entropy = nn.CrossEntropyLoss(reduction="mean") + self.eps = 1e-8 + + if abs(self.temperature) < self.eps: + raise ValueError( + "Illegal temperature: abs({}) < 1e-8".format(self.temperature) + ) + if gather_distributed and not torch_dist.is_available(): + raise ValueError( + "gather_distributed is True but torch.distributed is not available. " + "Please set gather_distributed=False or install a torch version with " + "distributed support." + ) + + def forward( + self, out0: Tensor, out1: Tensor, labels: Optional[Tensor] = None + ) -> Tensor: + """Forward pass through Supervised Contrastive Loss. + + Computes the loss based on contrast_mode setting. + + Args: + out0: + Output projections of the first set of transformed images. + Shape: (batch_size, embedding_size) + out1: + Output projections of the second set of transformed images. + Shape: (batch_size, embedding_size) + labels: + Onehot labels for each sample. Must be a vector of length `batch_size`. + + Returns: + Supervised Contrastive Loss value. + """ + # Stack the views for efficient computation + # Allows for more views to be added easily + features = (out0, out1) + n_views = len(features) + out_small = torch.vstack(features) + + device = out_small.device + batch_size = out_small.shape[0] // n_views + + # Normalize the output to length 1 + out_small = nn.functional.normalize(out_small, dim=1) + + # Gather hidden representations from other processes if distributed + # and compute the diagonal self-contrast mask + if self.gather_distributed and dist.world_size() > 1: + out_large = torch.cat(dist.gather(out_small), 0) + diag_mask = dist.eye_rank(n_views * batch_size, device=device) + else: + # Single process + out_large = out_small + diag_mask = torch.eye(n_views * batch_size, device=device, dtype=torch.bool) + + # Use cosine similarity (dot product) as all vectors are normalized to unit length + # Calculate similiarities + logits = out_small @ out_large.T + logits /= self.temperature + + # Set self-similarities to infinitely small value + logits[diag_mask] = -1e9 + + # Create labels if None + if labels is None: + labels = torch.arange(batch_size, device=device, dtype=torch.long) + if self.gather_distributed: + labels = labels + dist.rank() * batch_size + labels = labels.repeat(n_views) + + # Soft labels are 0 unless the logit represents a similarity + # between two of the same classes. We manually set self-similarity + # (same view of the same item) to 0. When not 0, the value is + # 1 / n, where n is the number of positive samples + # (different views of the same item, and all views of other items sharing + # classes with the item) + soft_labels = torch.eq(labels, labels.view(-1, 1)).float() + soft_labels.fill_diagonal_(0.0) + soft_labels /= soft_labels.sum(dim=1) + + # Compute log probabilities + log_proba = F.log_softmax(logits, dim=-1) + + # Compute soft cross-entropy loss + loss = (soft_labels * log_proba).sum(-1) + loss = -loss.mean() + + # Optional: rescale for stable training + if self.rescale: + loss *= self.temperature + + return loss diff --git a/tests/loss/test_supcon_loss.py b/tests/loss/test_supcon_loss.py new file mode 100644 index 000000000..2d1496fda --- /dev/null +++ b/tests/loss/test_supcon_loss.py @@ -0,0 +1,127 @@ +from typing import List + +import pytest +import torch +from pytest_mock import MockerFixture +from torch import Tensor +from torch import distributed as dist +from torch import nn + +from lightly.loss import NTXentLoss, SupConLoss + + +class TestSupConLoss: + temperature = 0.5 + + def test__gather_distributed(self, mocker: MockerFixture) -> None: + mock_is_available = mocker.patch.object(dist, "is_available", return_value=True) + SupConLoss(gather_distributed=True) + mock_is_available.assert_called_once() + + def test__gather_distributed_dist_not_available( + self, mocker: MockerFixture + ) -> None: + mock_is_available = mocker.patch.object( + dist, "is_available", return_value=False + ) + with pytest.raises(ValueError): + SupConLoss(gather_distributed=True) + mock_is_available.assert_called_once() + + def test_simple_input(self) -> None: + out1 = torch.rand((3, 10)) + out2 = torch.rand((3, 10)) + my_label = Tensor([0, 1, 1]) + my_loss = SupConLoss() + my_loss(out1, out2, my_label) + + def test_unsup_equal_to_simclr(self) -> None: + supcon = SupConLoss(temperature=self.temperature, rescale=False) + ntxent = NTXentLoss(temperature=self.temperature) + out1 = torch.rand((8, 10)) + out2 = torch.rand((8, 10)) + supcon_loss = supcon(out1, out2) + ntxent_loss = ntxent(out1, out2) + assert (supcon_loss - ntxent_loss).pow(2).item() == pytest.approx(0.0) + + @pytest.mark.parametrize("labels", [[0, 0, 0, 0], [0, 1, 1, 1], [0, 1, 2, 3]]) + def test_equivalence(self, labels: List[int]) -> None: + DistributedSupCon = SupConLoss(temperature=self.temperature) + NonDistributedSupCon = SupConLossNonDistributed(temperature=self.temperature) + out1 = nn.functional.normalize(torch.rand(4, 10), dim=-1) + out2 = nn.functional.normalize(torch.rand(4, 10), dim=-1) + test_labels = Tensor(labels) + + loss1 = DistributedSupCon(out1, out2, test_labels) + loss2 = NonDistributedSupCon( + torch.vstack((out1, out2)), test_labels.view(-1, 1) + ) + + assert (loss1 - loss2).pow(2).item() == pytest.approx(0.0) + + +class SupConLossNonDistributed(nn.Module): + def __init__( + self, + temperature: float = 0.1, + ): + """Contrastive Learning Loss Function: SupConLoss and InfoNCE Loss. Non-distributed version by Yutong. + + SupCon from Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. + InfoNCE (NT-Xent) from SimCLR: https://arxiv.org/pdf/2002.05709.pdf. + + Adapted from Yonglong Tian's work at https://github.com/HobbitLong/SupContrast/blob/master/losses.py and + https://github.com/google-research/syn-rep-learn/blob/main/StableRep/models/losses.py. + + The function first creates a contrastive mask of shape [batch_size * n_views, batch_size * n_views], where + mask_{i,j}=1 if sample j has the same class as sample i, except for the sample i itself. + + Next, it computes the logits from the features and then computes the soft cross-entropy loss. + + The loss is rescaled by the temperature parameter. + + For self-supervised learning, the labels should be the indices of the samples. In this case it is equivalent to InfoNCE loss. + + Attributes: + - temperature (float): A temperature parameter to control the similarity. Default is 0.1. + + Args: + - features (torch.Tensor): hidden vector of shape [batch_size * n_views, ...]. + - labels (torch.Tensor): ground truth of shape [batch_size, 1]. + """ + super().__init__() + + self.temperature = temperature + + def forward( + self, + features: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + # create n-viewed mask + labels_n_views = labels.contiguous().repeat( + features.shape[0] // labels.shape[0], 1 + ) # [batch_size * n_views, 1] + contrastive_mask_n_views = torch.eq( + labels_n_views, labels_n_views.T + ).float() # [batch_size * n_views, batch_size * n_views] + contrastive_mask_n_views.fill_diagonal_(0) # mask-out self-contrast cases + + # compute logits + logits = ( + torch.matmul(features, features.T) / self.temperature + ) # [batch_size * n_views, batch_size * n_views] + logits.fill_diagonal_(-1e9) # suppress logit for self-contrast cases + + # compute log probabilities and soft labels + soft_label = contrastive_mask_n_views / contrastive_mask_n_views.sum(dim=1) + log_proba = nn.functional.log_softmax(logits, dim=-1) + + # compute soft cross-entropy loss + loss_all = torch.sum(soft_label * log_proba, dim=-1) + loss = -loss_all.mean() + + # rescale for stable training + loss *= self.temperature + + return loss