Skip to content

Commit

Permalink
Re Guarin
Browse files Browse the repository at this point in the history
  • Loading branch information
philippmwirth committed Nov 21, 2023
1 parent b35f65b commit ed8ea4e
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 41 deletions.
19 changes: 10 additions & 9 deletions lightly/loss/memory_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional, Tuple, Union

import torch
from torch import Tensor


class MemoryBankModule(torch.nn.Module):
Expand All @@ -26,8 +27,8 @@ class MemoryBankModule(torch.nn.Module):
>>> def __init__(self, memory_bank_size: int = 2 ** 16):
>>> super(MyLossFunction, self).__init__(memory_bank_size)
>>>
>>> def forward(self, output: torch.Tensor,
>>> labels: torch.Tensor = None):
>>> def forward(self, output: Tensor,
>>> labels: Tensor = None):
>>>
>>> output, negatives = super(
>>> MyLossFunction, self).forward(output)
Expand Down Expand Up @@ -67,12 +68,12 @@ def _init_memory_bank(self, dim: int) -> None:
# we could use register buffers like in the moco repo
# https://github.com/facebookresearch/moco but we don't
# want to pollute our checkpoints
bank: torch.Tensor = torch.randn(dim, self.size).type_as(self.bank)
self.bank: torch.Tensor = torch.nn.functional.normalize(bank, dim=0)
self.bank_ptr: torch.Tensor = torch.zeros(1).type_as(self.bank_ptr)
bank: Tensor = torch.randn(dim, self.size).type_as(self.bank)
self.bank: Tensor = torch.nn.functional.normalize(bank, dim=0)
self.bank_ptr: Tensor = torch.zeros(1).type_as(self.bank_ptr)

@torch.no_grad()
def _dequeue_and_enqueue(self, batch: torch.Tensor) -> None:
def _dequeue_and_enqueue(self, batch: Tensor) -> None:
"""Dequeue the oldest batch and add the latest one
Args:
Expand All @@ -92,10 +93,10 @@ def _dequeue_and_enqueue(self, batch: torch.Tensor) -> None:

def forward(
self,
output: torch.Tensor,
labels: Optional[torch.Tensor] = None,
output: Tensor,
labels: Optional[Tensor] = None,
update: bool = False,
) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
) -> Union[Tuple[Tensor, Optional[Tensor]], Tensor]:
"""Query memory bank for additional negative samples
Args:
Expand Down
9 changes: 4 additions & 5 deletions lightly/models/_momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.parameter import Parameter


Expand Down Expand Up @@ -46,7 +47,7 @@ class _MomentumEncoderMixin:
>>> # initialize momentum_backbone and momentum_projection_head
>>> self._init_momentum_encoder()
>>>
>>> def forward(self, x: torch.Tensor):
>>> def forward(self, x: Tensor):
>>>
>>> # do the momentum update
>>> self._momentum_update(0.999)
Expand Down Expand Up @@ -89,16 +90,14 @@ def _momentum_update(self, m: float = 0.999) -> None:
)

