Skip to content

Commit

Permalink
Typecheck negative cosine similarity
Browse files Browse the repository at this point in the history
  • Loading branch information
philippmwirth committed Dec 27, 2024
1 parent f400be5 commit 38e8adf
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ exclude = '''(?x)(
lightly/cli/train_cli.py |
lightly/cli/_cli_simclr.py |
lightly/cli/_helpers.py |
lightly/loss/negative_cosine_similarity.py |
lightly/loss/hypersphere_loss.py |
lightly/loss/dino_loss.py |
lightly/loss/sym_neg_cos_sim_loss.py |
Expand Down Expand Up @@ -244,7 +243,6 @@ exclude = '''(?x)(
tests/UNMOCKED_end2end_tests/delete_datasets_test_unmocked_cli.py |
tests/UNMOCKED_end2end_tests/create_custom_metadata_from_input_dir.py |
tests/UNMOCKED_end2end_tests/scripts_for_reproducing_problems/test_api_latency.py |
tests/loss/test_NegativeCosineSimilarity.py |
tests/loss/test_DINOLoss.py |
tests/loss/test_VICRegLLoss.py |
tests/loss/test_CO2Regularizer.py |
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import unittest

import pytest
import torch

from lightly.loss import NegativeCosineSimilarity


class TestNegativeCosineSimilarity(unittest.TestCase):
def test_forward_pass(self):
class TestNegativeCosineSimilarity:
def test_forward_pass(self) -> None:
loss = NegativeCosineSimilarity()
for bsz in range(1, 20):
x0 = torch.randn((bsz, 32))
Expand All @@ -15,10 +14,10 @@ def test_forward_pass(self):
# symmetry
l1 = loss(x0, x1)
l2 = loss(x1, x0)
self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0)
assert l1 == pytest.approx(l2, abs=1e-5)

@unittest.skipUnless(torch.cuda.is_available(), "Cuda not available")
def test_forward_pass_cuda(self):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda")
def test_forward_pass_cuda(self) -> None:
loss = NegativeCosineSimilarity()
for bsz in range(1, 20):
x0 = torch.randn((bsz, 32)).cuda()
Expand All @@ -27,4 +26,4 @@ def test_forward_pass_cuda(self):
# symmetry
l1 = loss(x0, x1)
l2 = loss(x1, x0)
self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0)
assert l1 == pytest.approx(l2, abs=1e-5)

0 comments on commit 38e8adf

Please sign in to comment.