From ed8ea4ef777087e25935702dc0c6bd2cd5dcef8c Mon Sep 17 00:00:00 2001 From: philippmwirth Date: Tue, 21 Nov 2023 11:02:15 +0000 Subject: [PATCH] Re Guarin --- lightly/loss/memory_bank.py | 19 ++++++++++--------- lightly/models/_momentum.py | 9 ++++----- lightly/models/batchnorm.py | 15 +++++++++------ lightly/models/modules/heads.py | 23 ++++++++++++----------- lightly/models/modules/nn_memory_bank.py | 8 ++++---- lightly/models/resnet.py | 12 ++++++------ 6 files changed, 45 insertions(+), 41 deletions(-) diff --git a/lightly/loss/memory_bank.py b/lightly/loss/memory_bank.py index ccdb857c9..998917715 100644 --- a/lightly/loss/memory_bank.py +++ b/lightly/loss/memory_bank.py @@ -6,6 +6,7 @@ from typing import Optional, Tuple, Union import torch +from torch import Tensor class MemoryBankModule(torch.nn.Module): @@ -26,8 +27,8 @@ class MemoryBankModule(torch.nn.Module): >>> def __init__(self, memory_bank_size: int = 2 ** 16): >>> super(MyLossFunction, self).__init__(memory_bank_size) >>> - >>> def forward(self, output: torch.Tensor, - >>> labels: torch.Tensor = None): + >>> def forward(self, output: Tensor, + >>> labels: Tensor = None): >>> >>> output, negatives = super( >>> MyLossFunction, self).forward(output) @@ -67,12 +68,12 @@ def _init_memory_bank(self, dim: int) -> None: # we could use register buffers like in the moco repo # https://github.com/facebookresearch/moco but we don't # want to pollute our checkpoints - bank: torch.Tensor = torch.randn(dim, self.size).type_as(self.bank) - self.bank: torch.Tensor = torch.nn.functional.normalize(bank, dim=0) - self.bank_ptr: torch.Tensor = torch.zeros(1).type_as(self.bank_ptr) + bank: Tensor = torch.randn(dim, self.size).type_as(self.bank) + self.bank: Tensor = torch.nn.functional.normalize(bank, dim=0) + self.bank_ptr: Tensor = torch.zeros(1).type_as(self.bank_ptr) @torch.no_grad() - def _dequeue_and_enqueue(self, batch: torch.Tensor) -> None: + def _dequeue_and_enqueue(self, batch: Tensor) -> None: """Dequeue the oldest batch and add the latest one Args: @@ -92,10 +93,10 @@ def _dequeue_and_enqueue(self, batch: torch.Tensor) -> None: def forward( self, - output: torch.Tensor, - labels: Optional[torch.Tensor] = None, + output: Tensor, + labels: Optional[Tensor] = None, update: bool = False, - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: + ) -> Union[Tuple[Tensor, Optional[Tensor]], Tensor]: """Query memory bank for additional negative samples Args: diff --git a/lightly/models/_momentum.py b/lightly/models/_momentum.py index c7beed30f..94192cc24 100644 --- a/lightly/models/_momentum.py +++ b/lightly/models/_momentum.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn +from torch import Tensor from torch.nn.parameter import Parameter @@ -46,7 +47,7 @@ class _MomentumEncoderMixin: >>> # initialize momentum_backbone and momentum_projection_head >>> self._init_momentum_encoder() >>> - >>> def forward(self, x: torch.Tensor): + >>> def forward(self, x: Tensor): >>> >>> # do the momentum update >>> self._momentum_update(0.999) @@ -89,16 +90,14 @@ def _momentum_update(self, m: float = 0.999) -> None: ) @torch.no_grad() - def _batch_shuffle(self, batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def _batch_shuffle(self, batch: Tensor) -> Tuple[Tensor, Tensor]: """Returns the shuffled batch and the indices to undo.""" batch_size = batch.shape[0] shuffle = torch.randperm(batch_size, device=batch.device) return batch[shuffle], shuffle @torch.no_grad() - def _batch_unshuffle( - self, batch: torch.Tensor, shuffle: torch.Tensor - ) -> torch.Tensor: + def _batch_unshuffle(self, batch: Tensor, shuffle: Tensor) -> Tensor: """Returns the unshuffled batch.""" unshuffle = torch.argsort(shuffle) return batch[unshuffle] diff --git a/lightly/models/batchnorm.py b/lightly/models/batchnorm.py index e8f7c22a6..7f05fe48e 100644 --- a/lightly/models/batchnorm.py +++ b/lightly/models/batchnorm.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn +from torch import Tensor class SplitBatchNorm(nn.BatchNorm2d): @@ -25,12 +26,10 @@ class SplitBatchNorm(nn.BatchNorm2d): """ - running_mean: torch.Tensor - running_var: torch.Tensor - def __init__(self, num_features: int, num_splits: int, **kw: Any) -> None: super().__init__(num_features, **kw) self.num_splits = num_splits + # Register buffers self.register_buffer( "running_mean", torch.zeros(num_features * self.num_splits) ) @@ -39,16 +38,18 @@ def __init__(self, num_features: int, num_splits: int, **kw: Any) -> None: def train(self, mode: bool = True) -> SplitBatchNorm: # lazily collate stats when we are going to use them if (self.training is True) and (mode is False): + assert self.running_mean is not None self.running_mean = torch.mean( self.running_mean.view(self.num_splits, self.num_features), dim=0 ).repeat(self.num_splits) + assert self.running_var is not None self.running_var = torch.mean( self.running_var.view(self.num_splits, self.num_features), dim=0 ).repeat(self.num_splits) return super().train(mode) - def forward(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, input: Tensor) -> Tensor: """Computes the SplitBatchNorm on the input.""" # get input shape N, C, H, W = input.shape @@ -67,10 +68,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.eps, ).view(N, C, H, W) else: + # We have to ignore the type errors here, because we know that running_mean + # and running_var are not None, but the type checker does not. result = nn.functional.batch_norm( input, - self.running_mean[: self.num_features], - self.running_var[: self.num_features], + self.running_mean[: self.num_features], # type: ignore[index] + self.running_var[: self.num_features], # type: ignore[index] self.weight, self.bias, False, diff --git a/lightly/models/modules/heads.py b/lightly/models/modules/heads.py index 5325b7464..541b49a29 100644 --- a/lightly/models/modules/heads.py +++ b/lightly/models/modules/heads.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn +from torch import Tensor from lightly.models import utils @@ -46,7 +47,7 @@ def __init__( layers.append(non_linearity) self.layers = nn.Sequential(*layers) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: """Computes one forward pass through the projection head. Args: @@ -54,7 +55,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Input of shape bsz x num_ftrs. """ - projection: torch.Tensor = self.layers(x) + projection: Tensor = self.layers(x) return projection @@ -325,7 +326,7 @@ class SMoGPrototypes(nn.Module): def __init__( self, - group_features: torch.Tensor, + group_features: Tensor, beta: float, ): super(SMoGPrototypes, self).__init__() @@ -333,8 +334,8 @@ def __init__( self.beta = beta def forward( - self, x: torch.Tensor, group_features: torch.Tensor, temperature: float = 0.1 - ) -> torch.Tensor: + self, x: Tensor, group_features: Tensor, temperature: float = 0.1 + ) -> Tensor: """Computes the logits for given model outputs and group features. Args: @@ -354,7 +355,7 @@ def forward( logits = torch.mm(x, group_features.t()) return logits / temperature - def get_updated_group_features(self, x: torch.Tensor) -> torch.Tensor: + def get_updated_group_features(self, x: Tensor) -> Tensor: """Performs the synchronous momentum update of the group vectors. Args: @@ -375,12 +376,12 @@ def get_updated_group_features(self, x: torch.Tensor) -> torch.Tensor: return group_features - def set_group_features(self, x: torch.Tensor) -> None: + def set_group_features(self, x: Tensor) -> None: """Sets the group features and asserts they don't require gradient.""" self.group_features.data = x.to(self.group_features.device) @torch.no_grad() - def assign_groups(self, x: torch.Tensor) -> torch.Tensor: + def assign_groups(self, x: Tensor) -> Tensor: """Assigns each representation in x to a group based on cosine similarity. Args: @@ -526,8 +527,8 @@ def __init__( self.n_steps_frozen_prototypes = n_steps_frozen_prototypes def forward( - self, x: torch.Tensor, step: Optional[int] = None - ) -> Union[torch.Tensor, List[torch.Tensor]]: + self, x: Tensor, step: Optional[int] = None + ) -> Union[Tensor, List[Tensor]]: self._freeze_prototypes_if_required(step) out = [] for layer in self.heads: @@ -633,7 +634,7 @@ def _init_weights(self, module: nn.Module) -> None: if module.bias is not None: nn.init.constant_(module.bias, 0) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: """Computes one forward pass through the head.""" x = self.layers(x) # l2 normalization diff --git a/lightly/models/modules/nn_memory_bank.py b/lightly/models/modules/nn_memory_bank.py index a045bca59..59abb65a5 100644 --- a/lightly/models/modules/nn_memory_bank.py +++ b/lightly/models/modules/nn_memory_bank.py @@ -6,6 +6,7 @@ from typing import Optional import torch +from torch import Tensor from lightly.loss.memory_bank import MemoryBankModule @@ -43,12 +44,11 @@ def __init__(self, size: int = 2**16): raise ValueError(f"Memory bank size must be positive, got {size}.") super(NNMemoryBankModule, self).__init__(size) - def forward( + def forward( # type: ignore[override] # TODO(Philipp, 11/23): Fix signature to match parent class. self, - output: torch.Tensor, - labels: Optional[torch.Tensor] = None, + output: Tensor, update: bool = False, - ) -> torch.Tensor: + ) -> Tensor: """Returns nearest neighbour of output tensor from memory bank Args: diff --git a/lightly/models/resnet.py b/lightly/models/resnet.py index 93abfae10..6bff53626 100644 --- a/lightly/models/resnet.py +++ b/lightly/models/resnet.py @@ -14,9 +14,9 @@ from typing import List -import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor from lightly.models.batchnorm import get_norm_layer @@ -63,7 +63,7 @@ def __init__( get_norm_layer(self.expansion * planes, num_splits), ) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: """Forward pass through basic ResNet block. Args: @@ -74,7 +74,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Tensor of shape bsz x channels x W x H """ - out: torch.Tensor = self.conv1(x) + out: Tensor = self.conv1(x) out = self.bn1(out) out = F.relu(out) @@ -133,7 +133,7 @@ def __init__( get_norm_layer(self.expansion * planes, num_splits), ) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: """Forward pass through bottleneck ResNet block. Args: @@ -144,7 +144,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Tensor of shape bsz x channels x W x H """ - out: torch.Tensor = self.conv1(x) + out: Tensor = self.conv1(x) out = self.bn1(out) out = F.relu(out) @@ -224,7 +224,7 @@ def _make_layer( self.in_planes = planes * block.expansion return nn.Sequential(*layers) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: """Forward pass through ResNet. Args: