Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lightly/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from lightly.loss.negative_cosine_similarity import NegativeCosineSimilarity
from lightly.loss.ntx_ent_loss import NTXentLoss
from lightly.loss.pmsn_loss import PMSNCustomLoss, PMSNLoss
from lightly.loss.supcon_loss import SupConLoss
from lightly.loss.swav_loss import SwaVLoss
from lightly.loss.sym_neg_cos_sim_loss import SymNegCosineSimilarityLoss
from lightly.loss.tico_loss import TiCoLoss
Expand Down
168 changes: 168 additions & 0 deletions lightly/loss/supcon_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
""" Contrastive Loss Functions """

# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch import distributed as torch_dist
from torch import nn

from lightly.utils import dist


class SupConLoss(nn.Module):
"""Implementation of the Supervised Contrastive Loss.

This implementation follows the SupCon[0] paper.

- [0] SupCon, 2020, https://arxiv.org/abs/2004.11362

Attributes:
temperature:
Scale logits by the inverse of the temperature.
gather_distributed:
If True then negatives from all GPUs are gathered before the
loss calculation. If a memory bank is used and gather_distributed is True,
then tensors from all gpus are gathered before the memory bank is updated.
rescale:
Optionally rescale final loss by the temperature for stability.
Raises:
ValueError: If abs(temperature) < 1e-8 to prevent divide by zero.

Examples:
>>> # initialize loss function without memory bank
>>> loss_fn = NTXentLoss(memory_bank_size=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the docstring with the example for SupConLoss

>>>
>>> # generate two random transforms of images
>>> t0 = transforms(images)
>>> t1 = transforms(images)
>>>
>>> # feed through SimCLR or MoCo model
>>> out0, out1 = model(t0), model(t1)
>>>
>>> # calculate loss
>>> loss = loss_fn(out0, out1)

"""

def __init__(
self,
temperature: float = 0.5,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default temperature in the original paper is 0.1

gather_distributed: bool = False,
rescale: bool = True,
):
"""Initializes the SupConLoss module with the specified parameters.

Args:
temperature:
Scale logits by the inverse of the temperature.
gather_distributed:
If True, negatives from all GPUs are gathered before the loss calculation.
rescale:
Optionally rescale final loss by the temperature for stability.

Raises:
ValueError: If temperature is less than 1e-8 to prevent divide by zero.
ValueError: If gather_distributed is True but torch.distributed is not available.
"""
super().__init__()
self.temperature = temperature
self.gather_distributed = gather_distributed
self.rescale = rescale
self.cross_entropy = nn.CrossEntropyLoss(reduction="mean")
self.eps = 1e-8

if abs(self.temperature) < self.eps:
raise ValueError(
"Illegal temperature: abs({}) < 1e-8".format(self.temperature)
)
if gather_distributed and not torch_dist.is_available():
raise ValueError(
"gather_distributed is True but torch.distributed is not available. "
"Please set gather_distributed=False or install a torch version with "
"distributed support."
)

def forward(
self, out0: Tensor, out1: Tensor, labels: Optional[Tensor] = None
) -> Tensor:
"""Forward pass through Supervised Contrastive Loss.

Computes the loss based on contrast_mode setting.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no contrast mode setting anymore


Args:
out0:
Output projections of the first set of transformed images.
Shape: (batch_size, embedding_size)
out1:
Output projections of the second set of transformed images.
Shape: (batch_size, embedding_size)
labels:
Onehot labels for each sample. Must be a vector of length `batch_size`.

Returns:
Supervised Contrastive Loss value.
"""
# Stack the views for efficient computation
# Allows for more views to be added easily
features = (out0, out1)
n_views = len(features)
out_small = torch.vstack(features)

device = out_small.device
batch_size = out_small.shape[0] // n_views

# Normalize the output to length 1
out_small = nn.functional.normalize(out_small, dim=1)

# Gather hidden representations from other processes if distributed
# and compute the diagonal self-contrast mask
if self.gather_distributed and dist.world_size() > 1:
out_large = torch.cat(dist.gather(out_small), 0)
diag_mask = dist.eye_rank(n_views * batch_size, device=device)
else:
# Single process
out_large = out_small
diag_mask = torch.eye(n_views * batch_size, device=device, dtype=torch.bool)

# Use cosine similarity (dot product) as all vectors are normalized to unit length
# Calculate similiarities
logits = out_small @ out_large.T
logits /= self.temperature

# Set self-similarities to infinitely small value
logits[diag_mask] = -1e9

# Create labels if None
if labels is None:
labels = torch.arange(batch_size, device=device, dtype=torch.long)
if self.gather_distributed:
labels = labels + dist.rank() * batch_size
labels = labels.repeat(n_views)

# Soft labels are 0 unless the logit represents a similarity
# between two of the same classes. We manually set self-similarity
# (same view of the same item) to 0. When not 0, the value is
# 1 / n, where n is the number of positive samples
# (different views of the same item, and all views of other items sharing
# classes with the item)
soft_labels = torch.eq(labels, labels.view(-1, 1)).float()
soft_labels.fill_diagonal_(0.0)
soft_labels /= soft_labels.sum(dim=1)

