Skip to content

Commit

Permalink
Typecheck part of models (#1430)
Browse files Browse the repository at this point in the history
  • Loading branch information
IgorSusmelj authored Nov 21, 2023
1 parent 8b6acf4 commit 514ffd7
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 68 deletions.
24 changes: 14 additions & 10 deletions lightly/loss/memory_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

import functools
from typing import Optional, Tuple, Union

import torch
from torch import Tensor


class MemoryBankModule(torch.nn.Module):
Expand All @@ -26,8 +27,8 @@ class MemoryBankModule(torch.nn.Module):
>>> def __init__(self, memory_bank_size: int = 2 ** 16):
>>> super(MyLossFunction, self).__init__(memory_bank_size)
>>>
>>> def forward(self, output: torch.Tensor,
>>> labels: torch.Tensor = None):
>>> def forward(self, output: Tensor,
>>> labels: Tensor = None):
>>>
>>> output, negatives = super(
>>> MyLossFunction, self).forward(output)
Expand Down Expand Up @@ -55,7 +56,7 @@ def __init__(self, size: int = 2**16):
)

@torch.no_grad()
def _init_memory_bank(self, dim: int):
def _init_memory_bank(self, dim: int) -> None:
"""Initialize the memory bank if it's empty
Args:
Expand All @@ -67,12 +68,12 @@ def _init_memory_bank(self, dim: int):
# we could use register buffers like in the moco repo
# https://github.com/facebookresearch/moco but we don't
# want to pollute our checkpoints
self.bank = torch.randn(dim, self.size).type_as(self.bank)
self.bank = torch.nn.functional.normalize(self.bank, dim=0)
self.bank_ptr = torch.zeros(1).type_as(self.bank_ptr)
bank: Tensor = torch.randn(dim, self.size).type_as(self.bank)
self.bank: Tensor = torch.nn.functional.normalize(bank, dim=0)
self.bank_ptr: Tensor = torch.zeros(1).type_as(self.bank_ptr)

@torch.no_grad()
def _dequeue_and_enqueue(self, batch: torch.Tensor):
def _dequeue_and_enqueue(self, batch: Tensor) -> None:
"""Dequeue the oldest batch and add the latest one
Args:
Expand All @@ -91,8 +92,11 @@ def _dequeue_and_enqueue(self, batch: torch.Tensor):
self.bank_ptr[0] = ptr + batch_size

def forward(
self, output: torch.Tensor, labels: torch.Tensor = None, update: bool = False
):
self,
output: Tensor,
labels: Optional[Tensor] = None,
update: bool = False,
) -> Union[Tuple[Tensor, Optional[Tensor]], Tensor]:
"""Query memory bank for additional negative samples
Args:
Expand Down
19 changes: 12 additions & 7 deletions lightly/models/_momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,23 @@
# All Rights Reserved

import copy
from typing import Iterable, Tuple

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.parameter import Parameter


def _deactivate_requires_grad(params):
def _deactivate_requires_grad(params: Iterable[Parameter]) -> None:
"""Deactivates the requires_grad flag for all parameters."""
for param in params:
param.requires_grad = False


def _do_momentum_update(prev_params, params, m):
def _do_momentum_update(
prev_params: Iterable[Parameter], params: Iterable[Parameter], m: float
) -> None:
"""Updates the weights of the previous parameters."""
for prev_param, param in zip(prev_params, params):
prev_param.data = prev_param.data * m + param.data * (1.0 - m)
Expand All @@ -42,7 +47,7 @@ class _MomentumEncoderMixin:
>>> # initialize momentum_backbone and momentum_projection_head
>>> self._init_momentum_encoder()
>>>
>>> def forward(self, x: torch.Tensor):
>>> def forward(self, x: Tensor):
>>>
>>> # do the momentum update
>>> self._momentum_update(0.999)
Expand All @@ -59,7 +64,7 @@ class _MomentumEncoderMixin:
momentum_backbone: nn.Module
momentum_projection_head: nn.Module

def _init_momentum_encoder(self):
def _init_momentum_encoder(self) -> None:
"""Initializes momentum backbone and a momentum projection head."""
assert self.backbone is not None
assert self.projection_head is not None
Expand All @@ -71,7 +76,7 @@ def _init_momentum_encoder(self):
_deactivate_requires_grad(self.momentum_projection_head.parameters())

@torch.no_grad()
def _momentum_update(self, m: float = 0.999):
def _momentum_update(self, m: float = 0.999) -> None:
"""Performs the momentum update for the backbone and projection head."""
_do_momentum_update(
self.momentum_backbone.parameters(),
Expand All @@ -85,14 +90,14 @@ def _momentum_update(self, m: float = 0.999):
)

@torch.no_grad()
def _batch_shuffle(self, batch: torch.Tensor):
def _batch_shuffle(self, batch: Tensor) -> Tuple[Tensor, Tensor]:
"""Returns the shuffled batch and the indices to undo."""
batch_size = batch.shape[0]
shuffle = torch.randperm(batch_size, device=batch.device)
return batch[shuffle], shuffle

@torch.no_grad()
def _batch_unshuffle(self, batch: torch.Tensor, shuffle: torch.Tensor):
def _batch_unshuffle(self, batch: Tensor, shuffle: Tensor) -> Tensor:
"""Returns the unshuffled batch."""
unshuffle = torch.argsort(shuffle)
return batch[unshuffle]
22 changes: 16 additions & 6 deletions lightly/models/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

from __future__ import annotations

from typing import Any

import torch
import torch.nn as nn
from torch import Tensor


class SplitBatchNorm(nn.BatchNorm2d):
Expand All @@ -21,27 +26,30 @@ class SplitBatchNorm(nn.BatchNorm2d):
"""

