Skip to content

Commit

Permalink
Merge branch 'master' into philipp-typecheck-vicregl-loss
Browse files Browse the repository at this point in the history
  • Loading branch information
philippmwirth authored Dec 30, 2024
2 parents 1267910 + 26bdfa6 commit 441eafa
Show file tree
Hide file tree
Showing 10 changed files with 377 additions and 382 deletions.
21 changes: 12 additions & 9 deletions lightly/loss/dino_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
from torch.nn import Module, Parameter

from lightly.models.modules import center
from lightly.models.modules.center import CENTER_MODE_TO_FUNCTION
Expand Down Expand Up @@ -83,6 +83,7 @@ def __init__(

# TODO(Guarin, 08/24): Refactor this to use the Center module directly once
# we do a breaking change.
self.center: Parameter
self.register_buffer("center", torch.zeros(1, 1, output_dim))

# we apply a warm up for the teacher temperature because
Expand Down Expand Up @@ -123,13 +124,15 @@ def forward(
if epoch < self.warmup_teacher_temp_epochs:
teacher_temp = self.teacher_temp_schedule[epoch]
else:
teacher_temp = self.teacher_temp
teacher_temp = torch.tensor(self.teacher_temp)

teacher_out = torch.stack(teacher_out)
t_out = F.softmax((teacher_out - self.center) / teacher_temp, dim=-1)
teacher_out_stacked = torch.stack(teacher_out)
t_out: Tensor = F.softmax(
(teacher_out_stacked - self.center) / teacher_temp, dim=-1
)

student_out = torch.stack(student_out)
s_out = F.log_softmax(student_out / self.student_temp, dim=-1)
student_out_stacked = torch.stack(student_out)
s_out = F.log_softmax(student_out_stacked / self.student_temp, dim=-1)

# Calculate feature similarities, ignoring the diagonal
# b = batch_size, t = n_views_teacher, s = n_views_student, d = output_dim
Expand All @@ -138,12 +141,12 @@ def forward(

# Number of loss terms, ignoring the diagonal
n_terms = loss.numel() - loss.diagonal().numel()
batch_size = teacher_out.shape[1]
batch_size = teacher_out_stacked.shape[1]

loss = loss.sum() / (n_terms * batch_size)

# Update the center used for the teacher output
self.update_center(teacher_out)
self.update_center(teacher_out_stacked)

return loss

Expand All @@ -161,6 +164,6 @@ def update_center(self, teacher_out: Tensor) -> None:
batch_center = self._center_fn(x=teacher_out, dim=(0, 1))

# Update the center with a moving average
self.center = center.center_momentum(
self.center.data = center.center_momentum(
center=self.center, batch_center=batch_center, momentum=self.center_momentum
)
17 changes: 10 additions & 7 deletions lightly/loss/hypersphere_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module


class HypersphereLoss(torch.nn.Module):
class HypersphereLoss(Module):
"""Implementation of the loss described in 'Understanding Contrastive Representation Learning through
Alignment and Uniformity on the Hypersphere.' [0]
Expand Down Expand Up @@ -44,7 +46,7 @@ class HypersphereLoss(torch.nn.Module):
>>> loss = loss_fn(out0, out1)
"""

def __init__(self, t=1.0, lam=1.0, alpha=2.0):
def __init__(self, t: float = 1.0, lam: float = 1.0, alpha: float = 2.0):
"""Initializes the HypersphereLoss module with the specified parameters.
Parameters as described in [0]
Expand All @@ -63,7 +65,7 @@ def __init__(self, t=1.0, lam=1.0, alpha=2.0):
self.lam = lam
self.alpha = alpha

def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor:
def forward(self, z_a: Tensor, z_b: Tensor) -> Tensor:
"""Computes the Hypersphere loss, which combines alignment and uniformity loss terms.
Args:
Expand All @@ -80,13 +82,14 @@ def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor:
y = F.normalize(z_b)

# Calculate alignment loss
def lalign(x, y):
return (x - y).norm(dim=1).pow(self.alpha).mean()
def lalign(x: Tensor, y: Tensor) -> Tensor:
lalign_: Tensor = (x - y).norm(dim=1).pow(self.alpha).mean()
return lalign_

# Calculate uniformity loss
def lunif(x):
def lunif(x: Tensor) -> Tensor:
sq_pdist = torch.pdist(x, p=2).pow(2)
return sq_pdist.mul(-self.t).exp().mean().log()

# Combine alignment and uniformity loss terms
return lalign(x, y) + self.lam * (lunif(x) + lunif(y)) / 2
return lalign(x, y) + self.lam * (lunif(x) + lunif(y)) / 2.0
10 changes: 6 additions & 4 deletions lightly/loss/sym_neg_cos_sim_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import warnings

import torch
from torch import Tensor
from torch.nn import Module


class SymNegCosineSimilarityLoss(torch.nn.Module):
class SymNegCosineSimilarityLoss(Module):
"""Implementation of the Symmetrized Loss used in the SimSiam[0] paper.
- [0] SimSiam, 2020, https://arxiv.org/abs/2011.10566
Expand Down Expand Up @@ -43,7 +45,7 @@ def __init__(self) -> None:
DeprecationWarning,
)

def forward(self, out0: torch.Tensor, out1: torch.Tensor):
def forward(self, out0: Tensor, out1: Tensor) -> Tensor:
"""Forward pass through Symmetric Loss.
Args:
Expand All @@ -64,14 +66,14 @@ def forward(self, out0: torch.Tensor, out1: torch.Tensor):
z0, p0 = out0
z1, p1 = out1

loss = (
loss: Tensor = (
self._neg_cosine_simililarity(p0, z1) / 2
+ self._neg_cosine_simililarity(p1, z0) / 2
)

return loss

def _neg_cosine_simililarity(self, x, y):
def _neg_cosine_simililarity(self, x: Tensor, y: Tensor) -> Tensor:
"""Calculates the negative cosine similarity between two tensors.
Args:
Expand Down
6 changes: 0 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,6 @@ exclude = '''(?x)(
lightly/cli/train_cli.py |
lightly/cli/_cli_simclr.py |
lightly/cli/_helpers.py |
lightly/loss/hypersphere_loss.py |
lightly/loss/dino_loss.py |
lightly/loss/sym_neg_cos_sim_loss.py |
lightly/loss/dcl_loss.py |
lightly/loss/regularizer/co2.py |
lightly/loss/barlow_twins_loss.py |
Expand Down Expand Up @@ -242,13 +239,10 @@ exclude = '''(?x)(
tests/UNMOCKED_end2end_tests/delete_datasets_test_unmocked_cli.py |
tests/UNMOCKED_end2end_tests/create_custom_metadata_from_input_dir.py |
tests/UNMOCKED_end2end_tests/scripts_for_reproducing_problems/test_api_latency.py |
tests/loss/test_DINOLoss.py |
tests/loss/test_CO2Regularizer.py |
tests/loss/test_DCLLoss.py |
tests/loss/test_barlow_twins_loss.py |
tests/loss/test_SymNegCosineSimilarityLoss.py |
tests/loss/test_MemoryBank.py |
tests/loss/test_HyperSphere.py |
tests/core/test_Core.py |
tests/data/test_multi_view_collate.py |
tests/data/test_data_collate.py |
Expand Down
Loading

0 comments on commit 441eafa

Please sign in to comment.