# Compute log probabilities
log_proba = F.log_softmax(logits, dim=-1)

# Compute soft cross-entropy loss
loss = (soft_labels * log_proba).sum(-1)
loss = -loss.mean()

# Optional: rescale for stable training
if self.rescale:
loss *= self.temperature

return loss
127 changes: 127 additions & 0 deletions tests/loss/test_supcon_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import List

import pytest
import torch
from pytest_mock import MockerFixture
from torch import Tensor
from torch import distributed as dist
from torch import nn

from lightly.loss import NTXentLoss, SupConLoss


class TestSupConLoss:
temperature = 0.5

def test__gather_distributed(self, mocker: MockerFixture) -> None:
mock_is_available = mocker.patch.object(dist, "is_available", return_value=True)
SupConLoss(gather_distributed=True)
mock_is_available.assert_called_once()

def test__gather_distributed_dist_not_available(
self, mocker: MockerFixture
) -> None:
mock_is_available = mocker.patch.object(
dist, "is_available", return_value=False
)
with pytest.raises(ValueError):
SupConLoss(gather_distributed=True)
mock_is_available.assert_called_once()

def test_simple_input(self) -> None:
out1 = torch.rand((3, 10))
out2 = torch.rand((3, 10))
my_label = Tensor([0, 1, 1])
my_loss = SupConLoss()
my_loss(out1, out2, my_label)

def test_unsup_equal_to_simclr(self) -> None:
supcon = SupConLoss(temperature=self.temperature, rescale=False)
ntxent = NTXentLoss(temperature=self.temperature)
out1 = torch.rand((8, 10))
out2 = torch.rand((8, 10))
supcon_loss = supcon(out1, out2)
ntxent_loss = ntxent(out1, out2)
assert (supcon_loss - ntxent_loss).pow(2).item() == pytest.approx(0.0)

@pytest.mark.parametrize("labels", [[0, 0, 0, 0], [0, 1, 1, 1], [0, 1, 2, 3]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you shouldn't compare to my implementation. Your code is basically a rewritten version of mine that is not legitimately tested (even if it is correct). I would instead compare against plain code rewritten in NumPy like done here:

def _calc_ntxent_loss_manual(

def test_equivalence(self, labels: List[int]) -> None:
DistributedSupCon = SupConLoss(temperature=self.temperature)
NonDistributedSupCon = SupConLossNonDistributed(temperature=self.temperature)
out1 = nn.functional.normalize(torch.rand(4, 10), dim=-1)
out2 = nn.functional.normalize(torch.rand(4, 10), dim=-1)
test_labels = Tensor(labels)

loss1 = DistributedSupCon(out1, out2, test_labels)
loss2 = NonDistributedSupCon(
torch.vstack((out1, out2)), test_labels.view(-1, 1)
)

assert (loss1 - loss2).pow(2).item() == pytest.approx(0.0)


class SupConLossNonDistributed(nn.Module):
def __init__(
self,
temperature: float = 0.1,
):
"""Contrastive Learning Loss Function: SupConLoss and InfoNCE Loss. Non-distributed version by Yutong.

SupCon from Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
InfoNCE (NT-Xent) from SimCLR: https://arxiv.org/pdf/2002.05709.pdf.

Adapted from Yonglong Tian's work at https://github.com/HobbitLong/SupContrast/blob/master/losses.py and
https://github.com/google-research/syn-rep-learn/blob/main/StableRep/models/losses.py.

The function first creates a contrastive mask of shape [batch_size * n_views, batch_size * n_views], where
mask_{i,j}=1 if sample j has the same class as sample i, except for the sample i itself.

Next, it computes the logits from the features and then computes the soft cross-entropy loss.

The loss is rescaled by the temperature parameter.

For self-supervised learning, the labels should be the indices of the samples. In this case it is equivalent to InfoNCE loss.

Attributes:
- temperature (float): A temperature parameter to control the similarity. Default is 0.1.

Args:
- features (torch.Tensor): hidden vector of shape [batch_size * n_views, ...].
- labels (torch.Tensor): ground truth of shape [batch_size, 1].
"""
super().__init__()

self.temperature = temperature

def forward(
self,
features: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
# create n-viewed mask
labels_n_views = labels.contiguous().repeat(
features.shape[0] // labels.shape[0], 1
) # [batch_size * n_views, 1]
contrastive_mask_n_views = torch.eq(
labels_n_views, labels_n_views.T
).float() # [batch_size * n_views, batch_size * n_views]
contrastive_mask_n_views.fill_diagonal_(0) # mask-out self-contrast cases

# compute logits
logits = (
torch.matmul(features, features.T) / self.temperature
) # [batch_size * n_views, batch_size * n_views]
logits.fill_diagonal_(-1e9) # suppress logit for self-contrast cases

# compute log probabilities and soft labels
soft_label = contrastive_mask_n_views / contrastive_mask_n_views.sum(dim=1)
log_proba = nn.functional.log_softmax(logits, dim=-1)

# compute soft cross-entropy loss
loss_all = torch.sum(soft_label * log_proba, dim=-1)
loss = -loss_all.mean()

# rescale for stable training
loss *= self.temperature

return loss