-
Notifications
You must be signed in to change notification settings - Fork 312
Start on SupCon loss #1554 #1877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,168 @@ | ||
| """ Contrastive Loss Functions """ | ||
|
|
||
| # Copyright (c) 2020. Lightly AG and its affiliates. | ||
| # All Rights Reserved | ||
|
|
||
| from typing import Optional | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from torch import Tensor | ||
| from torch import distributed as torch_dist | ||
| from torch import nn | ||
|
|
||
| from lightly.utils import dist | ||
|
|
||
|
|
||
| class SupConLoss(nn.Module): | ||
| """Implementation of the Supervised Contrastive Loss. | ||
|
|
||
| This implementation follows the SupCon[0] paper. | ||
|
|
||
| - [0] SupCon, 2020, https://arxiv.org/abs/2004.11362 | ||
|
|
||
| Attributes: | ||
| temperature: | ||
| Scale logits by the inverse of the temperature. | ||
| gather_distributed: | ||
| If True then negatives from all GPUs are gathered before the | ||
| loss calculation. If a memory bank is used and gather_distributed is True, | ||
| then tensors from all gpus are gathered before the memory bank is updated. | ||
| rescale: | ||
| Optionally rescale final loss by the temperature for stability. | ||
| Raises: | ||
| ValueError: If abs(temperature) < 1e-8 to prevent divide by zero. | ||
|
|
||
| Examples: | ||
| >>> # initialize loss function without memory bank | ||
| >>> loss_fn = NTXentLoss(memory_bank_size=0) | ||
| >>> | ||
| >>> # generate two random transforms of images | ||
| >>> t0 = transforms(images) | ||
| >>> t1 = transforms(images) | ||
| >>> | ||
| >>> # feed through SimCLR or MoCo model | ||
| >>> out0, out1 = model(t0), model(t1) | ||
| >>> | ||
| >>> # calculate loss | ||
| >>> loss = loss_fn(out0, out1) | ||
|
|
||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| temperature: float = 0.5, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default temperature in the original paper is 0.1 |
||
| gather_distributed: bool = False, | ||
| rescale: bool = True, | ||
| ): | ||
| """Initializes the SupConLoss module with the specified parameters. | ||
|
|
||
| Args: | ||
| temperature: | ||
| Scale logits by the inverse of the temperature. | ||
| gather_distributed: | ||
| If True, negatives from all GPUs are gathered before the loss calculation. | ||
| rescale: | ||
| Optionally rescale final loss by the temperature for stability. | ||
|
|
||
| Raises: | ||
| ValueError: If temperature is less than 1e-8 to prevent divide by zero. | ||
| ValueError: If gather_distributed is True but torch.distributed is not available. | ||
| """ | ||
| super().__init__() | ||
| self.temperature = temperature | ||
| self.gather_distributed = gather_distributed | ||
| self.rescale = rescale | ||
| self.cross_entropy = nn.CrossEntropyLoss(reduction="mean") | ||
| self.eps = 1e-8 | ||
|
|
||
| if abs(self.temperature) < self.eps: | ||
| raise ValueError( | ||
| "Illegal temperature: abs({}) < 1e-8".format(self.temperature) | ||
| ) | ||
| if gather_distributed and not torch_dist.is_available(): | ||
| raise ValueError( | ||
| "gather_distributed is True but torch.distributed is not available. " | ||
| "Please set gather_distributed=False or install a torch version with " | ||
| "distributed support." | ||
| ) | ||
|
|
||
| def forward( | ||
| self, out0: Tensor, out1: Tensor, labels: Optional[Tensor] = None | ||
| ) -> Tensor: | ||
| """Forward pass through Supervised Contrastive Loss. | ||
|
|
||
| Computes the loss based on contrast_mode setting. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no contrast mode setting anymore |
||
|
|
||
| Args: | ||
| out0: | ||
| Output projections of the first set of transformed images. | ||
| Shape: (batch_size, embedding_size) | ||
| out1: | ||
| Output projections of the second set of transformed images. | ||
| Shape: (batch_size, embedding_size) | ||
| labels: | ||
| Onehot labels for each sample. Must be a vector of length `batch_size`. | ||
|
|
||
| Returns: | ||
| Supervised Contrastive Loss value. | ||
| """ | ||
| # Stack the views for efficient computation | ||
| # Allows for more views to be added easily | ||
| features = (out0, out1) | ||
| n_views = len(features) | ||
| out_small = torch.vstack(features) | ||
|
|
||
| device = out_small.device | ||
| batch_size = out_small.shape[0] // n_views | ||
|
|
||
| # Normalize the output to length 1 | ||
| out_small = nn.functional.normalize(out_small, dim=1) | ||
|
|
||
| # Gather hidden representations from other processes if distributed | ||
| # and compute the diagonal self-contrast mask | ||
| if self.gather_distributed and dist.world_size() > 1: | ||
| out_large = torch.cat(dist.gather(out_small), 0) | ||
| diag_mask = dist.eye_rank(n_views * batch_size, device=device) | ||
| else: | ||
| # Single process | ||
| out_large = out_small | ||
| diag_mask = torch.eye(n_views * batch_size, device=device, dtype=torch.bool) | ||
|
|
||
| # Use cosine similarity (dot product) as all vectors are normalized to unit length | ||
| # Calculate similiarities | ||
| logits = out_small @ out_large.T | ||
| logits /= self.temperature | ||
|
|
||
| # Set self-similarities to infinitely small value | ||
| logits[diag_mask] = -1e9 | ||
|
|
||
| # Create labels if None | ||
| if labels is None: | ||
| labels = torch.arange(batch_size, device=device, dtype=torch.long) | ||
| if self.gather_distributed: | ||
| labels = labels + dist.rank() * batch_size | ||
| labels = labels.repeat(n_views) | ||
|
|
||
| # Soft labels are 0 unless the logit represents a similarity | ||
| # between two of the same classes. We manually set self-similarity | ||
| # (same view of the same item) to 0. When not 0, the value is | ||
| # 1 / n, where n is the number of positive samples | ||
| # (different views of the same item, and all views of other items sharing | ||
| # classes with the item) | ||
| soft_labels = torch.eq(labels, labels.view(-1, 1)).float() | ||
| soft_labels.fill_diagonal_(0.0) | ||
| soft_labels /= soft_labels.sum(dim=1) | ||
|
|
||
| # Compute log probabilities | ||
| log_proba = F.log_softmax(logits, dim=-1) | ||
|
|
||
| # Compute soft cross-entropy loss | ||
| loss = (soft_labels * log_proba).sum(-1) | ||
| loss = -loss.mean() | ||
|
|
||
| # Optional: rescale for stable training | ||
| if self.rescale: | ||
| loss *= self.temperature | ||
|
|
||
| return loss | ||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,127 @@ | ||||
| from typing import List | ||||
|
|
||||
| import pytest | ||||
| import torch | ||||
| from pytest_mock import MockerFixture | ||||
| from torch import Tensor | ||||
| from torch import distributed as dist | ||||
| from torch import nn | ||||
|
|
||||
| from lightly.loss import NTXentLoss, SupConLoss | ||||
|
|
||||
|
|
||||
| class TestSupConLoss: | ||||
| temperature = 0.5 | ||||
|
|
||||
| def test__gather_distributed(self, mocker: MockerFixture) -> None: | ||||
| mock_is_available = mocker.patch.object(dist, "is_available", return_value=True) | ||||
| SupConLoss(gather_distributed=True) | ||||
| mock_is_available.assert_called_once() | ||||
|
|
||||
| def test__gather_distributed_dist_not_available( | ||||
| self, mocker: MockerFixture | ||||
| ) -> None: | ||||
| mock_is_available = mocker.patch.object( | ||||
| dist, "is_available", return_value=False | ||||
| ) | ||||
| with pytest.raises(ValueError): | ||||
| SupConLoss(gather_distributed=True) | ||||
| mock_is_available.assert_called_once() | ||||
|
|
||||
| def test_simple_input(self) -> None: | ||||
| out1 = torch.rand((3, 10)) | ||||
| out2 = torch.rand((3, 10)) | ||||
| my_label = Tensor([0, 1, 1]) | ||||
| my_loss = SupConLoss() | ||||
| my_loss(out1, out2, my_label) | ||||
|
|
||||
| def test_unsup_equal_to_simclr(self) -> None: | ||||
| supcon = SupConLoss(temperature=self.temperature, rescale=False) | ||||
| ntxent = NTXentLoss(temperature=self.temperature) | ||||
| out1 = torch.rand((8, 10)) | ||||
| out2 = torch.rand((8, 10)) | ||||
| supcon_loss = supcon(out1, out2) | ||||
| ntxent_loss = ntxent(out1, out2) | ||||
| assert (supcon_loss - ntxent_loss).pow(2).item() == pytest.approx(0.0) | ||||
|
|
||||
| @pytest.mark.parametrize("labels", [[0, 0, 0, 0], [0, 1, 1, 1], [0, 1, 2, 3]]) | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here you shouldn't compare to my implementation. Your code is basically a rewritten version of mine that is not legitimately tested (even if it is correct). I would instead compare against plain code rewritten in NumPy like done here: lightly/tests/loss/test_ntx_ent_loss.py Line 151 in ee30cd4
|
||||
| def test_equivalence(self, labels: List[int]) -> None: | ||||
| DistributedSupCon = SupConLoss(temperature=self.temperature) | ||||
| NonDistributedSupCon = SupConLossNonDistributed(temperature=self.temperature) | ||||
| out1 = nn.functional.normalize(torch.rand(4, 10), dim=-1) | ||||
| out2 = nn.functional.normalize(torch.rand(4, 10), dim=-1) | ||||
| test_labels = Tensor(labels) | ||||
|
|
||||
| loss1 = DistributedSupCon(out1, out2, test_labels) | ||||
| loss2 = NonDistributedSupCon( | ||||
| torch.vstack((out1, out2)), test_labels.view(-1, 1) | ||||
| ) | ||||
|
|
||||
| assert (loss1 - loss2).pow(2).item() == pytest.approx(0.0) | ||||
|
|
||||
|
|
||||
| class SupConLossNonDistributed(nn.Module): | ||||
| def __init__( | ||||
| self, | ||||
| temperature: float = 0.1, | ||||
| ): | ||||
| """Contrastive Learning Loss Function: SupConLoss and InfoNCE Loss. Non-distributed version by Yutong. | ||||
|
|
||||
| SupCon from Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. | ||||
| InfoNCE (NT-Xent) from SimCLR: https://arxiv.org/pdf/2002.05709.pdf. | ||||
|
|
||||
| Adapted from Yonglong Tian's work at https://github.com/HobbitLong/SupContrast/blob/master/losses.py and | ||||
| https://github.com/google-research/syn-rep-learn/blob/main/StableRep/models/losses.py. | ||||
|
|
||||
| The function first creates a contrastive mask of shape [batch_size * n_views, batch_size * n_views], where | ||||
| mask_{i,j}=1 if sample j has the same class as sample i, except for the sample i itself. | ||||
|
|
||||
| Next, it computes the logits from the features and then computes the soft cross-entropy loss. | ||||
|
|
||||
| The loss is rescaled by the temperature parameter. | ||||
|
|
||||
| For self-supervised learning, the labels should be the indices of the samples. In this case it is equivalent to InfoNCE loss. | ||||
|
|
||||
| Attributes: | ||||
| - temperature (float): A temperature parameter to control the similarity. Default is 0.1. | ||||
|
|
||||
| Args: | ||||
| - features (torch.Tensor): hidden vector of shape [batch_size * n_views, ...]. | ||||
| - labels (torch.Tensor): ground truth of shape [batch_size, 1]. | ||||
| """ | ||||
| super().__init__() | ||||
|
|
||||
| self.temperature = temperature | ||||
|
|
||||
| def forward( | ||||
| self, | ||||
| features: torch.Tensor, | ||||
| labels: torch.Tensor, | ||||
| ) -> torch.Tensor: | ||||
| # create n-viewed mask | ||||
| labels_n_views = labels.contiguous().repeat( | ||||
| features.shape[0] // labels.shape[0], 1 | ||||
| ) # [batch_size * n_views, 1] | ||||
| contrastive_mask_n_views = torch.eq( | ||||
| labels_n_views, labels_n_views.T | ||||
| ).float() # [batch_size * n_views, batch_size * n_views] | ||||
| contrastive_mask_n_views.fill_diagonal_(0) # mask-out self-contrast cases | ||||
|
|
||||
| # compute logits | ||||
| logits = ( | ||||
| torch.matmul(features, features.T) / self.temperature | ||||
| ) # [batch_size * n_views, batch_size * n_views] | ||||
| logits.fill_diagonal_(-1e9) # suppress logit for self-contrast cases | ||||
|
|
||||
| # compute log probabilities and soft labels | ||||
| soft_label = contrastive_mask_n_views / contrastive_mask_n_views.sum(dim=1) | ||||
| log_proba = nn.functional.log_softmax(logits, dim=-1) | ||||
|
|
||||
| # compute soft cross-entropy loss | ||||
| loss_all = torch.sum(soft_label * log_proba, dim=-1) | ||||
| loss = -loss_all.mean() | ||||
|
|
||||
| # rescale for stable training | ||||
| loss *= self.temperature | ||||
|
|
||||
| return loss | ||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix the docstring with the example for
SupConLoss