Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typecheck negative cosine similarity #1760

Merged
merged 14 commits into from
Dec 27, 2024
Merged
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions lightly/loss/swav_loss.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from typing import List
from typing import List, Union

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


@torch.no_grad()
def sinkhorn(
out: torch.Tensor,
out: Tensor,
iterations: int = 3,
epsilon: float = 0.05,
gather_distributed: bool = False,
) -> torch.Tensor:
) -> Tensor:
"""Distributed sinkhorn algorithm.
As outlined in [0] and implemented in [1].
@@ -113,7 +114,7 @@ def __init__(
self.sinkhorn_epsilon = sinkhorn_epsilon
self.sinkhorn_gather_distributed = sinkhorn_gather_distributed

def subloss(self, z: torch.Tensor, q: torch.Tensor):
def subloss(self, z: Tensor, q: Tensor) -> Tensor:
"""Calculates the cross entropy for the SwaV prediction problem.
Args:
@@ -131,10 +132,10 @@ def subloss(self, z: torch.Tensor, q: torch.Tensor):

def forward(
self,
high_resolution_outputs: List[torch.Tensor],
low_resolution_outputs: List[torch.Tensor],
queue_outputs: List[torch.Tensor] = None,
):
high_resolution_outputs: List[Tensor],
low_resolution_outputs: List[Tensor],
queue_outputs: Union[List[Tensor], None] = None,
) -> Tensor:
"""Computes the SwaV loss for a set of high and low resolution outputs.
- [0]: SwaV, 2020, https://arxiv.org/abs/2006.09882
@@ -156,7 +157,7 @@ def forward(
n_crops = len(high_resolution_outputs) + len(low_resolution_outputs)

# Multi-crop iterations
loss = 0.0
loss = torch.tensor(0.0)
for i in range(len(high_resolution_outputs)):
# Compute codes of i-th high resolution crop
with torch.no_grad():
@@ -179,7 +180,7 @@ def forward(
q = q[: len(high_resolution_outputs[i])]

# Compute subloss for each pair of crops
subloss = 0.0
subloss = torch.tensor(0.0)
for v in range(len(high_resolution_outputs)):
if v != i:
subloss += self.subloss(high_resolution_outputs[v], q)
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -192,8 +192,6 @@ exclude = '''(?x)(
lightly/cli/train_cli.py |
lightly/cli/_cli_simclr.py |
lightly/cli/_helpers.py |
lightly/loss/swav_loss.py |
lightly/loss/negative_cosine_similarity.py |
lightly/loss/hypersphere_loss.py |
lightly/loss/dino_loss.py |
lightly/loss/sym_neg_cos_sim_loss.py |
@@ -245,7 +243,6 @@ exclude = '''(?x)(
tests/UNMOCKED_end2end_tests/delete_datasets_test_unmocked_cli.py |
tests/UNMOCKED_end2end_tests/create_custom_metadata_from_input_dir.py |
tests/UNMOCKED_end2end_tests/scripts_for_reproducing_problems/test_api_latency.py |
tests/loss/test_NegativeCosineSimilarity.py |
tests/loss/test_DINOLoss.py |
tests/loss/test_VICRegLLoss.py |
tests/loss/test_CO2Regularizer.py |
@@ -254,7 +251,6 @@ exclude = '''(?x)(
tests/loss/test_SymNegCosineSimilarityLoss.py |
tests/loss/test_MemoryBank.py |
tests/loss/test_HyperSphere.py |
tests/loss/test_SwaVLoss.py |
tests/core/test_Core.py |
tests/data/test_multi_view_collate.py |
tests/data/test_data_collate.py |
30 changes: 0 additions & 30 deletions tests/loss/test_NegativeCosineSimilarity.py

This file was deleted.

114 changes: 0 additions & 114 deletions tests/loss/test_SwaVLoss.py

This file was deleted.

29 changes: 29 additions & 0 deletions tests/loss/test_negative_cosine_similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
import torch

from lightly.loss import NegativeCosineSimilarity


class TestNegativeCosineSimilarity:
@pytest.mark.parametrize("bsz", range(1, 20))
def test_forward_pass(self, bsz: int) -> None:
loss = NegativeCosineSimilarity()
x0 = torch.randn((bsz, 32))
x1 = torch.randn((bsz, 32))

# symmetry
l1 = loss(x0, x1)
l2 = loss(x1, x0)
assert l1 == pytest.approx(l2, abs=1e-5)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda")
@pytest.mark.parametrize("bsz", range(1, 20))
def test_forward_pass_cuda(self, bsz: int) -> None:
loss = NegativeCosineSimilarity()
x0 = torch.randn((bsz, 32)).cuda()
x1 = torch.randn((bsz, 32)).cuda()

# symmetry
l1 = loss(x0, x1)
l2 = loss(x1, x0)
assert l1 == pytest.approx(l2, abs=1e-5)
83 changes: 83 additions & 0 deletions tests/loss/test_swav_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pytest
import torch
from pytest_mock import MockerFixture
from torch import distributed as dist

from lightly.loss import SwaVLoss


class TestNTXentLoss:
def test__sinkhorn_gather_distributed(self, mocker: MockerFixture) -> None:
mock_is_available = mocker.patch.object(dist, "is_available", return_value=True)
SwaVLoss(sinkhorn_gather_distributed=True)
mock_is_available.assert_called_once()

def test__sinkhorn_gather_distributed_dist_not_available(
self, mocker: MockerFixture
) -> None:
mock_is_available = mocker.patch.object(
dist, "is_available", return_value=False
)
with pytest.raises(ValueError):
SwaVLoss(sinkhorn_gather_distributed=True)
mock_is_available.assert_called_once()

@pytest.mark.parametrize("n_low_res", range(6))
@pytest.mark.parametrize("sinkhorn_iterations", range(3))
def test_forward_pass(self, n_low_res: int, sinkhorn_iterations: int) -> None:
n = 32
n_high_res = 2
high_res = [torch.eye(32, 32) for i in range(n_high_res)]
criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations)
low_res = [torch.eye(n, n) for i in range(n_low_res)]
loss = criterion(high_res, low_res)
# loss should be almost zero for unit matrix
assert loss.cpu().numpy() < 0.5

@pytest.mark.parametrize("n_low_res", range(6))
@pytest.mark.parametrize("sinkhorn_iterations", range(3))
def test_forward_pass_queue(self, n_low_res: int, sinkhorn_iterations: int) -> None:
n = 32
n_high_res = 2
high_res = [torch.eye(32, 32) for i in range(n_high_res)]
queue = [torch.eye(128, 32) for i in range(n_high_res)]
criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations)
low_res = [torch.eye(n, n) for i in range(n_low_res)]
loss = criterion(high_res, low_res, queue)
# loss should be almost zero for unit matrix
assert loss.cpu().numpy() < 0.5

@pytest.mark.parametrize("n_low_res", range(6))
@pytest.mark.parametrize("sinkhorn_iterations", range(3))
def test_forward_pass_bsz_1(self, n_low_res: int, sinkhorn_iterations: int) -> None:
n = 32
n_high_res = 2
high_res = [torch.eye(1, n) for i in range(n_high_res)]
criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations)
low_res = [torch.eye(1, n) for i in range(n_low_res)]
criterion(high_res, low_res)

@pytest.mark.parametrize("n_low_res", range(6))
@pytest.mark.parametrize("sinkhorn_iterations", range(3))
def test_forward_pass_1d(self, n_low_res: int, sinkhorn_iterations: int) -> None:
n = 32
n_high_res = 2
high_res = [torch.eye(n, 1) for i in range(n_high_res)]
criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations)
low_res = [torch.eye(n, 1) for i in range(n_low_res)]
loss = criterion(high_res, low_res)
# loss should be almost zero for unit matrix
assert loss.cpu().numpy() < 0.5

@pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda")
@pytest.mark.parametrize("n_low_res", range(6))
@pytest.mark.parametrize("sinkhorn_iterations", range(3))
def test_forward_pass_cuda(self, n_low_res: int, sinkhorn_iterations: int) -> None:
n = 32
n_high_res = 2
high_res = [torch.eye(n, n).cuda() for i in range(n_high_res)]
criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations)
low_res = [torch.eye(n, n).cuda() for i in range(n_low_res)]
loss = criterion(high_res, low_res)
# loss should be almost zero for unit matrix
assert loss.cpu().numpy() < 0.5