From 903c3f1e7f70c8dc9ac9c2e47191a1ef7d5cde50 Mon Sep 17 00:00:00 2001 From: guarin Date: Fri, 24 Nov 2023 10:36:44 +0000 Subject: [PATCH] Fix types --- lightly/models/modules/memory_bank.py | 10 ++++++---- lightly/utils/benchmarking/metric_callback.py | 8 ++++---- tests/models/modules/test_memory_bank.py | 8 ++++---- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/lightly/models/modules/memory_bank.py b/lightly/models/modules/memory_bank.py index 9a1790564..c55516f36 100644 --- a/lightly/models/modules/memory_bank.py +++ b/lightly/models/modules/memory_bank.py @@ -4,7 +4,7 @@ # All Rights Reserved import warnings -from typing import Sequence, Union +from typing import Sequence, Tuple, Union import torch from torch import Tensor @@ -72,11 +72,13 @@ def __init__( self.size = size_tuple self.gather_distributed = gather_distributed self.feature_dim_first = feature_dim_first + self.bank: Tensor self.register_buffer( "bank", tensor=torch.empty(size=self.size, dtype=torch.float), persistent=False, ) + self.bank_ptr: Tensor self.register_buffer( "bank_ptr", tensor=torch.empty(1, dtype=torch.long), @@ -97,7 +99,7 @@ def __init__( self._init_memory_bank(dim=None) @torch.no_grad() - def _init_memory_bank(self, dim: Union[Sequence[int], None]): + def _init_memory_bank(self, dim: Union[Sequence[int], None]) -> None: """Initialize the memory bank if it's empty. Args: @@ -116,7 +118,7 @@ def _init_memory_bank(self, dim: Union[Sequence[int], None]): self.bank_ptr = torch.zeros(1).type_as(self.bank_ptr) @torch.no_grad() - def _dequeue_and_enqueue(self, batch: Tensor): + def _dequeue_and_enqueue(self, batch: Tensor) -> None: """Dequeue the oldest batch and add the latest one Args: @@ -141,7 +143,7 @@ def forward( output: Tensor, labels: Union[Tensor, None] = None, update: bool = False, - ): + ) -> Tuple[Tensor, Union[Tensor, None]]: """Query memory bank for additional negative samples Args: diff --git a/lightly/utils/benchmarking/metric_callback.py b/lightly/utils/benchmarking/metric_callback.py index b635b2fbb..8dd38f2e8 100644 --- a/lightly/utils/benchmarking/metric_callback.py +++ b/lightly/utils/benchmarking/metric_callback.py @@ -60,7 +60,7 @@ def _append_metrics( self, metrics_dict: Dict[str, List[float]], trainer: Trainer ) -> None: for name, value in trainer.callback_metrics.items(): - if isinstance(value, float) or ( - isinstance(value, Tensor) and value.numel() == 1 - ): - metrics_dict.setdefault(name, []).append(float(value)) + if isinstance(value, Tensor) and value.numel() > 1: + # Ignore non-scalar tensors. + continue + metrics_dict.setdefault(name, []).append(float(value)) diff --git a/tests/models/modules/test_memory_bank.py b/tests/models/modules/test_memory_bank.py index 695c83e1b..0f6ccae53 100644 --- a/tests/models/modules/test_memory_bank.py +++ b/tests/models/modules/test_memory_bank.py @@ -8,11 +8,11 @@ class TestNTXentLoss(unittest.TestCase): - def test_init__negative_size(self): + def test_init__negative_size(self) -> None: with self.assertRaises(ValueError): MemoryBankModule(size=-1) - def test_forward_easy(self): + def test_forward_easy(self) -> None: bsz = 3 dim, size = 2, 9 n = 33 * bsz @@ -37,7 +37,7 @@ def test_forward_easy(self): ptr = (ptr + bsz) % size - def test_forward(self): + def test_forward(self) -> None: bsz = 3 dim, size = 2, 10 n = 33 * bsz @@ -50,7 +50,7 @@ def test_forward(self): _, _ = memory_bank(output) @unittest.skipUnless(torch.cuda.is_available(), "cuda not available") - def test_forward__cuda(self): + def test_forward__cuda(self) -> None: bsz = 3 dim, size = 2, 10 n = 33 * bsz