Skip to content

Commit 7a5a73d

Browse files
Simplify SupConLoss (lightly-ai#1877)
1 parent 9e38fcd commit 7a5a73d

File tree

2 files changed

+185
-251
lines changed

2 files changed

+185
-251
lines changed

lightly/loss/supcon_loss.py

Lines changed: 71 additions & 241 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
# Copyright (c) 2020. Lightly AG and its affiliates.
44
# All Rights Reserved
55

6-
from enum import Enum
7-
from typing import Optional, Tuple
6+
from typing import Optional
87

98
import torch
109
import torch.nn.functional as F
@@ -15,42 +14,6 @@
1514
from lightly.utils import dist
1615

1716

18-
def divide_no_nan(numerator: Tensor, denominator: Tensor) -> Tensor:
19-
"""Performs tensor division, setting result to zero where denominator is zero.
20-
21-
Args:
22-
numerator:
23-
Numerator tensor.
24-
denominator:
25-
Denominator tensor with possible zeroes.
26-
27-
Returns:
28-
Result with zeros where denominator is zero.
29-
"""
30-
result = torch.zeros_like(numerator)
31-
nonzero_mask = denominator != 0
32-
result[nonzero_mask] = numerator[nonzero_mask] / denominator[nonzero_mask]
33-
return result
34-
35-
36-
class ContrastMode(Enum):
37-
"""Contrast Mode Enum for SupCon Loss.
38-
39-
Offers the three contrast modes as enum for the SupCon loss. The three modes are:
40-
41-
- ContrastMode.ALL: Uses all positives and negatives.
42-
- ContrastMode.ONE_POSITIVE: Uses only one positive, and all negatives.
43-
- ContrastMode.ONLY_NEGATIVES: Uses no positives, only negatives.
44-
"""
45-
46-
ALL = 1
47-
ONE_POSITIVE = 2
48-
ONLY_NEGATIVES = 3
49-
50-
51-
VALID_CONTRAST_MODES = set(item.name for item in ContrastMode)
52-
53-
5417
class SupConLoss(nn.Module):
5518
"""Implementation of the Supervised Contrastive Loss.
5619
@@ -61,64 +24,55 @@ class SupConLoss(nn.Module):
6124
Attributes:
6225
temperature:
6326
Scale logits by the inverse of the temperature.
64-
contrast_mode:
65-
Whether to use all positives, one positive, or none. All negatives are
66-
used in all cases.
6727
gather_distributed:
6828
If True then negatives from all GPUs are gathered before the
69-
loss calculation.
70-
29+
loss calculation. If a memory bank is used and gather_distributed is True,
30+
then tensors from all gpus are gathered before the memory bank is updated.
31+
rescale:
32+
Optionally rescale final loss by the temperature for stability.
7133
Raises:
7234
ValueError: If abs(temperature) < 1e-8 to prevent divide by zero.
73-
ValueError: If gather_distributed is True but torch.distributed is not available.
74-
ValueError: If contrast_mode is outside the accepted ContrastMode values.
7535
7636
Examples:
77-
>>> # initialize loss function
78-
>>> loss_fn = SupConLoss()
37+
>>> # initialize loss function without memory bank
38+
>>> loss_fn = NTXentLoss(memory_bank_size=0)
7939
>>>
80-
>>> # generate two or more views of images
40+
>>> # generate two random transforms of images
8141
>>> t0 = transforms(images)
8242
>>> t1 = transforms(images)
8343
>>>
84-
>>> # feed through SimCLR model
44+
>>> # feed through SimCLR or MoCo model
8545
>>> out0, out1 = model(t0), model(t1)
8646
>>>
87-
>>> # Stack views along 2nd dimensions
88-
>>> features = torch.stack([out0, out1], dim=1)
89-
>>>
9047
>>> # calculate loss
91-
>>> loss = loss_fn(features, labels)
48+
>>> loss = loss_fn(out0, out1)
9249
9350
"""
9451

9552
def __init__(
9653
self,
9754
temperature: float = 0.5,
98-
contrast_mode: ContrastMode = ContrastMode.ALL,
9955
gather_distributed: bool = False,
56+
rescale: bool = True,
10057
):
10158
"""Initializes the SupConLoss module with the specified parameters.
10259
10360
Args:
10461
temperature:
10562
Scale logits by the inverse of the temperature.
106-
contrast_mode:
107-
Whether to use all positives, one positive, or none. All negatives are
108-
used in all cases.
10963
gather_distributed:
11064
If True, negatives from all GPUs are gathered before the loss calculation.
65+
rescale:
66+
Optionally rescale final loss by the temperature for stability.
11167
11268
Raises:
11369
ValueError: If temperature is less than 1e-8 to prevent divide by zero.
11470
ValueError: If gather_distributed is True but torch.distributed is not available.
115-
ValueError: If contrast_mode is outside the accepted ContrastMode values.
11671
"""
11772
super().__init__()
11873
self.temperature = temperature
119-
self.contrast_mode = contrast_mode
120-
self.positives_cap = -1 # Unused at the moment
12174
self.gather_distributed = gather_distributed
75+
self.rescale = rescale
12276
self.cross_entropy = nn.CrossEntropyLoss(reduction="mean")
12377
self.eps = 1e-8
12478

@@ -133,206 +87,82 @@ def __init__(
13387
"distributed support."
13488
)
13589

136-
if contrast_mode.name not in VALID_CONTRAST_MODES:
137-
raise ValueError(
138-
f"contrast_mode is {contrast_mode} but must be one of ContrastMode.{VALID_CONTRAST_MODES}"
139-
)
140-
141-
def forward(self, features: Tensor, labels: Optional[Tensor] = None) -> Tensor:
90+
def forward(
91+
self, out0: Tensor, out1: Tensor, labels: Optional[Tensor] = None
92+
) -> Tensor:
14293
"""Forward pass through Supervised Contrastive Loss.
14394
14495
Computes the loss based on contrast_mode setting.
14596
14697
Args:
147-
features:
148-
Tensor of at least 3 dimensions, corresponding to
149-
(batch_size, num_views, ...)
98+
out0:
99+
Output projections of the first set of transformed images.
100+
Shape: (batch_size, embedding_size)
101+
out1:
102+
Output projections of the second set of transformed images.
103+
Shape: (batch_size, embedding_size)
150104
labels:
151-
Onehot labels for each sample. Must match shape
152-
(batch_size, num_classes)
153-
154-
Raises:
155-
ValueError: If features does not have at least 3 dimensions.
156-
ValueError: If number of labels does not match batch_size.
157-
ValueError: If labels is not one-hot encoded.
105+
Onehot labels for each sample. Must be a vector of length `batch_size`.
158106
159107
Returns:
160108
Supervised Contrastive Loss value.
161109
"""
110+
# Stack the views for efficient computation
111+
# Allows for more views to be added easily
112+
features = (out0, out1)
113+
n_views = len(features)
114+
out_small = torch.vstack(features)
162115

163-
if len(features.shape) < 3:
164-
raise ValueError(
165-
f"Features must have at least 3 dimensions, got {len(features.shape)}."
166-
)
167-
168-
device = features.device
169-
batch_size, num_views = features.shape[:2]
170-
171-
if labels is not None and labels.size(0) != batch_size:
172-
raise ValueError(
173-
f"When setting labels, labels must match batch_size {batch_size}, got {labels.size(0)}."
174-
)
175-
176-
if labels is not None:
177-
if not self._is_one_hot(labels):
178-
raise ValueError(
179-
"labels must be a 2D matrix representing the one-hot encoded classes."
180-
)
181-
182-
# Flatten the features in case they are still images or other
183-
features = features.flatten(2)
184-
185-
# Normalize the features to length 1
186-
features = F.normalize(features, dim=2)
187-
188-
# Memory bank could be used here but labelled samples are not yet supported.
116+
device = out_small.device
117+
batch_size = out_small.shape[0] // n_views
189118

190-
# Use cosine similarity (dot product) as all vectors are normalized to unit length
119+
# Normalize the output to length 1
120+
out_small = nn.functional.normalize(out_small, dim=1)
191121

192-
# Use other samples from different classes in batch as negatives
193-
# and create diagonal mask that only selects similarities between
194-
# views of the same image / same class
122+
# Gather hidden representations from other processes if distributed
123+
# and compute the diagonal self-contrast mask
195124
if self.gather_distributed and dist.world_size() > 1:
196-
# Gather hidden representations and optional labels from other processes
197-
global_features = torch.cat(dist.gather(features), 0)
198-
diag_mask = dist.eye_rank(batch_size, device=device)
199-
if labels is not None:
200-
global_labels = torch.cat(dist.gather(labels), 0)
125+
out_large = torch.cat(dist.gather(out_small), 0)
126+
diag_mask = dist.eye_rank(n_views * batch_size, device=device)
201127
else:
202128
# Single process
203-
global_features = features
204-
diag_mask = torch.eye(batch_size, device=device, dtype=torch.bool)
205-
if labels is not None:
206-
global_labels = labels
207-
208-
# Use the diagonal mask if labels is none, else compute the mask based on labels
209-
if labels is None:
210-
# No labels, typical semi-supervised contrastive learning like SimCLR
211-
mask = diag_mask
212-
else:
213-
mask = (labels @ global_labels.T).to(device)
129+
out_large = out_small
130+
diag_mask = torch.eye(n_views * batch_size, device=device, dtype=torch.bool)
214131

215-
# Get features in shape [num_views * batch_size, c]
216-
all_global_features = global_features.permute(1, 0, 2).reshape(
217-
-1, global_features.size(-1)
218-
)
219-
220-
if self.contrast_mode == ContrastMode.ONE_POSITIVE:
221-
# We take only the first view as anchor
222-
anchor_features = features[:, 0]
223-
num_anchor_views = 1
224-
else:
225-
# We take all views as anchors in the same shape as the global features
226-
anchor_features = features.permute(1, 0, 2).reshape(-1, features.size(-1))
227-
num_anchor_views = num_views
228-
229-
# Obtain the logits between anchor features and features across all processes
230-
# Logits will be shaped [local_batch_size * num_anchor_views, global_batch_size * num_views]
231-
# We then temperature scale it and subtract the max to improve numerical stability
232-
# In the einsum, n is local_batch_size * num_anchor_views, m is global_batch_size * num_views,
233-
# and c is the flattened feature length
234-
# Note: features are ordered by view first, i.e. first all samples of view 0, then all samples
235-
# of view 1, and so on.
236-
logits = torch.einsum("nc,mc->nm", anchor_features, all_global_features)
132+
# Use cosine similarity (dot product) as all vectors are normalized to unit length
133+
# Calculate similiarities
134+
logits = out_small @ out_large.T
237135
logits /= self.temperature
238-
logits -= logits.max(dim=1, keepdim=True)[0].detach()
239-
exp_logits = torch.exp(logits)
240136

241-
# Get the positive and negative masks for numerator & denominator
242-
positives_mask, negatives_mask = self._create_tiled_masks(
243-
mask.long(),
244-
diag_mask.long(),
245-
num_views,
246-
num_anchor_views,
247-
self.positives_cap,
248-
)
249-
num_positives_per_row = positives_mask.sum(dim=1)
137+
# Set self-similarities to infinitely small value
138+
logits[diag_mask] = -1e9
250139

251-
# Calculate denominator based on contrast_mode
252-
if self.contrast_mode == ContrastMode.ONE_POSITIVE:
253-
denominator = exp_logits + (exp_logits * negatives_mask).sum(
254-
dim=1, keepdim=True
255-
)
256-
elif self.contrast_mode == ContrastMode.ALL:
257-
denominator = (exp_logits * negatives_mask).sum(dim=1, keepdim=True)
258-
denominator += (exp_logits * positives_mask).sum(dim=1, keepdim=True)
259-
else: # ContrastMode.ONLY_NEGATIVES
260-
denominator = (exp_logits * negatives_mask).sum(dim=1, keepdim=True)
261-
262-
# num_positives_per_row can be zero iff 1 view is used. Here we use a safe
263-
# dividing method seting those values to zero to prevent division by zero errors.
264-
265-
# Only implements SupCon_{out}.
266-
log_probs = (logits - torch.log(denominator)) * positives_mask
267-
log_probs = log_probs.sum(dim=1)
268-
log_probs = divide_no_nan(log_probs, num_positives_per_row)
269-
270-
loss = -log_probs
271-
272-
# Adjust for num_positives_per_row being zero when using exactly 1 view
273-
if num_views != 1:
274-
loss = loss.mean(dim=0)
275-
else:
276-
num_valid_views_per_sample = num_positives_per_row.unsqueeze(0)
277-
loss = divide_no_nan(loss, num_valid_views_per_sample).squeeze()
140+
# Create labels if None
141+
if labels is None:
142+
labels = torch.arange(batch_size, device=device, dtype=torch.long)
143+
if self.gather_distributed:
144+
labels = labels + dist.rank() * batch_size
145+
labels = labels.repeat(n_views)
146+
147+
# Soft labels are 0 unless the logit represents a similarity
148+
# between two of the same classes. We manually set self-similarity
149+
# (same view of the same item) to 0. When not 0, the value is
150+
# 1 / n, where n is the number of positive samples
151+
# (different views of the same item, and all views of other items sharing
152+
# classes with the item)
153+
soft_labels = torch.eq(labels, labels.view(-1, 1)).float()
154+
soft_labels.fill_diagonal_(0.0)
155+
soft_labels /= soft_labels.sum(dim=1)
156+
157+
# Compute log probabilities
158+
log_proba = F.log_softmax(logits, dim=-1)
159+
160+
# Compute soft cross-entropy loss
161+
loss = (soft_labels * log_proba).sum(-1)
162+
loss = -loss.mean()
163+
164+
# Optional: rescale for stable training
165+
if self.rescale:
166+
loss *= self.temperature
278167

279168
return loss
280-
281-
def _create_tiled_masks(
282-
self,
283-
untiled_mask: Tensor,
284-
diagonal_mask: Tensor,
285-
num_views: int,
286-
num_anchor_views: int,
287-
positives_cap: int,
288-
) -> Tuple[Tensor, Tensor]:
289-
# Get total batch size across all processes
290-
global_batch_size = untiled_mask.size(1)
291-
292-
# Find index of the anchor for each sample
293-
labels = torch.argmax(diagonal_mask, dim=1)
294-
295-
# Generate tiled labels across views
296-
tiled_labels = []
297-
for i in range(num_anchor_views):
298-
tiled_labels.append(labels + global_batch_size * i)
299-
tiled_labels_tensor = torch.cat(tiled_labels, 0)
300-
tiled_diagonal_mask = F.one_hot(
301-
tiled_labels_tensor, global_batch_size * num_views
302-
)
303-
304-
# Mask to zero the diagonal at the end
305-
all_but_diagonal_mask = 1 - tiled_diagonal_mask
306-
307-
# All tiled positives
308-
uncapped_positives_mask = torch.tile(
309-
untiled_mask, [num_anchor_views, num_views]
310-
)
311-
312-
# The negatives is simply the bitflipped positives
313-
negatives_mask = 1.0 - uncapped_positives_mask
314-
315-
# For when positives_cap is implemented
316-
if positives_cap > -1:
317-
raise NotImplementedError("Capping positives is not yet implemented.")
318-
else:
319-
positives_mask = uncapped_positives_mask
320-
321-
# Zero out the self-contrast
322-
positives_mask *= all_but_diagonal_mask
323-
324-
return positives_mask, negatives_mask
325-
326-
def _is_one_hot(self, tensor: Tensor) -> bool:
327-
# Tensor is not a 2D matrix
328-
if tensor.ndim != 2:
329-
return False
330-
331-
# Check values are only 0 or 1
332-
is_binary = ((tensor == 0) | (tensor == 1)).all()
333-
334-
# Check each row sums to 1
335-
row_sums = tensor.sum(dim=1)
336-
has_single_one = (row_sums == 1).all()
337-
338-
return bool(is_binary.item() and has_single_one.item())

0 commit comments

Comments
 (0)