From 27e9ae5418f70e627fb95519e7ec4a24ff0ec73a Mon Sep 17 00:00:00 2001 From: Philipp Wirth Date: Fri, 27 Dec 2024 14:30:23 +0100 Subject: [PATCH 1/7] Typecheck (P)MSN loss --- lightly/loss/pmsn_loss.py | 2 +- pyproject.toml | 4 -- .../{test_MSNLoss.py => test_msn_loss.py} | 55 ++++++++----------- .../{test_PMSNLoss.py => test_pmsn_loss.py} | 5 +- 4 files changed, 27 insertions(+), 39 deletions(-) rename tests/loss/{test_MSNLoss.py => test_msn_loss.py} (74%) rename tests/loss/{test_PMSNLoss.py => test_pmsn_loss.py} (95%) diff --git a/lightly/loss/pmsn_loss.py b/lightly/loss/pmsn_loss.py index e41a0a349..473eba7bd 100644 --- a/lightly/loss/pmsn_loss.py +++ b/lightly/loss/pmsn_loss.py @@ -177,6 +177,6 @@ def _power_law_distribution(size: int, exponent: float, device: torch.device) -> A power law distribution tensor summing up to 1. """ k = torch.arange(1, size + 1, device=device) - power_dist = k ** (-exponent) + power_dist = torch.tensor(k ** (-exponent)) power_dist = power_dist / power_dist.sum() return power_dist diff --git a/pyproject.toml b/pyproject.toml index d24032e98..fb3610404 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -192,11 +192,9 @@ exclude = '''(?x)( lightly/cli/train_cli.py | lightly/cli/_cli_simclr.py | lightly/cli/_helpers.py | - lightly/loss/pmsn_loss.py | lightly/loss/swav_loss.py | lightly/loss/negative_cosine_similarity.py | lightly/loss/hypersphere_loss.py | - lightly/loss/msn_loss.py | lightly/loss/dino_loss.py | lightly/loss/sym_neg_cos_sim_loss.py | lightly/loss/vicregl_loss.py | @@ -248,7 +246,6 @@ exclude = '''(?x)( tests/UNMOCKED_end2end_tests/create_custom_metadata_from_input_dir.py | tests/UNMOCKED_end2end_tests/scripts_for_reproducing_problems/test_api_latency.py | tests/loss/test_NegativeCosineSimilarity.py | - tests/loss/test_MSNLoss.py | tests/loss/test_DINOLoss.py | tests/loss/test_VICRegLLoss.py | tests/loss/test_CO2Regularizer.py | @@ -256,7 +253,6 @@ exclude = '''(?x)( tests/loss/test_barlow_twins_loss.py | tests/loss/test_SymNegCosineSimilarityLoss.py | tests/loss/test_MemoryBank.py | - tests/loss/test_PMSNLoss.py | tests/loss/test_HyperSphere.py | tests/loss/test_SwaVLoss.py | tests/core/test_Core.py | diff --git a/tests/loss/test_MSNLoss.py b/tests/loss/test_msn_loss.py similarity index 74% rename from tests/loss/test_MSNLoss.py rename to tests/loss/test_msn_loss.py index 2cbb68357..a4b1101fb 100644 --- a/tests/loss/test_MSNLoss.py +++ b/tests/loss/test_msn_loss.py @@ -1,6 +1,3 @@ -import unittest -from unittest import TestCase - import pytest import torch import torch.nn.functional as F @@ -30,19 +27,16 @@ def test__gather_distributed_dist_not_available( MSNLoss(gather_distributed=True) mock_is_available.assert_called_once() - -class TestMSNLossUnitTest(TestCase): - # Old tests in unittest style, please add new tests to TestMSNLoss using pytest. def test__init__temperature(self) -> None: MSNLoss(temperature=1.0) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): MSNLoss(temperature=0.0) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): MSNLoss(temperature=-1.0) def test__init__sinkhorn_iterations(self) -> None: MSNLoss(sinkhorn_iterations=0) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): MSNLoss(sinkhorn_iterations=-1) def test__init__me_max_weight(self) -> None: @@ -54,16 +48,16 @@ def test_prototype_probabilitiy(self) -> None: queries = F.normalize(torch.rand((8, 10)), dim=1) prototypes = F.normalize(torch.rand((4, 10)), dim=1) prob = msn_loss.prototype_probabilities(queries, prototypes, temperature=0.5) - self.assertEqual(prob.shape, (8, 4)) - self.assertLessEqual(prob.max(), 1.0) - self.assertGreater(prob.min(), 0.0) + assert prob.shape == (8, 4) + assert prob.max() == 1.0 + assert prob.min() > 0.0 # verify sharpening prob1 = msn_loss.prototype_probabilities(queries, prototypes, temperature=0.1) # same prototypes should be assigned regardless of temperature - self.assertTrue(torch.all(prob.argmax(dim=1) == prob1.argmax(dim=1))) + assert torch.all(prob.argmax(dim=1) == prob1.argmax(dim=1)) # probabilities of selected prototypes should be higher for lower temperature - self.assertTrue(torch.all(prob.max(dim=1)[0] < prob1.max(dim=1)[0])) + assert torch.all(prob.max(dim=1)[0] < prob1.max(dim=1)[0]) def test_sharpen(self) -> None: torch.manual_seed(0) @@ -71,33 +65,32 @@ def test_sharpen(self) -> None: p0 = msn_loss.sharpen(prob, temperature=0.5) p1 = msn_loss.sharpen(prob, temperature=0.1) # indices of max probabilities should be the same regardless of temperature - self.assertTrue(torch.all(p0.argmax(dim=1) == p1.argmax(dim=1))) + assert torch.all(p0.argmax(dim=1) == p1.argmax(dim=1)) # max probabilities should be higher for lower temperature - self.assertTrue(torch.all(p0.max(dim=1)[0] < p1.max(dim=1)[0])) + assert torch.all(p0.max(dim=1)[0] < p1.max(dim=1)[0]) def test_sinkhorn(self) -> None: torch.manual_seed(0) prob = torch.rand((8, 10)) out = msn_loss.sinkhorn(prob) - self.assertTrue(torch.all(prob != out)) + assert torch.all(prob != out) def test_sinkhorn_no_iter(self) -> None: torch.manual_seed(0) prob = torch.rand((8, 10)) out = msn_loss.sinkhorn(prob, iterations=0) - self.assertTrue(torch.all(prob == out)) + assert torch.all(prob == out) - def test_forward(self) -> None: + @pytest.mark.parametrize("num_target_views", range(1, 4)) + def test_forward(self, num_target_views: int) -> None: torch.manual_seed(0) - for num_target_views in range(1, 4): - with self.subTest(num_views=num_target_views): - criterion = MSNLoss() - anchors = torch.rand((8 * num_target_views, 10)) - targets = torch.rand((8, 10)) - prototypes = torch.rand((4, 10), requires_grad=True) - criterion(anchors, targets, prototypes) - - @unittest.skipUnless(torch.cuda.is_available(), "cuda not available") + criterion = MSNLoss() + anchors = torch.rand((8 * num_target_views, 10)) + targets = torch.rand((8, 10)) + prototypes = torch.rand((4, 10), requires_grad=True) + criterion(anchors, targets, prototypes) + + @pytest.mark.skipif(not torch.cuda.is_available(), "No cuda") def test_forward_cuda(self) -> None: torch.manual_seed(0) criterion = MSNLoss() @@ -124,9 +117,9 @@ def test_backward(self) -> None: optimizer.step() weights_after = head.layers[0].weight.data # backward pass should update weights - self.assertTrue(torch.any(weights_before != weights_after)) + assert torch.any(weights_before != weights_after) - @unittest.skipUnless(torch.cuda.is_available(), "cuda not available") + @pytest.mark.skipif(not torch.cuda.is_available(), "No cuda") def test_backward_cuda(self) -> None: torch.manual_seed(0) head = MSNProjectionHead(5, 16, 6) @@ -146,4 +139,4 @@ def test_backward_cuda(self) -> None: optimizer.step() weights_after = head.layers[0].weight.data # backward pass should update weights - self.assertTrue(torch.any(weights_before != weights_after)) + assert torch.any(weights_before != weights_after) diff --git a/tests/loss/test_PMSNLoss.py b/tests/loss/test_pmsn_loss.py similarity index 95% rename from tests/loss/test_PMSNLoss.py rename to tests/loss/test_pmsn_loss.py index 1ddc8b84e..c1112afc9 100644 --- a/tests/loss/test_PMSNLoss.py +++ b/tests/loss/test_pmsn_loss.py @@ -1,5 +1,4 @@ import math -import unittest import pytest import torch @@ -32,7 +31,7 @@ def test_forward(self) -> None: prototypes = torch.rand((4, 10), requires_grad=True) criterion(anchors, targets, prototypes) - @unittest.skipUnless(torch.cuda.is_available(), "cuda not available") + @pytest.mark.skipif(torch.cuda.is_available(), "No cuda") def test_forward_cuda(self) -> None: torch.manual_seed(0) criterion = PMSNLoss() @@ -66,7 +65,7 @@ def test_forward(self) -> None: prototypes = torch.rand((4, 10), requires_grad=True) criterion(anchors, targets, prototypes) - @unittest.skipUnless(torch.cuda.is_available(), "cuda not available") + @pytest.mark.skipif(torch.cuda.is_available(), "No cuda") def test_forward_cuda(self) -> None: torch.manual_seed(0) criterion = PMSNCustomLoss(target_distribution=_uniform_distribution) From 26b59e126081f17c3db11e70afa891c64f3eb7da Mon Sep 17 00:00:00 2001 From: Philipp Wirth Date: Fri, 27 Dec 2024 14:38:43 +0100 Subject: [PATCH 2/7] Type check SwAV loss --- lightly/loss/swav_loss.py | 21 +++--- pyproject.toml | 2 - .../{test_SwaVLoss.py => test_swav_loss.py} | 64 ++++++------------- 3 files changed, 30 insertions(+), 57 deletions(-) rename tests/loss/{test_SwaVLoss.py => test_swav_loss.py} (56%) diff --git a/lightly/loss/swav_loss.py b/lightly/loss/swav_loss.py index bf4294812..084324419 100644 --- a/lightly/loss/swav_loss.py +++ b/lightly/loss/swav_loss.py @@ -1,18 +1,19 @@ -from typing import List +from typing import List, Union import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +from torch import Tensor @torch.no_grad() def sinkhorn( - out: torch.Tensor, + out: Tensor, iterations: int = 3, epsilon: float = 0.05, gather_distributed: bool = False, -) -> torch.Tensor: +) -> Tensor: """Distributed sinkhorn algorithm. As outlined in [0] and implemented in [1]. @@ -113,7 +114,7 @@ def __init__( self.sinkhorn_epsilon = sinkhorn_epsilon self.sinkhorn_gather_distributed = sinkhorn_gather_distributed - def subloss(self, z: torch.Tensor, q: torch.Tensor): + def subloss(self, z: Tensor, q: Tensor) -> Tensor: """Calculates the cross entropy for the SwaV prediction problem. Args: @@ -131,10 +132,10 @@ def subloss(self, z: torch.Tensor, q: torch.Tensor): def forward( self, - high_resolution_outputs: List[torch.Tensor], - low_resolution_outputs: List[torch.Tensor], - queue_outputs: List[torch.Tensor] = None, - ): + high_resolution_outputs: List[Tensor], + low_resolution_outputs: List[Tensor], + queue_outputs: Union[List[Tensor], None] = None, + ) -> Tensor: """Computes the SwaV loss for a set of high and low resolution outputs. - [0]: SwaV, 2020, https://arxiv.org/abs/2006.09882 @@ -156,7 +157,7 @@ def forward( n_crops = len(high_resolution_outputs) + len(low_resolution_outputs) # Multi-crop iterations - loss = 0.0 + loss = torch.tensor(0.0) for i in range(len(high_resolution_outputs)): # Compute codes of i-th high resolution crop with torch.no_grad(): @@ -179,7 +180,7 @@ def forward( q = q[: len(high_resolution_outputs[i])] # Compute subloss for each pair of crops - subloss = 0.0 + subloss = torch.tensor(0.0) for v in range(len(high_resolution_outputs)): if v != i: subloss += self.subloss(high_resolution_outputs[v], q) diff --git a/pyproject.toml b/pyproject.toml index fb3610404..f1bc629d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -192,7 +192,6 @@ exclude = '''(?x)( lightly/cli/train_cli.py | lightly/cli/_cli_simclr.py | lightly/cli/_helpers.py | - lightly/loss/swav_loss.py | lightly/loss/negative_cosine_similarity.py | lightly/loss/hypersphere_loss.py | lightly/loss/dino_loss.py | @@ -254,7 +253,6 @@ exclude = '''(?x)( tests/loss/test_SymNegCosineSimilarityLoss.py | tests/loss/test_MemoryBank.py | tests/loss/test_HyperSphere.py | - tests/loss/test_SwaVLoss.py | tests/core/test_Core.py | tests/data/test_multi_view_collate.py | tests/data/test_data_collate.py | diff --git a/tests/loss/test_SwaVLoss.py b/tests/loss/test_swav_loss.py similarity index 56% rename from tests/loss/test_SwaVLoss.py rename to tests/loss/test_swav_loss.py index c125dec74..dd3529734 100644 --- a/tests/loss/test_SwaVLoss.py +++ b/tests/loss/test_swav_loss.py @@ -1,5 +1,3 @@ -import unittest - import pytest import torch from pytest_mock import MockerFixture @@ -24,10 +22,7 @@ def test__sinkhorn_gather_distributed_dist_not_available( SwaVLoss(sinkhorn_gather_distributed=True) mock_is_available.assert_called_once() - -class TestSwaVLossUnitTest(unittest.TestCase): - # Old tests in unittest style, please add new tests to TestSwavLoss using pytest. - def test_forward_pass(self): + def test_forward_pass(self) -> None: n = 32 n_high_res = 2 high_res = [torch.eye(32, 32) for i in range(n_high_res)] @@ -36,34 +31,25 @@ def test_forward_pass(self): for sinkhorn_iterations in range(3): criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) low_res = [torch.eye(n, n) for i in range(n_low_res)] + loss = criterion(high_res, low_res) + # loss should be almost zero for unit matrix + assert loss.cpu().numpy() < 0.5 - with self.subTest( - msg=f"n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}" - ): - loss = criterion(high_res, low_res) - # loss should be almost zero for unit matrix - self.assertGreater(0.5, loss.cpu().numpy()) - - def test_forward_pass_queue(self): + def test_forward_pass_queue(self) -> None: n = 32 n_high_res = 2 high_res = [torch.eye(32, 32) for i in range(n_high_res)] - queue_length = 128 queue = [torch.eye(128, 32) for i in range(n_high_res)] for n_low_res in range(6): for sinkhorn_iterations in range(3): criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) low_res = [torch.eye(n, n) for i in range(n_low_res)] + loss = criterion(high_res, low_res, queue) + # loss should be almost zero for unit matrix + assert loss.cpu().numpy() < 0.5 - with self.subTest( - msg=f"n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}" - ): - loss = criterion(high_res, low_res, queue) - # loss should be almost zero for unit matrix - self.assertGreater(0.5, loss.cpu().numpy()) - - def test_forward_pass_bsz_1(self): + def test_forward_pass_bsz_1(self) -> None: n = 32 n_high_res = 2 high_res = [torch.eye(1, n) for i in range(n_high_res)] @@ -72,13 +58,9 @@ def test_forward_pass_bsz_1(self): for sinkhorn_iterations in range(3): criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) low_res = [torch.eye(1, n) for i in range(n_low_res)] + criterion(high_res, low_res) - with self.subTest( - msg=f"n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}" - ): - loss = criterion(high_res, low_res) - - def test_forward_pass_1d(self): + def test_forward_pass_1d(self) -> None: n = 32 n_high_res = 2 high_res = [torch.eye(n, 1) for i in range(n_high_res)] @@ -87,16 +69,12 @@ def test_forward_pass_1d(self): for sinkhorn_iterations in range(3): criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) low_res = [torch.eye(n, 1) for i in range(n_low_res)] + loss = criterion(high_res, low_res) + # loss should be almost zero for unit matrix + assert loss.cpu().numpy() < 0.5 - with self.subTest( - msg=f"n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}" - ): - loss = criterion(high_res, low_res) - # loss should be almost zero for unit matrix - self.assertGreater(0.5, loss.cpu().numpy()) - - @unittest.skipUnless(torch.cuda.is_available(), "skip") - def test_forward_pass_cuda(self): + @pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda") + def test_forward_pass_cuda(self) -> None: n = 32 n_high_res = 2 high_res = [torch.eye(n, n).cuda() for i in range(n_high_res)] @@ -105,10 +83,6 @@ def test_forward_pass_cuda(self): for sinkhorn_iterations in range(3): criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) low_res = [torch.eye(n, n).cuda() for i in range(n_low_res)] - - with self.subTest( - msg=f"n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}" - ): - loss = criterion(high_res, low_res) - # loss should be almost zero for unit matrix - self.assertGreater(0.5, loss.cpu().numpy()) + loss = criterion(high_res, low_res) + # loss should be almost zero for unit matrix + assert loss.cpu().numpy() < 0.5 From 6d63cd064cd31a7af67eefa45552c25b60fd0e8e Mon Sep 17 00:00:00 2001 From: Philipp Wirth Date: Fri, 27 Dec 2024 14:39:55 +0100 Subject: [PATCH 3/7] Fix skips --- tests/loss/test_msn_loss.py | 4 ++-- tests/loss/test_pmsn_loss.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/loss/test_msn_loss.py b/tests/loss/test_msn_loss.py index a4b1101fb..968fdd7da 100644 --- a/tests/loss/test_msn_loss.py +++ b/tests/loss/test_msn_loss.py @@ -90,7 +90,7 @@ def test_forward(self, num_target_views: int) -> None: prototypes = torch.rand((4, 10), requires_grad=True) criterion(anchors, targets, prototypes) - @pytest.mark.skipif(not torch.cuda.is_available(), "No cuda") + @pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda") def test_forward_cuda(self) -> None: torch.manual_seed(0) criterion = MSNLoss() @@ -119,7 +119,7 @@ def test_backward(self) -> None: # backward pass should update weights assert torch.any(weights_before != weights_after) - @pytest.mark.skipif(not torch.cuda.is_available(), "No cuda") + @pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda") def test_backward_cuda(self) -> None: torch.manual_seed(0) head = MSNProjectionHead(5, 16, 6) diff --git a/tests/loss/test_pmsn_loss.py b/tests/loss/test_pmsn_loss.py index c1112afc9..0278fc580 100644 --- a/tests/loss/test_pmsn_loss.py +++ b/tests/loss/test_pmsn_loss.py @@ -31,7 +31,7 @@ def test_forward(self) -> None: prototypes = torch.rand((4, 10), requires_grad=True) criterion(anchors, targets, prototypes) - @pytest.mark.skipif(torch.cuda.is_available(), "No cuda") + @pytest.mark.skipif(torch.cuda.is_available(), reason="No cuda") def test_forward_cuda(self) -> None: torch.manual_seed(0) criterion = PMSNLoss() @@ -65,7 +65,7 @@ def test_forward(self) -> None: prototypes = torch.rand((4, 10), requires_grad=True) criterion(anchors, targets, prototypes) - @pytest.mark.skipif(torch.cuda.is_available(), "No cuda") + @pytest.mark.skipif(torch.cuda.is_available(), reason="No cuda") def test_forward_cuda(self) -> None: torch.manual_seed(0) criterion = PMSNCustomLoss(target_distribution=_uniform_distribution) From b73804d36a034831cb80f159a045963ef674ac31 Mon Sep 17 00:00:00 2001 From: Philipp Wirth Date: Fri, 27 Dec 2024 14:59:13 +0100 Subject: [PATCH 4/7] Fix typos --- tests/loss/test_msn_loss.py | 2 +- tests/loss/test_pmsn_loss.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/loss/test_msn_loss.py b/tests/loss/test_msn_loss.py index 968fdd7da..8ee34dc60 100644 --- a/tests/loss/test_msn_loss.py +++ b/tests/loss/test_msn_loss.py @@ -49,7 +49,7 @@ def test_prototype_probabilitiy(self) -> None: prototypes = F.normalize(torch.rand((4, 10)), dim=1) prob = msn_loss.prototype_probabilities(queries, prototypes, temperature=0.5) assert prob.shape == (8, 4) - assert prob.max() == 1.0 + assert prob.max() < 1.0 assert prob.min() > 0.0 # verify sharpening diff --git a/tests/loss/test_pmsn_loss.py b/tests/loss/test_pmsn_loss.py index 0278fc580..0c8d6f60e 100644 --- a/tests/loss/test_pmsn_loss.py +++ b/tests/loss/test_pmsn_loss.py @@ -31,7 +31,7 @@ def test_forward(self) -> None: prototypes = torch.rand((4, 10), requires_grad=True) criterion(anchors, targets, prototypes) - @pytest.mark.skipif(torch.cuda.is_available(), reason="No cuda") + @pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda") def test_forward_cuda(self) -> None: torch.manual_seed(0) criterion = PMSNLoss() @@ -65,7 +65,7 @@ def test_forward(self) -> None: prototypes = torch.rand((4, 10), requires_grad=True) criterion(anchors, targets, prototypes) - @pytest.mark.skipif(torch.cuda.is_available(), reason="No cuda") + @pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda") def test_forward_cuda(self) -> None: torch.manual_seed(0) criterion = PMSNCustomLoss(target_distribution=_uniform_distribution) From 85bbaa7eb4a0ef8cf48745006d044021aca88cfa Mon Sep 17 00:00:00 2001 From: Philipp Wirth Date: Fri, 27 Dec 2024 15:06:41 +0100 Subject: [PATCH 5/7] Typecheck negative cosine similarity --- pyproject.toml | 2 -- ...rity.py => test_negative_cosine_similarity.py} | 15 +++++++-------- 2 files changed, 7 insertions(+), 10 deletions(-) rename tests/loss/{test_NegativeCosineSimilarity.py => test_negative_cosine_similarity.py} (61%) diff --git a/pyproject.toml b/pyproject.toml index fb3610404..7bd072e82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -193,7 +193,6 @@ exclude = '''(?x)( lightly/cli/_cli_simclr.py | lightly/cli/_helpers.py | lightly/loss/swav_loss.py | - lightly/loss/negative_cosine_similarity.py | lightly/loss/hypersphere_loss.py | lightly/loss/dino_loss.py | lightly/loss/sym_neg_cos_sim_loss.py | @@ -245,7 +244,6 @@ exclude = '''(?x)( tests/UNMOCKED_end2end_tests/delete_datasets_test_unmocked_cli.py | tests/UNMOCKED_end2end_tests/create_custom_metadata_from_input_dir.py | tests/UNMOCKED_end2end_tests/scripts_for_reproducing_problems/test_api_latency.py | - tests/loss/test_NegativeCosineSimilarity.py | tests/loss/test_DINOLoss.py | tests/loss/test_VICRegLLoss.py | tests/loss/test_CO2Regularizer.py | diff --git a/tests/loss/test_NegativeCosineSimilarity.py b/tests/loss/test_negative_cosine_similarity.py similarity index 61% rename from tests/loss/test_NegativeCosineSimilarity.py rename to tests/loss/test_negative_cosine_similarity.py index 44d29522a..2f8c58a23 100644 --- a/tests/loss/test_NegativeCosineSimilarity.py +++ b/tests/loss/test_negative_cosine_similarity.py @@ -1,12 +1,11 @@ -import unittest - +import pytest import torch from lightly.loss import NegativeCosineSimilarity -class TestNegativeCosineSimilarity(unittest.TestCase): - def test_forward_pass(self): +class TestNegativeCosineSimilarity: + def test_forward_pass(self) -> None: loss = NegativeCosineSimilarity() for bsz in range(1, 20): x0 = torch.randn((bsz, 32)) @@ -15,10 +14,10 @@ def test_forward_pass(self): # symmetry l1 = loss(x0, x1) l2 = loss(x1, x0) - self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0) + assert l1 == pytest.approx(l2, abs=1e-5) - @unittest.skipUnless(torch.cuda.is_available(), "Cuda not available") - def test_forward_pass_cuda(self): + @pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda") + def test_forward_pass_cuda(self) -> None: loss = NegativeCosineSimilarity() for bsz in range(1, 20): x0 = torch.randn((bsz, 32)).cuda() @@ -27,4 +26,4 @@ def test_forward_pass_cuda(self): # symmetry l1 = loss(x0, x1) l2 = loss(x1, x0) - self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0) + assert l1 == pytest.approx(l2, abs=1e-5) From a7eb3ab16f031b7420c315abab14de228e51f084 Mon Sep 17 00:00:00 2001 From: Philipp Wirth Date: Fri, 27 Dec 2024 16:04:07 +0100 Subject: [PATCH 6/7] Re Malte --- tests/loss/test_swav_loss.py | 81 +++++++++++++++++------------------- 1 file changed, 38 insertions(+), 43 deletions(-) diff --git a/tests/loss/test_swav_loss.py b/tests/loss/test_swav_loss.py index dd3529734..1d32b4aed 100644 --- a/tests/loss/test_swav_loss.py +++ b/tests/loss/test_swav_loss.py @@ -22,67 +22,62 @@ def test__sinkhorn_gather_distributed_dist_not_available( SwaVLoss(sinkhorn_gather_distributed=True) mock_is_available.assert_called_once() - def test_forward_pass(self) -> None: + @pytest.mark.parametrize("n_low_res", range(6)) + @pytest.mark.parametrize("sinkhorn_iterations", range(3)) + def test_forward_pass(self, n_low_res: int, sinkhorn_iterations: int) -> None: n = 32 n_high_res = 2 high_res = [torch.eye(32, 32) for i in range(n_high_res)] + criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) + low_res = [torch.eye(n, n) for i in range(n_low_res)] + loss = criterion(high_res, low_res) + # loss should be almost zero for unit matrix + assert loss.cpu().numpy() < 0.5 - for n_low_res in range(6): - for sinkhorn_iterations in range(3): - criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) - low_res = [torch.eye(n, n) for i in range(n_low_res)] - loss = criterion(high_res, low_res) - # loss should be almost zero for unit matrix - assert loss.cpu().numpy() < 0.5 - - def test_forward_pass_queue(self) -> None: + @pytest.mark.parametrize("n_low_res", range(6)) + @pytest.mark.parametrize("sinkhorn_iterations", range(3)) + def test_forward_pass_queue(self, n_low_res: int, sinkhorn_iterations: int) -> None: n = 32 n_high_res = 2 high_res = [torch.eye(32, 32) for i in range(n_high_res)] queue = [torch.eye(128, 32) for i in range(n_high_res)] + criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) + low_res = [torch.eye(n, n) for i in range(n_low_res)] + loss = criterion(high_res, low_res, queue) + # loss should be almost zero for unit matrix + assert loss.cpu().numpy() < 0.5 - for n_low_res in range(6): - for sinkhorn_iterations in range(3): - criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) - low_res = [torch.eye(n, n) for i in range(n_low_res)] - loss = criterion(high_res, low_res, queue) - # loss should be almost zero for unit matrix - assert loss.cpu().numpy() < 0.5 - - def test_forward_pass_bsz_1(self) -> None: + @pytest.mark.parametrize("n_low_res", range(6)) + @pytest.mark.parametrize("sinkhorn_iterations", range(3)) + def test_forward_pass_bsz_1(self, n_low_res: int, sinkhorn_iterations: int) -> None: n = 32 n_high_res = 2 high_res = [torch.eye(1, n) for i in range(n_high_res)] + criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) + low_res = [torch.eye(1, n) for i in range(n_low_res)] + criterion(high_res, low_res) - for n_low_res in range(6): - for sinkhorn_iterations in range(3): - criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) - low_res = [torch.eye(1, n) for i in range(n_low_res)] - criterion(high_res, low_res) - - def test_forward_pass_1d(self) -> None: + @pytest.mark.parametrize("n_low_res", range(6)) + @pytest.mark.parametrize("sinkhorn_iterations", range(3)) + def test_forward_pass_1d(self, n_low_res: int, sinkhorn_iterations: int) -> None: n = 32 n_high_res = 2 high_res = [torch.eye(n, 1) for i in range(n_high_res)] - - for n_low_res in range(6): - for sinkhorn_iterations in range(3): - criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) - low_res = [torch.eye(n, 1) for i in range(n_low_res)] - loss = criterion(high_res, low_res) - # loss should be almost zero for unit matrix - assert loss.cpu().numpy() < 0.5 + criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) + low_res = [torch.eye(n, 1) for i in range(n_low_res)] + loss = criterion(high_res, low_res) + # loss should be almost zero for unit matrix + assert loss.cpu().numpy() < 0.5 @pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda") - def test_forward_pass_cuda(self) -> None: + @pytest.mark.parametrize("n_low_res", range(6)) + @pytest.mark.parametrize("sinkhorn_iterations", range(3)) + def test_forward_pass_cuda(self, n_low_res: int, sinkhorn_iterations: int) -> None: n = 32 n_high_res = 2 high_res = [torch.eye(n, n).cuda() for i in range(n_high_res)] - - for n_low_res in range(6): - for sinkhorn_iterations in range(3): - criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) - low_res = [torch.eye(n, n).cuda() for i in range(n_low_res)] - loss = criterion(high_res, low_res) - # loss should be almost zero for unit matrix - assert loss.cpu().numpy() < 0.5 + criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) + low_res = [torch.eye(n, n).cuda() for i in range(n_low_res)] + loss = criterion(high_res, low_res) + # loss should be almost zero for unit matrix + assert loss.cpu().numpy() < 0.5 From 1d2c7e604527975d50ae61faedbd69e203e9df5a Mon Sep 17 00:00:00 2001 From: Philipp Wirth Date: Fri, 27 Dec 2024 16:10:18 +0100 Subject: [PATCH 7/7] Re Malte --- tests/loss/test_negative_cosine_similarity.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/loss/test_negative_cosine_similarity.py b/tests/loss/test_negative_cosine_similarity.py index 2f8c58a23..fa51905b4 100644 --- a/tests/loss/test_negative_cosine_similarity.py +++ b/tests/loss/test_negative_cosine_similarity.py @@ -5,25 +5,25 @@ class TestNegativeCosineSimilarity: - def test_forward_pass(self) -> None: + @pytest.mark.parametrize("bsz", range(1, 20)) + def test_forward_pass(self, bsz: int) -> None: loss = NegativeCosineSimilarity() - for bsz in range(1, 20): - x0 = torch.randn((bsz, 32)) - x1 = torch.randn((bsz, 32)) + x0 = torch.randn((bsz, 32)) + x1 = torch.randn((bsz, 32)) - # symmetry - l1 = loss(x0, x1) - l2 = loss(x1, x0) - assert l1 == pytest.approx(l2, abs=1e-5) + # symmetry + l1 = loss(x0, x1) + l2 = loss(x1, x0) + assert l1 == pytest.approx(l2, abs=1e-5) @pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda") - def test_forward_pass_cuda(self) -> None: + @pytest.mark.parametrize("bsz", range(1, 20)) + def test_forward_pass_cuda(self, bsz: int) -> None: loss = NegativeCosineSimilarity() - for bsz in range(1, 20): - x0 = torch.randn((bsz, 32)).cuda() - x1 = torch.randn((bsz, 32)).cuda() + x0 = torch.randn((bsz, 32)).cuda() + x1 = torch.randn((bsz, 32)).cuda() - # symmetry - l1 = loss(x0, x1) - l2 = loss(x1, x0) - assert l1 == pytest.approx(l2, abs=1e-5) + # symmetry + l1 = loss(x0, x1) + l2 = loss(x1, x0) + assert l1 == pytest.approx(l2, abs=1e-5)