Skip to content

Commit

Permalink
Typecheck rest of loss tests (#1768)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippmwirth authored Dec 30, 2024
1 parent 84aa00b commit f052cc2
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 28 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ exclude = '''(?x)(
tests/UNMOCKED_end2end_tests/delete_datasets_test_unmocked_cli.py |
tests/UNMOCKED_end2end_tests/create_custom_metadata_from_input_dir.py |
tests/UNMOCKED_end2end_tests/scripts_for_reproducing_problems/test_api_latency.py |
tests/loss/test_MemoryBank.py |
tests/core/test_Core.py |
tests/data/test_multi_view_collate.py |
tests/data/test_data_collate.py |
Expand Down
22 changes: 8 additions & 14 deletions tests/loss/test_MMCR_loss.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import unittest

import pytest
import torch

from lightly.loss.mmcr_loss import MMCRLoss


class testMMCRLoss(unittest.TestCase):
class TestMMCRLoss:
def test_forward(self) -> None:
bs = 3
dim = 128
Expand All @@ -15,11 +14,9 @@ def test_forward(self) -> None:
online = torch.randn(bs, k, dim)
momentum = torch.randn(bs, k, dim)

loss = loss_fn(online, momentum)

print(loss)
loss_fn(online, momentum)

@unittest.skipUnless(torch.cuda.is_available(), "cuda not available")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda")
def test_forward_cuda(self) -> None:
bs = 3
dim = 128
Expand All @@ -29,9 +26,7 @@ def test_forward_cuda(self) -> None:
online = torch.randn(bs, k, dim).cuda()
momentum = torch.randn(bs, k, dim).cuda()

loss = loss_fn(online, momentum)

print(loss)
loss_fn(online, momentum)

def test_loss_value(self) -> None:
"""If all values are zero, the loss should be zero."""
Expand All @@ -44,12 +39,11 @@ def test_loss_value(self) -> None:
momentum = torch.zeros(bs, k, dim)

loss = loss_fn(online, momentum)

self.assertTrue(loss == 0)
assert loss == 0.0

def test_lambda_value_error(self) -> None:
"""If lambda is negative, a ValueError should be raised."""
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
MMCRLoss(lmda=-1)

def test_shape_assertion_forward(self) -> None:
Expand All @@ -61,5 +55,5 @@ def test_shape_assertion_forward(self) -> None:
online = torch.randn(bs, k, dim)
momentum = torch.randn(bs, k, dim + 1)

with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
loss_fn(online, momentum)
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import unittest

import pytest
import torch

from lightly.loss.emp_ssl_loss import EMPSSLLoss


class testEMPSSSLLoss(unittest.TestCase):
class testEMPSSSLLoss:
def test_forward(self) -> None:
bs = 512
dim = 128
Expand All @@ -16,7 +15,7 @@ def test_forward(self) -> None:

loss_fn(x)

@unittest.skipUnless(torch.cuda.is_available(), "cuda not available")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda")
def test_forward_cuda(self) -> None:
bs = 512
dim = 128
Expand Down
15 changes: 6 additions & 9 deletions tests/loss/test_WMSELoss.py → tests/loss/test_wmse_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import unittest

import pytest
import torch

Expand All @@ -11,7 +9,7 @@
pytest.skip("torch.linalg.solve_triangular not available", allow_module_level=True)


class testWMSELoss(unittest.TestCase):
class TestWMSELoss:
def test_forward(self) -> None:
bs = 512
dim = 128
Expand All @@ -22,7 +20,7 @@ def test_forward(self) -> None:

loss_fn(x)

@unittest.skipUnless(torch.cuda.is_available(), "cuda not available")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda")
def test_forward_cuda(self) -> None:
bs = 512
dim = 128
Expand All @@ -43,21 +41,20 @@ def test_loss_value(self) -> None:
x = torch.randn(bs * num_samples, dim)

loss = loss_fn(x)

self.assertGreater(loss, 0)
assert loss > 0

def test_embedding_dim_error(self) -> None:
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
WMSELoss(embedding_dim=2, w_size=2)

def test_num_samples_error(self) -> None:
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
loss_fn = WMSELoss(num_samples=3)
x = torch.randn(5, 128)
loss_fn(x)

def test_w_size_error(self) -> None:
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
loss_fn = WMSELoss(w_size=5)
x = torch.randn(4, 128)
loss_fn(x)

0 comments on commit f052cc2

Please sign in to comment.