Skip to content

Commit

Permalink
Typecheck DCLLoss (#1765)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippmwirth authored Dec 30, 2024
1 parent d9b8de1 commit 30b2e45
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 93 deletions.
12 changes: 7 additions & 5 deletions lightly/loss/dcl_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def negative_mises_fisher_weights(
out0: Tensor, out1: Tensor, sigma: float = 0.5
) -> torch.Tensor:
) -> Tensor:
"""Negative Mises-Fisher weighting function as presented in Decoupled Contrastive Learning [0].
The implementation was inspired by [1].
Expand All @@ -35,7 +35,7 @@ def negative_mises_fisher_weights(
similarity = torch.einsum("nm,nm->n", out0.detach(), out1.detach()) / sigma

# Return negative Mises-Fisher weights
return 2 - out0.shape[0] * nn.functional.softmax(similarity, dim=0)
return torch.tensor(2 - out0.shape[0] * nn.functional.softmax(similarity, dim=0))


class DCLLoss(nn.Module):
Expand Down Expand Up @@ -148,13 +148,15 @@ def forward(
out1_all = out1

# Calculate symmetric loss
loss0 = self._loss(out0, out1, out0_all, out1_all)
loss1 = self._loss(out1, out0, out1_all, out0_all)
loss0: Tensor = self._loss(out0, out1, out0_all, out1_all)
loss1: Tensor = self._loss(out1, out0, out1_all, out0_all)

# Return the mean loss over the mini-batch
return 0.5 * (loss0 + loss1)

def _loss(self, out0, out1, out0_all, out1_all):
def _loss(
self, out0: Tensor, out1: Tensor, out0_all: Tensor, out1_all: Tensor
) -> Tensor:
"""Calculates DCL loss for out0 with respect to its positives in out1
and the negatives in out1, out0_all, and out1_all.
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ exclude = '''(?x)(
lightly/cli/train_cli.py |
lightly/cli/_cli_simclr.py |
lightly/cli/_helpers.py |
lightly/loss/dcl_loss.py |
lightly/loss/barlow_twins_loss.py |
lightly/data/dataset.py |
lightly/data/collate.py |
Expand Down Expand Up @@ -238,7 +237,6 @@ exclude = '''(?x)(
tests/UNMOCKED_end2end_tests/delete_datasets_test_unmocked_cli.py |
tests/UNMOCKED_end2end_tests/create_custom_metadata_from_input_dir.py |
tests/UNMOCKED_end2end_tests/scripts_for_reproducing_problems/test_api_latency.py |
tests/loss/test_DCLLoss.py |
tests/loss/test_barlow_twins_loss.py |
tests/loss/test_MemoryBank.py |
tests/core/test_Core.py |
Expand Down
86 changes: 0 additions & 86 deletions tests/loss/test_DCLLoss.py

This file was deleted.

103 changes: 103 additions & 0 deletions tests/loss/test_dcl_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pytest
import torch
from pytest_mock import MockerFixture
from torch import distributed as dist

from lightly.loss.dcl_loss import DCLLoss, DCLWLoss, negative_mises_fisher_weights


class TestDCLLoss:
def test__gather_distributed(self, mocker: MockerFixture) -> None:
mock_is_available = mocker.patch.object(dist, "is_available", return_value=True)
DCLLoss(gather_distributed=True)
mock_is_available.assert_called_once()

def test__gather_distributed_dist_not_available(
self, mocker: MockerFixture
) -> None:
mock_is_available = mocker.patch.object(
dist, "is_available", return_value=False
)
with pytest.raises(ValueError):
DCLLoss(gather_distributed=True)
mock_is_available.assert_called_once()

@pytest.mark.parametrize("sigma", [0.0000001, 0.5, 10000])
def test_negative_mises_fisher_weights(self, sigma: float, seed: int = 0) -> None:
torch.manual_seed(seed)
out0 = torch.rand((3, 5))
out1 = torch.rand((3, 5))
negative_mises_fisher_weights(out0, out1, sigma)

@pytest.mark.parametrize("batch_size", [2, 3])
@pytest.mark.parametrize("dim", [1, 3])
@pytest.mark.parametrize("temperature", [0.1, 0.5, 1.0])
@pytest.mark.parametrize("gather_distributed", [False, True])
def test_dclloss_forward(
self,
batch_size: int,
dim: int,
temperature: float,
gather_distributed: bool,
seed: int = 0,
) -> None:
torch.manual_seed(seed=seed)
out0 = torch.rand((batch_size, dim))
out1 = torch.rand((batch_size, dim))
criterion = DCLLoss(
temperature=temperature,
gather_distributed=gather_distributed,
weight_fn=negative_mises_fisher_weights,
)
loss0 = criterion(out0, out1)
loss1 = criterion(out1, out0)
assert loss0 > 0
assert loss0 == pytest.approx(loss1)

@pytest.mark.parametrize("batch_size", [2, 3])
@pytest.mark.parametrize("dim", [1, 3])
@pytest.mark.parametrize("temperature", [0.1, 0.5, 1.0])
@pytest.mark.parametrize("gather_distributed", [False, True])
def test_dclloss_forward__no_weight_fn(
self,
batch_size: int,
dim: int,
temperature: float,
gather_distributed: bool,
seed: int = 0,
) -> None:
torch.manual_seed(seed=seed)
out0 = torch.rand((batch_size, dim))
out1 = torch.rand((batch_size, dim))
criterion = DCLLoss(
temperature=temperature,
gather_distributed=gather_distributed,
weight_fn=None,
)
loss0 = criterion(out0, out1)
loss1 = criterion(out1, out0)
assert loss0 > 0
assert loss0 == pytest.approx(loss1)

def test_dclloss_backprop(self, seed: int = 0) -> None:
torch.manual_seed(seed=seed)
out0 = torch.rand(3, 5)
out1 = torch.rand(3, 5)
layer = torch.nn.Linear(5, 5)
out0 = layer(out0)
out1 = layer(out1)
criterion = DCLLoss()
optimizer = torch.optim.SGD(layer.parameters(), lr=0.1)
loss = criterion(out0, out1)
loss.backward()
optimizer.step()

def test_dclwloss_forward(self, seed: int = 0) -> None:
torch.manual_seed(seed=seed)
out0 = torch.rand(3, 5)
out1 = torch.rand(3, 5)
criterion = DCLWLoss()
loss0 = criterion(out0, out1)
loss1 = criterion(out1, out0)
assert loss0 > 0
assert loss0 == pytest.approx(loss1)

0 comments on commit 30b2e45

Please sign in to comment.