33# Copyright (c) 2020. Lightly AG and its affiliates.
44# All Rights Reserved
55
6- from enum import Enum
7- from typing import Optional , Tuple
6+ from typing import Optional
87
98import torch
109import torch .nn .functional as F
1514from lightly .utils import dist
1615
1716
18- def divide_no_nan (numerator : Tensor , denominator : Tensor ) -> Tensor :
19- """Performs tensor division, setting result to zero where denominator is zero.
20-
21- Args:
22- numerator:
23- Numerator tensor.
24- denominator:
25- Denominator tensor with possible zeroes.
26-
27- Returns:
28- Result with zeros where denominator is zero.
29- """
30- result = torch .zeros_like (numerator )
31- nonzero_mask = denominator != 0
32- result [nonzero_mask ] = numerator [nonzero_mask ] / denominator [nonzero_mask ]
33- return result
34-
35-
36- class ContrastMode (Enum ):
37- """Contrast Mode Enum for SupCon Loss.
38-
39- Offers the three contrast modes as enum for the SupCon loss. The three modes are:
40-
41- - ContrastMode.ALL: Uses all positives and negatives.
42- - ContrastMode.ONE_POSITIVE: Uses only one positive, and all negatives.
43- - ContrastMode.ONLY_NEGATIVES: Uses no positives, only negatives.
44- """
45-
46- ALL = 1
47- ONE_POSITIVE = 2
48- ONLY_NEGATIVES = 3
49-
50-
51- VALID_CONTRAST_MODES = set (item .name for item in ContrastMode )
52-
53-
5417class SupConLoss (nn .Module ):
5518 """Implementation of the Supervised Contrastive Loss.
5619
@@ -61,64 +24,55 @@ class SupConLoss(nn.Module):
6124 Attributes:
6225 temperature:
6326 Scale logits by the inverse of the temperature.
64- contrast_mode:
65- Whether to use all positives, one positive, or none. All negatives are
66- used in all cases.
6727 gather_distributed:
6828 If True then negatives from all GPUs are gathered before the
69- loss calculation.
70-
29+ loss calculation. If a memory bank is used and gather_distributed is True,
30+ then tensors from all gpus are gathered before the memory bank is updated.
31+ rescale:
32+ Optionally rescale final loss by the temperature for stability.
7133 Raises:
7234 ValueError: If abs(temperature) < 1e-8 to prevent divide by zero.
73- ValueError: If gather_distributed is True but torch.distributed is not available.
74- ValueError: If contrast_mode is outside the accepted ContrastMode values.
7535
7636 Examples:
77- >>> # initialize loss function
78- >>> loss_fn = SupConLoss( )
37+ >>> # initialize loss function without memory bank
38+ >>> loss_fn = NTXentLoss(memory_bank_size=0 )
7939 >>>
80- >>> # generate two or more views of images
40+ >>> # generate two random transforms of images
8141 >>> t0 = transforms(images)
8242 >>> t1 = transforms(images)
8343 >>>
84- >>> # feed through SimCLR model
44+ >>> # feed through SimCLR or MoCo model
8545 >>> out0, out1 = model(t0), model(t1)
8646 >>>
87- >>> # Stack views along 2nd dimensions
88- >>> features = torch.stack([out0, out1], dim=1)
89- >>>
9047 >>> # calculate loss
91- >>> loss = loss_fn(features, labels )
48+ >>> loss = loss_fn(out0, out1 )
9249
9350 """
9451
9552 def __init__ (
9653 self ,
9754 temperature : float = 0.5 ,
98- contrast_mode : ContrastMode = ContrastMode .ALL ,
9955 gather_distributed : bool = False ,
56+ rescale : bool = True ,
10057 ):
10158 """Initializes the SupConLoss module with the specified parameters.
10259
10360 Args:
10461 temperature:
10562 Scale logits by the inverse of the temperature.
106- contrast_mode:
107- Whether to use all positives, one positive, or none. All negatives are
108- used in all cases.
10963 gather_distributed:
11064 If True, negatives from all GPUs are gathered before the loss calculation.
65+ rescale:
66+ Optionally rescale final loss by the temperature for stability.
11167
11268 Raises:
11369 ValueError: If temperature is less than 1e-8 to prevent divide by zero.
11470 ValueError: If gather_distributed is True but torch.distributed is not available.
115- ValueError: If contrast_mode is outside the accepted ContrastMode values.
11671 """
11772 super ().__init__ ()
11873 self .temperature = temperature
119- self .contrast_mode = contrast_mode
120- self .positives_cap = - 1 # Unused at the moment
12174 self .gather_distributed = gather_distributed
75+ self .rescale = rescale
12276 self .cross_entropy = nn .CrossEntropyLoss (reduction = "mean" )
12377 self .eps = 1e-8
12478
@@ -133,206 +87,82 @@ def __init__(
13387 "distributed support."
13488 )
13589
136- if contrast_mode .name not in VALID_CONTRAST_MODES :
137- raise ValueError (
138- f"contrast_mode is { contrast_mode } but must be one of ContrastMode.{ VALID_CONTRAST_MODES } "
139- )
140-
141- def forward (self , features : Tensor , labels : Optional [Tensor ] = None ) -> Tensor :
90+ def forward (
91+ self , out0 : Tensor , out1 : Tensor , labels : Optional [Tensor ] = None
92+ ) -> Tensor :
14293 """Forward pass through Supervised Contrastive Loss.
14394
14495 Computes the loss based on contrast_mode setting.
14596
14697 Args:
147- features:
148- Tensor of at least 3 dimensions, corresponding to
149- (batch_size, num_views, ...)
98+ out0:
99+ Output projections of the first set of transformed images.
100+ Shape: (batch_size, embedding_size)
101+ out1:
102+ Output projections of the second set of transformed images.
103+ Shape: (batch_size, embedding_size)
150104 labels:
151- Onehot labels for each sample. Must match shape
152- (batch_size, num_classes)
153-
154- Raises:
155- ValueError: If features does not have at least 3 dimensions.
156- ValueError: If number of labels does not match batch_size.
157- ValueError: If labels is not one-hot encoded.
105+ Onehot labels for each sample. Must be a vector of length `batch_size`.
158106
159107 Returns:
160108 Supervised Contrastive Loss value.
161109 """
110+ # Stack the views for efficient computation
111+ # Allows for more views to be added easily
112+ features = (out0 , out1 )
113+ n_views = len (features )
114+ out_small = torch .vstack (features )
162115
163- if len (features .shape ) < 3 :
164- raise ValueError (
165- f"Features must have at least 3 dimensions, got { len (features .shape )} ."
166- )
167-
168- device = features .device
169- batch_size , num_views = features .shape [:2 ]
170-
171- if labels is not None and labels .size (0 ) != batch_size :
172- raise ValueError (
173- f"When setting labels, labels must match batch_size { batch_size } , got { labels .size (0 )} ."
174- )
175-
176- if labels is not None :
177- if not self ._is_one_hot (labels ):
178- raise ValueError (
179- "labels must be a 2D matrix representing the one-hot encoded classes."
180- )
181-
182- # Flatten the features in case they are still images or other
183- features = features .flatten (2 )
184-
185- # Normalize the features to length 1
186- features = F .normalize (features , dim = 2 )
187-
188- # Memory bank could be used here but labelled samples are not yet supported.
116+ device = out_small .device
117+ batch_size = out_small .shape [0 ] // n_views
189118
190- # Use cosine similarity (dot product) as all vectors are normalized to unit length
119+ # Normalize the output to length 1
120+ out_small = nn .functional .normalize (out_small , dim = 1 )
191121
192- # Use other samples from different classes in batch as negatives
193- # and create diagonal mask that only selects similarities between
194- # views of the same image / same class
122+ # Gather hidden representations from other processes if distributed
123+ # and compute the diagonal self-contrast mask
195124 if self .gather_distributed and dist .world_size () > 1 :
196- # Gather hidden representations and optional labels from other processes
197- global_features = torch .cat (dist .gather (features ), 0 )
198- diag_mask = dist .eye_rank (batch_size , device = device )
199- if labels is not None :
200- global_labels = torch .cat (dist .gather (labels ), 0 )
125+ out_large = torch .cat (dist .gather (out_small ), 0 )
126+ diag_mask = dist .eye_rank (n_views * batch_size , device = device )
201127 else :
202128 # Single process
203- global_features = features
204- diag_mask = torch .eye (batch_size , device = device , dtype = torch .bool )
205- if labels is not None :
206- global_labels = labels
207-
208- # Use the diagonal mask if labels is none, else compute the mask based on labels
209- if labels is None :
210- # No labels, typical semi-supervised contrastive learning like SimCLR
211- mask = diag_mask
212- else :
213- mask = (labels @ global_labels .T ).to (device )
129+ out_large = out_small
130+ diag_mask = torch .eye (n_views * batch_size , device = device , dtype = torch .bool )
214131
215- # Get features in shape [num_views * batch_size, c]
216- all_global_features = global_features .permute (1 , 0 , 2 ).reshape (
217- - 1 , global_features .size (- 1 )
218- )
219-
220- if self .contrast_mode == ContrastMode .ONE_POSITIVE :
221- # We take only the first view as anchor
222- anchor_features = features [:, 0 ]
223- num_anchor_views = 1
224- else :
225- # We take all views as anchors in the same shape as the global features
226- anchor_features = features .permute (1 , 0 , 2 ).reshape (- 1 , features .size (- 1 ))
227- num_anchor_views = num_views
228-
229- # Obtain the logits between anchor features and features across all processes
230- # Logits will be shaped [local_batch_size * num_anchor_views, global_batch_size * num_views]
231- # We then temperature scale it and subtract the max to improve numerical stability
232- # In the einsum, n is local_batch_size * num_anchor_views, m is global_batch_size * num_views,
233- # and c is the flattened feature length
234- # Note: features are ordered by view first, i.e. first all samples of view 0, then all samples
235- # of view 1, and so on.
236- logits = torch .einsum ("nc,mc->nm" , anchor_features , all_global_features )
132+ # Use cosine similarity (dot product) as all vectors are normalized to unit length
133+ # Calculate similiarities
134+ logits = out_small @ out_large .T
237135 logits /= self .temperature
238- logits -= logits .max (dim = 1 , keepdim = True )[0 ].detach ()
239- exp_logits = torch .exp (logits )
240136
241- # Get the positive and negative masks for numerator & denominator
242- positives_mask , negatives_mask = self ._create_tiled_masks (
243- mask .long (),
244- diag_mask .long (),
245- num_views ,
246- num_anchor_views ,
247- self .positives_cap ,
248- )
249- num_positives_per_row = positives_mask .sum (dim = 1 )
137+ # Set self-similarities to infinitely small value
138+ logits [diag_mask ] = - 1e9
250139
251- # Calculate denominator based on contrast_mode
252- if self . contrast_mode == ContrastMode . ONE_POSITIVE :
253- denominator = exp_logits + ( exp_logits * negatives_mask ). sum (
254- dim = 1 , keepdim = True
255- )
256- elif self . contrast_mode == ContrastMode . ALL :
257- denominator = ( exp_logits * negatives_mask ). sum ( dim = 1 , keepdim = True )
258- denominator += ( exp_logits * positives_mask ). sum ( dim = 1 , keepdim = True )
259- else : # ContrastMode.ONLY_NEGATIVES
260- denominator = ( exp_logits * negatives_mask ). sum ( dim = 1 , keepdim = True )
261-
262- # num_positives_per_row can be zero iff 1 view is used. Here we use a safe
263- # dividing method seting those values to zero to prevent division by zero errors.
264-
265- # Only implements SupCon_{out}.
266- log_probs = ( logits - torch . log ( denominator )) * positives_mask
267- log_probs = log_probs . sum ( dim = 1 )
268- log_probs = divide_no_nan ( log_probs , num_positives_per_row )
269-
270- loss = - log_probs
271-
272- # Adjust for num_positives_per_row being zero when using exactly 1 view
273- if num_views != 1 :
274- loss = loss . mean ( dim = 0 )
275- else :
276- num_valid_views_per_sample = num_positives_per_row . unsqueeze ( 0 )
277- loss = divide_no_nan ( loss , num_valid_views_per_sample ). squeeze ()
140+ # Create labels if None
141+ if labels is None :
142+ labels = torch . arange ( batch_size , device = device , dtype = torch . long )
143+ if self . gather_distributed :
144+ labels = labels + dist . rank () * batch_size
145+ labels = labels . repeat ( n_views )
146+
147+ # Soft labels are 0 unless the logit represents a similarity
148+ # between two of the same classes. We manually set self-similarity
149+ # (same view of the same item) to 0. When not 0, the value is
150+ # 1 / n, where n is the number of positive samples
151+ # (different views of the same item, and all views of other items sharing
152+ # classes with the item)
153+ soft_labels = torch . eq ( labels , labels . view ( - 1 , 1 )). float ()
154+ soft_labels . fill_diagonal_ ( 0.0 )
155+ soft_labels /= soft_labels . sum ( dim = 1 )
156+
157+ # Compute log probabilities
158+ log_proba = F . log_softmax ( logits , dim = - 1 )
159+
160+ # Compute soft cross-entropy loss
161+ loss = ( soft_labels * log_proba ). sum ( - 1 )
162+ loss = - loss . mean ()
163+
164+ # Optional: rescale for stable training
165+ if self . rescale :
166+ loss *= self . temperature
278167
279168 return loss
280-
281- def _create_tiled_masks (
282- self ,
283- untiled_mask : Tensor ,
284- diagonal_mask : Tensor ,
285- num_views : int ,
286- num_anchor_views : int ,
287- positives_cap : int ,
288- ) -> Tuple [Tensor , Tensor ]:
289- # Get total batch size across all processes
290- global_batch_size = untiled_mask .size (1 )
291-
292- # Find index of the anchor for each sample
293- labels = torch .argmax (diagonal_mask , dim = 1 )
294-
295- # Generate tiled labels across views
296- tiled_labels = []
297- for i in range (num_anchor_views ):
298- tiled_labels .append (labels + global_batch_size * i )
299- tiled_labels_tensor = torch .cat (tiled_labels , 0 )
300- tiled_diagonal_mask = F .one_hot (
301- tiled_labels_tensor , global_batch_size * num_views
302- )
303-
304- # Mask to zero the diagonal at the end
305- all_but_diagonal_mask = 1 - tiled_diagonal_mask
306-
307- # All tiled positives
308- uncapped_positives_mask = torch .tile (
309- untiled_mask , [num_anchor_views , num_views ]
310- )
311-
312- # The negatives is simply the bitflipped positives
313- negatives_mask = 1.0 - uncapped_positives_mask
314-
315- # For when positives_cap is implemented
316- if positives_cap > - 1 :
317- raise NotImplementedError ("Capping positives is not yet implemented." )
318- else :
319- positives_mask = uncapped_positives_mask
320-
321- # Zero out the self-contrast
322- positives_mask *= all_but_diagonal_mask
323-
324- return positives_mask , negatives_mask
325-
326- def _is_one_hot (self , tensor : Tensor ) -> bool :
327- # Tensor is not a 2D matrix
328- if tensor .ndim != 2 :
329- return False
330-
331- # Check values are only 0 or 1
332- is_binary = ((tensor == 0 ) | (tensor == 1 )).all ()
333-
334- # Check each row sums to 1
335- row_sums = tensor .sum (dim = 1 )
336- has_single_one = (row_sums == 1 ).all ()
337-
338- return bool (is_binary .item () and has_single_one .item ())
0 commit comments