def __init__(self, num_features, num_splits, **kw):
def __init__(self, num_features: int, num_splits: int, **kw: Any) -> None:
super().__init__(num_features, **kw)
self.num_splits = num_splits
# Register buffers
self.register_buffer(
"running_mean", torch.zeros(num_features * self.num_splits)
)
self.register_buffer("running_var", torch.ones(num_features * self.num_splits))

def train(self, mode=True):
def train(self, mode: bool = True) -> SplitBatchNorm:
# lazily collate stats when we are going to use them
if (self.training is True) and (mode is False):
assert self.running_mean is not None
self.running_mean = torch.mean(
self.running_mean.view(self.num_splits, self.num_features), dim=0
).repeat(self.num_splits)
assert self.running_var is not None
self.running_var = torch.mean(
self.running_var.view(self.num_splits, self.num_features), dim=0
).repeat(self.num_splits)

return super().train(mode)

def forward(self, input):
def forward(self, input: Tensor) -> Tensor:
"""Computes the SplitBatchNorm on the input."""
# get input shape
N, C, H, W = input.shape
Expand All @@ -60,10 +68,12 @@ def forward(self, input):
self.eps,
).view(N, C, H, W)
else:
# We have to ignore the type errors here, because we know that running_mean
# and running_var are not None, but the type checker does not.
result = nn.functional.batch_norm(
input,
self.running_mean[: self.num_features],
self.running_var[: self.num_features],
self.running_mean[: self.num_features], # type: ignore[index]
self.running_var[: self.num_features], # type: ignore[index]
self.weight,
self.bias,
False,
Expand All @@ -74,7 +84,7 @@ def forward(self, input):
return result


def get_norm_layer(num_features: int, num_splits: int, **kw):
def get_norm_layer(num_features: int, num_splits: int, **kw: Any) -> nn.Module:
"""Utility to switch between BatchNorm2d and SplitBatchNorm."""
if num_splits > 0:
return SplitBatchNorm(num_features, num_splits)
Expand Down
49 changes: 27 additions & 22 deletions lightly/models/modules/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import torch.nn as nn
from torch import Tensor

from lightly.models import utils

Expand All @@ -33,10 +34,10 @@ class ProjectionHead(nn.Module):

def __init__(
self, blocks: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]]
):
) -> None:
super(ProjectionHead, self).__init__()