@torch.no_grad()
def _batch_shuffle(self, batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def _batch_shuffle(self, batch: Tensor) -> Tuple[Tensor, Tensor]:
"""Returns the shuffled batch and the indices to undo."""
batch_size = batch.shape[0]
shuffle = torch.randperm(batch_size, device=batch.device)
return batch[shuffle], shuffle

@torch.no_grad()
def _batch_unshuffle(
self, batch: torch.Tensor, shuffle: torch.Tensor
) -> torch.Tensor:
def _batch_unshuffle(self, batch: Tensor, shuffle: Tensor) -> Tensor:
"""Returns the unshuffled batch."""
unshuffle = torch.argsort(shuffle)
return batch[unshuffle]
15 changes: 9 additions & 6 deletions lightly/models/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
import torch.nn as nn
from torch import Tensor


class SplitBatchNorm(nn.BatchNorm2d):
Expand All @@ -25,12 +26,10 @@ class SplitBatchNorm(nn.BatchNorm2d):
"""

running_mean: torch.Tensor
running_var: torch.Tensor

def __init__(self, num_features: int, num_splits: int, **kw: Any) -> None:
super().__init__(num_features, **kw)
self.num_splits = num_splits
# Register buffers
self.register_buffer(
"running_mean", torch.zeros(num_features * self.num_splits)
)
Expand All @@ -39,16 +38,18 @@ def __init__(self, num_features: int, num_splits: int, **kw: Any) -> None:
def train(self, mode: bool = True) -> SplitBatchNorm:
# lazily collate stats when we are going to use them
if (self.training is True) and (mode is False):
assert self.running_mean is not None
self.running_mean = torch.mean(
self.running_mean.view(self.num_splits, self.num_features), dim=0
).repeat(self.num_splits)
assert self.running_var is not None
self.running_var = torch.mean(
self.running_var.view(self.num_splits, self.num_features), dim=0
).repeat(self.num_splits)

return super().train(mode)

def forward(self, input: torch.Tensor) -> torch.Tensor:
def forward(self, input: Tensor) -> Tensor:
"""Computes the SplitBatchNorm on the input."""
# get input shape
N, C, H, W = input.shape
Expand All @@ -67,10 +68,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.eps,
).view(N, C, H, W)
else:
# We have to ignore the type errors here, because we know that running_mean
# and running_var are not None, but the type checker does not.
result = nn.functional.batch_norm(
input,
self.running_mean[: self.num_features],
self.running_var[: self.num_features],
self.running_mean[: self.num_features], # type: ignore[index]
self.running_var[: self.num_features], # type: ignore[index]
self.weight,
self.bias,
False,
Expand Down
23 changes: 12 additions & 11 deletions lightly/models/modules/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import torch.nn as nn
from torch import Tensor

from lightly.models import utils

Expand Down Expand Up @@ -46,15 +47,15 @@ def __init__(
layers.append(non_linearity)
self.layers = nn.Sequential(*layers)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: Tensor) -> Tensor:
"""Computes one forward pass through the projection head.
Args:
x:
Input of shape bsz x num_ftrs.
"""
projection: torch.Tensor = self.layers(x)
projection: Tensor = self.layers(x)
return projection


Expand Down Expand Up @@ -325,16 +326,16 @@ class SMoGPrototypes(nn.Module):

def __init__(
self,
group_features: torch.Tensor,
group_features: Tensor,
beta: float,
):
super(SMoGPrototypes, self).__init__()
self.group_features = nn.Parameter(group_features, requires_grad=False)
self.beta = beta

def forward(
self, x: torch.Tensor, group_features: torch.Tensor, temperature: float = 0.1
) -> torch.Tensor:
self, x: Tensor, group_features: Tensor, temperature: float = 0.1
) -> Tensor:
"""Computes the logits for given model outputs and group features.
Args:
Expand All @@ -354,7 +355,7 @@ def forward(
logits = torch.mm(x, group_features.t())
return logits / temperature

def get_updated_group_features(self, x: torch.Tensor) -> torch.Tensor:
def get_updated_group_features(self, x: Tensor) -> Tensor:
"""Performs the synchronous momentum update of the group vectors.
Args:
Expand All @@ -375,12 +376,12 @@ def get_updated_group_features(self, x: torch.Tensor) -> torch.Tensor:

return group_features

def set_group_features(self, x: torch.Tensor) -> None:
def set_group_features(self, x: Tensor) -> None:
"""Sets the group features and asserts they don't require gradient."""
self.group_features.data = x.to(self.group_features.device)

@torch.no_grad()
def assign_groups(self, x: torch.Tensor) -> torch.Tensor:
def assign_groups(self, x: Tensor) -> Tensor:
"""Assigns each representation in x to a group based on cosine similarity.
Args:
Expand Down Expand Up @@ -526,8 +527,8 @@ def __init__(
self.n_steps_frozen_prototypes = n_steps_frozen_prototypes

def forward(
self, x: torch.Tensor, step: Optional[int] = None
) -> Union[torch.Tensor, List[torch.Tensor]]:
self, x: Tensor, step: Optional[int] = None
) -> Union[Tensor, List[Tensor]]:
self._freeze_prototypes_if_required(step)
out = []
for layer in self.heads:
Expand Down Expand Up @@ -633,7 +634,7 @@ def _init_weights(self, module: nn.Module) -> None:
if module.bias is not None:
nn.init.constant_(module.bias, 0)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: Tensor) -> Tensor:
"""Computes one forward pass through the head."""
x = self.layers(x)
# l2 normalization
Expand Down
8 changes: 4 additions & 4 deletions lightly/models/modules/nn_memory_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional

import torch
from torch import Tensor

from lightly.loss.memory_bank import MemoryBankModule

Expand Down Expand Up @@ -43,12 +44,11 @@ def __init__(self, size: int = 2**16):
raise ValueError(f"Memory bank size must be positive, got {size}.")
super(NNMemoryBankModule, self).__init__(size)

def forward(
def forward( # type: ignore[override] # TODO(Philipp, 11/23): Fix signature to match parent class.
self,
output: torch.Tensor,
labels: Optional[torch.Tensor] = None,
output: Tensor,
update: bool = False,
) -> torch.Tensor:
) -> Tensor:
"""Returns nearest neighbour of output tensor from memory bank
Args:
Expand Down
12 changes: 6 additions & 6 deletions lightly/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from lightly.models.batchnorm import get_norm_layer

Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
get_norm_layer(self.expansion * planes, num_splits),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: Tensor) -> Tensor:
"""Forward pass through basic ResNet block.
Args:
Expand All @@ -74,7 +74,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Tensor of shape bsz x channels x W x H
"""

out: torch.Tensor = self.conv1(x)
out: Tensor = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)

Expand Down Expand Up @@ -133,7 +133,7 @@ def __init__(
get_norm_layer(self.expansion * planes, num_splits),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: Tensor) -> Tensor:
"""Forward pass through bottleneck ResNet block.
Args:
Expand All @@ -144,7 +144,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Tensor of shape bsz x channels x W x H
"""

out: torch.Tensor = self.conv1(x)
out: Tensor = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)

Expand Down Expand Up @@ -224,7 +224,7 @@ def _make_layer(
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: Tensor) -> Tensor:
"""Forward pass through ResNet.
Args:
Expand Down

0 comments on commit ed8ea4e

Please sign in to comment.