Skip to content

Commit

Permalink
Fix types
Browse files Browse the repository at this point in the history
  • Loading branch information
guarin committed Nov 24, 2023
1 parent 247b173 commit 903c3f1
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
10 changes: 6 additions & 4 deletions lightly/models/modules/memory_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# All Rights Reserved

import warnings
from typing import Sequence, Union
from typing import Sequence, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -72,11 +72,13 @@ def __init__(
self.size = size_tuple
self.gather_distributed = gather_distributed
self.feature_dim_first = feature_dim_first
self.bank: Tensor
self.register_buffer(
"bank",
tensor=torch.empty(size=self.size, dtype=torch.float),
persistent=False,
)
self.bank_ptr: Tensor
self.register_buffer(
"bank_ptr",
tensor=torch.empty(1, dtype=torch.long),
Expand All @@ -97,7 +99,7 @@ def __init__(
self._init_memory_bank(dim=None)

@torch.no_grad()
def _init_memory_bank(self, dim: Union[Sequence[int], None]):
def _init_memory_bank(self, dim: Union[Sequence[int], None]) -> None:
"""Initialize the memory bank if it's empty.
Args:
Expand All @@ -116,7 +118,7 @@ def _init_memory_bank(self, dim: Union[Sequence[int], None]):
self.bank_ptr = torch.zeros(1).type_as(self.bank_ptr)

@torch.no_grad()
def _dequeue_and_enqueue(self, batch: Tensor):
def _dequeue_and_enqueue(self, batch: Tensor) -> None:
"""Dequeue the oldest batch and add the latest one
Args:
Expand All @@ -141,7 +143,7 @@ def forward(
output: Tensor,
labels: Union[Tensor, None] = None,
update: bool = False,
):
) -> Tuple[Tensor, Union[Tensor, None]]:
"""Query memory bank for additional negative samples
Args:
Expand Down
8 changes: 4 additions & 4 deletions lightly/utils/benchmarking/metric_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _append_metrics(
self, metrics_dict: Dict[str, List[float]], trainer: Trainer
) -> None:
for name, value in trainer.callback_metrics.items():
if isinstance(value, float) or (
isinstance(value, Tensor) and value.numel() == 1
):
metrics_dict.setdefault(name, []).append(float(value))
if isinstance(value, Tensor) and value.numel() > 1:
# Ignore non-scalar tensors.
continue

Check warning on line 65 in lightly/utils/benchmarking/metric_callback.py

View check run for this annotation

Codecov / codecov/patch

lightly/utils/benchmarking/metric_callback.py#L65

Added line #L65 was not covered by tests
metrics_dict.setdefault(name, []).append(float(value))
8 changes: 4 additions & 4 deletions tests/models/modules/test_memory_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@


class TestNTXentLoss(unittest.TestCase):
def test_init__negative_size(self):
def test_init__negative_size(self) -> None:
with self.assertRaises(ValueError):
MemoryBankModule(size=-1)

def test_forward_easy(self):
def test_forward_easy(self) -> None:
bsz = 3
dim, size = 2, 9
n = 33 * bsz
Expand All @@ -37,7 +37,7 @@ def test_forward_easy(self):

ptr = (ptr + bsz) % size

def test_forward(self):
def test_forward(self) -> None:
bsz = 3
dim, size = 2, 10
n = 33 * bsz
Expand All @@ -50,7 +50,7 @@ def test_forward(self):
_, _ = memory_bank(output)

@unittest.skipUnless(torch.cuda.is_available(), "cuda not available")
def test_forward__cuda(self):
def test_forward__cuda(self) -> None:
bsz = 3
dim, size = 2, 10
n = 33 * bsz
Expand Down

0 comments on commit 903c3f1

Please sign in to comment.