Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typecheck SwAV loss #1759

Merged
merged 8 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lightly/loss/pmsn_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,6 @@ def _power_law_distribution(size: int, exponent: float, device: torch.device) ->
A power law distribution tensor summing up to 1.
"""
k = torch.arange(1, size + 1, device=device)
power_dist = k ** (-exponent)
power_dist = torch.tensor(k ** (-exponent))
power_dist = power_dist / power_dist.sum()
return power_dist
21 changes: 11 additions & 10 deletions lightly/loss/swav_loss.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from typing import List
from typing import List, Union

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


@torch.no_grad()
def sinkhorn(
out: torch.Tensor,
out: Tensor,
iterations: int = 3,
epsilon: float = 0.05,
gather_distributed: bool = False,
) -> torch.Tensor:
) -> Tensor:
"""Distributed sinkhorn algorithm.

As outlined in [0] and implemented in [1].
Expand Down Expand Up @@ -113,7 +114,7 @@ def __init__(
self.sinkhorn_epsilon = sinkhorn_epsilon
self.sinkhorn_gather_distributed = sinkhorn_gather_distributed

def subloss(self, z: torch.Tensor, q: torch.Tensor):
def subloss(self, z: Tensor, q: Tensor) -> Tensor:
"""Calculates the cross entropy for the SwaV prediction problem.

Args:
Expand All @@ -131,10 +132,10 @@ def subloss(self, z: torch.Tensor, q: torch.Tensor):

def forward(
self,
high_resolution_outputs: List[torch.Tensor],
low_resolution_outputs: List[torch.Tensor],
queue_outputs: List[torch.Tensor] = None,
):
high_resolution_outputs: List[Tensor],
low_resolution_outputs: List[Tensor],
queue_outputs: Union[List[Tensor], None] = None,
) -> Tensor:
"""Computes the SwaV loss for a set of high and low resolution outputs.

- [0]: SwaV, 2020, https://arxiv.org/abs/2006.09882
Expand All @@ -156,7 +157,7 @@ def forward(
n_crops = len(high_resolution_outputs) + len(low_resolution_outputs)

# Multi-crop iterations
loss = 0.0
loss = torch.tensor(0.0)
for i in range(len(high_resolution_outputs)):
# Compute codes of i-th high resolution crop
with torch.no_grad():
Expand All @@ -179,7 +180,7 @@ def forward(
q = q[: len(high_resolution_outputs[i])]

# Compute subloss for each pair of crops
subloss = 0.0
subloss = torch.tensor(0.0)
for v in range(len(high_resolution_outputs)):
if v != i:
subloss += self.subloss(high_resolution_outputs[v], q)
Expand Down
6 changes: 0 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,8 @@ exclude = '''(?x)(
lightly/cli/train_cli.py |
lightly/cli/_cli_simclr.py |
lightly/cli/_helpers.py |
lightly/loss/pmsn_loss.py |
lightly/loss/swav_loss.py |
lightly/loss/negative_cosine_similarity.py |
lightly/loss/hypersphere_loss.py |
lightly/loss/msn_loss.py |
lightly/loss/dino_loss.py |
lightly/loss/sym_neg_cos_sim_loss.py |
lightly/loss/vicregl_loss.py |
Expand Down Expand Up @@ -248,17 +245,14 @@ exclude = '''(?x)(
tests/UNMOCKED_end2end_tests/create_custom_metadata_from_input_dir.py |
tests/UNMOCKED_end2end_tests/scripts_for_reproducing_problems/test_api_latency.py |
tests/loss/test_NegativeCosineSimilarity.py |
tests/loss/test_MSNLoss.py |
tests/loss/test_DINOLoss.py |
tests/loss/test_VICRegLLoss.py |
tests/loss/test_CO2Regularizer.py |
tests/loss/test_DCLLoss.py |
tests/loss/test_barlow_twins_loss.py |
tests/loss/test_SymNegCosineSimilarityLoss.py |
tests/loss/test_MemoryBank.py |
tests/loss/test_PMSNLoss.py |
tests/loss/test_HyperSphere.py |
tests/loss/test_SwaVLoss.py |
tests/core/test_Core.py |
tests/data/test_multi_view_collate.py |
tests/data/test_data_collate.py |
Expand Down
55 changes: 24 additions & 31 deletions tests/loss/test_MSNLoss.py → tests/loss/test_msn_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import unittest
from unittest import TestCase

import pytest
import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -30,19 +27,16 @@ def test__gather_distributed_dist_not_available(
MSNLoss(gather_distributed=True)
mock_is_available.assert_called_once()


class TestMSNLossUnitTest(TestCase):
# Old tests in unittest style, please add new tests to TestMSNLoss using pytest.
def test__init__temperature(self) -> None:
MSNLoss(temperature=1.0)
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
MSNLoss(temperature=0.0)
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
MSNLoss(temperature=-1.0)

def test__init__sinkhorn_iterations(self) -> None:
MSNLoss(sinkhorn_iterations=0)
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
MSNLoss(sinkhorn_iterations=-1)

def test__init__me_max_weight(self) -> None:
Expand All @@ -54,50 +48,49 @@ def test_prototype_probabilitiy(self) -> None:
queries = F.normalize(torch.rand((8, 10)), dim=1)
prototypes = F.normalize(torch.rand((4, 10)), dim=1)
prob = msn_loss.prototype_probabilities(queries, prototypes, temperature=0.5)
self.assertEqual(prob.shape, (8, 4))
self.assertLessEqual(prob.max(), 1.0)
self.assertGreater(prob.min(), 0.0)
assert prob.shape == (8, 4)
assert prob.max() < 1.0
assert prob.min() > 0.0

# verify sharpening
prob1 = msn_loss.prototype_probabilities(queries, prototypes, temperature=0.1)
# same prototypes should be assigned regardless of temperature
self.assertTrue(torch.all(prob.argmax(dim=1) == prob1.argmax(dim=1)))
assert torch.all(prob.argmax(dim=1) == prob1.argmax(dim=1))
# probabilities of selected prototypes should be higher for lower temperature
self.assertTrue(torch.all(prob.max(dim=1)[0] < prob1.max(dim=1)[0]))
assert torch.all(prob.max(dim=1)[0] < prob1.max(dim=1)[0])

def test_sharpen(self) -> None:
torch.manual_seed(0)
prob = torch.rand((8, 10))
p0 = msn_loss.sharpen(prob, temperature=0.5)
p1 = msn_loss.sharpen(prob, temperature=0.1)
# indices of max probabilities should be the same regardless of temperature
self.assertTrue(torch.all(p0.argmax(dim=1) == p1.argmax(dim=1)))
assert torch.all(p0.argmax(dim=1) == p1.argmax(dim=1))
# max probabilities should be higher for lower temperature
self.assertTrue(torch.all(p0.max(dim=1)[0] < p1.max(dim=1)[0]))
assert torch.all(p0.max(dim=1)[0] < p1.max(dim=1)[0])

def test_sinkhorn(self) -> None:
torch.manual_seed(0)
prob = torch.rand((8, 10))
out = msn_loss.sinkhorn(prob)
self.assertTrue(torch.all(prob != out))
assert torch.all(prob != out)

def test_sinkhorn_no_iter(self) -> None:
torch.manual_seed(0)
prob = torch.rand((8, 10))
out = msn_loss.sinkhorn(prob, iterations=0)
self.assertTrue(torch.all(prob == out))
assert torch.all(prob == out)

def test_forward(self) -> None:
@pytest.mark.parametrize("num_target_views", range(1, 4))
def test_forward(self, num_target_views: int) -> None:
torch.manual_seed(0)
for num_target_views in range(1, 4):
with self.subTest(num_views=num_target_views):
criterion = MSNLoss()
anchors = torch.rand((8 * num_target_views, 10))
targets = torch.rand((8, 10))
prototypes = torch.rand((4, 10), requires_grad=True)
criterion(anchors, targets, prototypes)

@unittest.skipUnless(torch.cuda.is_available(), "cuda not available")
criterion = MSNLoss()
anchors = torch.rand((8 * num_target_views, 10))
targets = torch.rand((8, 10))
prototypes = torch.rand((4, 10), requires_grad=True)
criterion(anchors, targets, prototypes)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda")
def test_forward_cuda(self) -> None:
torch.manual_seed(0)
criterion = MSNLoss()
Expand All @@ -124,9 +117,9 @@ def test_backward(self) -> None:
optimizer.step()
weights_after = head.layers[0].weight.data
# backward pass should update weights
self.assertTrue(torch.any(weights_before != weights_after))
assert torch.any(weights_before != weights_after)

@unittest.skipUnless(torch.cuda.is_available(), "cuda not available")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda")
def test_backward_cuda(self) -> None:
torch.manual_seed(0)
head = MSNProjectionHead(5, 16, 6)
Expand All @@ -146,4 +139,4 @@ def test_backward_cuda(self) -> None:
optimizer.step()
weights_after = head.layers[0].weight.data
# backward pass should update weights
self.assertTrue(torch.any(weights_before != weights_after))
assert torch.any(weights_before != weights_after)
5 changes: 2 additions & 3 deletions tests/loss/test_PMSNLoss.py → tests/loss/test_pmsn_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import math
import unittest

import pytest
import torch
Expand Down Expand Up @@ -32,7 +31,7 @@ def test_forward(self) -> None:
prototypes = torch.rand((4, 10), requires_grad=True)
criterion(anchors, targets, prototypes)

@unittest.skipUnless(torch.cuda.is_available(), "cuda not available")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda")
def test_forward_cuda(self) -> None:
torch.manual_seed(0)
criterion = PMSNLoss()
Expand Down Expand Up @@ -66,7 +65,7 @@ def test_forward(self) -> None:
prototypes = torch.rand((4, 10), requires_grad=True)
criterion(anchors, targets, prototypes)

@unittest.skipUnless(torch.cuda.is_available(), "cuda not available")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda")
def test_forward_cuda(self) -> None:
torch.manual_seed(0)
criterion = PMSNCustomLoss(target_distribution=_uniform_distribution)
Expand Down
64 changes: 19 additions & 45 deletions tests/loss/test_SwaVLoss.py → tests/loss/test_swav_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import unittest

import pytest
import torch
from pytest_mock import MockerFixture
Expand All @@ -24,10 +22,7 @@ def test__sinkhorn_gather_distributed_dist_not_available(
SwaVLoss(sinkhorn_gather_distributed=True)
mock_is_available.assert_called_once()


class TestSwaVLossUnitTest(unittest.TestCase):
# Old tests in unittest style, please add new tests to TestSwavLoss using pytest.
def test_forward_pass(self):
def test_forward_pass(self) -> None:
n = 32
n_high_res = 2
high_res = [torch.eye(32, 32) for i in range(n_high_res)]
Expand All @@ -36,34 +31,25 @@ def test_forward_pass(self):
for sinkhorn_iterations in range(3):
criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations)
low_res = [torch.eye(n, n) for i in range(n_low_res)]
loss = criterion(high_res, low_res)
# loss should be almost zero for unit matrix
assert loss.cpu().numpy() < 0.5

with self.subTest(
msg=f"n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}"
):
loss = criterion(high_res, low_res)
# loss should be almost zero for unit matrix
self.assertGreater(0.5, loss.cpu().numpy())

def test_forward_pass_queue(self):
def test_forward_pass_queue(self) -> None:
n = 32
n_high_res = 2
high_res = [torch.eye(32, 32) for i in range(n_high_res)]
queue_length = 128
queue = [torch.eye(128, 32) for i in range(n_high_res)]

for n_low_res in range(6):
for sinkhorn_iterations in range(3):
criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations)
low_res = [torch.eye(n, n) for i in range(n_low_res)]
loss = criterion(high_res, low_res, queue)
# loss should be almost zero for unit matrix
assert loss.cpu().numpy() < 0.5

with self.subTest(
msg=f"n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}"
):
loss = criterion(high_res, low_res, queue)
# loss should be almost zero for unit matrix
self.assertGreater(0.5, loss.cpu().numpy())

def test_forward_pass_bsz_1(self):
def test_forward_pass_bsz_1(self) -> None:
n = 32
n_high_res = 2
high_res = [torch.eye(1, n) for i in range(n_high_res)]
Expand All @@ -72,13 +58,9 @@ def test_forward_pass_bsz_1(self):
for sinkhorn_iterations in range(3):
criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations)
low_res = [torch.eye(1, n) for i in range(n_low_res)]
criterion(high_res, low_res)

with self.subTest(
msg=f"n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}"
):
loss = criterion(high_res, low_res)

def test_forward_pass_1d(self):
def test_forward_pass_1d(self) -> None:
n = 32
n_high_res = 2
high_res = [torch.eye(n, 1) for i in range(n_high_res)]
Expand All @@ -87,16 +69,12 @@ def test_forward_pass_1d(self):
for sinkhorn_iterations in range(3):
criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations)
low_res = [torch.eye(n, 1) for i in range(n_low_res)]
loss = criterion(high_res, low_res)
# loss should be almost zero for unit matrix
assert loss.cpu().numpy() < 0.5

with self.subTest(
msg=f"n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}"
):
loss = criterion(high_res, low_res)
# loss should be almost zero for unit matrix
self.assertGreater(0.5, loss.cpu().numpy())

@unittest.skipUnless(torch.cuda.is_available(), "skip")
def test_forward_pass_cuda(self):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda")
def test_forward_pass_cuda(self) -> None:
n = 32
n_high_res = 2
high_res = [torch.eye(n, n).cuda() for i in range(n_high_res)]
Expand All @@ -105,10 +83,6 @@ def test_forward_pass_cuda(self):
for sinkhorn_iterations in range(3):
criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations)
low_res = [torch.eye(n, n).cuda() for i in range(n_low_res)]

with self.subTest(
msg=f"n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}"
):
loss = criterion(high_res, low_res)
# loss should be almost zero for unit matrix
self.assertGreater(0.5, loss.cpu().numpy())
loss = criterion(high_res, low_res)
# loss should be almost zero for unit matrix
assert loss.cpu().numpy() < 0.5