layers = []
layers: List[nn.Module] = []
for input_dim, output_dim, batch_norm, non_linearity in blocks:
use_bias = not bool(batch_norm)
layers.append(nn.Linear(input_dim, output_dim, bias=use_bias))
Expand All @@ -46,15 +47,16 @@ def __init__(
layers.append(non_linearity)
self.layers = nn.Sequential(*layers)

def forward(self, x: torch.Tensor):
def forward(self, x: Tensor) -> Tensor:
"""Computes one forward pass through the projection head.
Args:
x:
Input of shape bsz x num_ftrs.
"""
return self.layers(x)
projection: Tensor = self.layers(x)
return projection


class BarlowTwinsProjectionHead(ProjectionHead):
Expand Down Expand Up @@ -324,16 +326,16 @@ class SMoGPrototypes(nn.Module):

def __init__(
self,
group_features: torch.Tensor,
group_features: Tensor,
beta: float,
):
super(SMoGPrototypes, self).__init__()
self.group_features = nn.Parameter(group_features, requires_grad=False)
self.beta = beta

def forward(
self, x: torch.Tensor, group_features: torch.Tensor, temperature: float = 0.1
) -> torch.Tensor:
self, x: Tensor, group_features: Tensor, temperature: float = 0.1
) -> Tensor:
"""Computes the logits for given model outputs and group features.
Args:
Expand All @@ -353,7 +355,7 @@ def forward(
logits = torch.mm(x, group_features.t())
return logits / temperature

def get_updated_group_features(self, x: torch.Tensor) -> None:
def get_updated_group_features(self, x: Tensor) -> Tensor:
"""Performs the synchronous momentum update of the group vectors.
Args:
Expand All @@ -370,23 +372,23 @@ def get_updated_group_features(self, x: torch.Tensor) -> None:
mask = assignments == assigned_class
group_features[assigned_class] = self.beta * self.group_features[
assigned_class
] + (1 - self.beta) * x[mask].mean(axis=0)
] + (1 - self.beta) * x[mask].mean(dim=0)

return group_features

def set_group_features(self, x: torch.Tensor) -> None:
def set_group_features(self, x: Tensor) -> None:
"""Sets the group features and asserts they don't require gradient."""
self.group_features.data = x.to(self.group_features.device)

@torch.no_grad()
def assign_groups(self, x: torch.Tensor) -> torch.LongTensor:
def assign_groups(self, x: Tensor) -> Tensor:
"""Assigns each representation in x to a group based on cosine similarity.
Args:
Tensor of shape bsz x dim.
Returns:
LongTensor of shape bsz indicating group assignments.
Tensor of shape bsz indicating group assignments.
"""
return torch.argmax(self.forward(x, self.group_features), dim=-1)
Expand Down Expand Up @@ -524,19 +526,21 @@ def __init__(
)
self.n_steps_frozen_prototypes = n_steps_frozen_prototypes

def forward(self, x, step=None) -> Union[torch.Tensor, List[torch.Tensor]]:
def forward(
self, x: Tensor, step: Optional[int] = None
) -> Union[Tensor, List[Tensor]]:
self._freeze_prototypes_if_required(step)
out = []
for layer in self.heads:
out.append(layer(x))
return out[0] if self._is_single_prototype else out

def normalize(self):
def normalize(self) -> None:
"""Normalizes the prototypes so that they are on the unit sphere."""
for layer in self.heads:
utils.normalize_weight(layer.weight)

def _freeze_prototypes_if_required(self, step):
def _freeze_prototypes_if_required(self, step: Optional[int] = None) -> None:
if self.n_steps_frozen_prototypes > 0:
if step is None:
raise ValueError(
Expand Down Expand Up @@ -601,22 +605,23 @@ def __init__(
)
self.apply(self._init_weights)
self.freeze_last_layer = freeze_last_layer
self.last_layer = nn.utils.weight_norm(
nn.Linear(bottleneck_dim, output_dim, bias=False)
)
self.last_layer.weight_g.data.fill_(1)
self.last_layer = nn.Linear(bottleneck_dim, output_dim, bias=False)
self.last_layer = nn.utils.weight_norm(self.last_layer)
# Tell mypy this is ok because fill_ is overloaded.
self.last_layer.weight_g.data.fill_(1) # type: ignore

# Option to normalize last layer.
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False

def cancel_last_layer_gradients(self, current_epoch: int):
def cancel_last_layer_gradients(self, current_epoch: int) -> None:
"""Cancel last layer gradients to stabilize the training."""
if current_epoch >= self.freeze_last_layer:
return
for param in self.last_layer.parameters():
param.grad = None

def _init_weights(self, module):
def _init_weights(self, module: nn.Module) -> None:
"""Initializes layers with a truncated normal distribution."""
if isinstance(module, nn.Linear):
utils._no_grad_trunc_normal(
Expand All @@ -629,7 +634,7 @@ def _init_weights(self, module):
if module.bias is not None:
nn.init.constant_(module.bias, 0)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: Tensor) -> Tensor:
"""Computes one forward pass through the head."""
x = self.layers(x)
# l2 normalization
Expand Down
Loading

0 comments on commit 514ffd7

Please sign in to comment.