Skip to content

Commit

Permalink
Merge branch 'philipp-type-check-pmsn-loss' into philipp-type-check-s…
Browse files Browse the repository at this point in the history
…wav-loss
  • Loading branch information
philippmwirth committed Dec 27, 2024
2 parents 97d60ce + b73804d commit f400be5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tests/loss/test_msn_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_prototype_probabilitiy(self) -> None:
prototypes = F.normalize(torch.rand((4, 10)), dim=1)
prob = msn_loss.prototype_probabilities(queries, prototypes, temperature=0.5)
assert prob.shape == (8, 4)
assert prob.max() == 1.0
assert prob.max() < 1.0
assert prob.min() > 0.0

# verify sharpening
Expand Down
4 changes: 2 additions & 2 deletions tests/loss/test_pmsn_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_forward(self) -> None:
prototypes = torch.rand((4, 10), requires_grad=True)
criterion(anchors, targets, prototypes)

@pytest.mark.skipif(torch.cuda.is_available(), reason="No cuda")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda")
def test_forward_cuda(self) -> None:
torch.manual_seed(0)
criterion = PMSNLoss()
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_forward(self) -> None:
prototypes = torch.rand((4, 10), requires_grad=True)
criterion(anchors, targets, prototypes)

@pytest.mark.skipif(torch.cuda.is_available(), reason="No cuda")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda")
def test_forward_cuda(self) -> None:
torch.manual_seed(0)
criterion = PMSNCustomLoss(target_distribution=_uniform_distribution)
Expand Down

0 comments on commit f400be5

Please sign in to comment.