Skip to content

Commit

Permalink
Implement W-MSE Loss
Browse files Browse the repository at this point in the history
Implement W-MSE Transform
  • Loading branch information
johnsutor committed Dec 17, 2023
1 parent 66ad1b4 commit e7f1bb9
Show file tree
Hide file tree
Showing 4 changed files with 346 additions and 0 deletions.
178 changes: 178 additions & 0 deletions lightly/loss/wmse_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""Code for W-MSE Loss, largely taken from https://github.com/htdt/self-supervised"""

from typing import Callable, List

import torch
import torch.nn as nn
import torch.nn.functional as F


def norm_mse_loss(x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
"""Normalized MSE Loss as implemented in https://github.com/htdt/self-supervised."""
x0 = F.normalize(x0)
x1 = F.normalize(x1)
return torch.sub(input=2, other=(x0 * x1).sum(dim=-1).mean(), alpha=2)


class Whitening2d(nn.Module):
"""
Implementation of the whitening layer as described in [0].
[0] Whitening for Self-Supervised Representation Learning, 2021, https://arxiv.org/pdf/2007.06346.pdf
"""

running_mean: torch.Tensor
running_variance: torch.Tensor

def __init__(
self,
num_features: int,
momentum: float = 0.01,
track_running_stats: bool = True,
eps: float = 0,
):
super(Whitening2d, self).__init__()
self.num_features = num_features
self.momentum = momentum
self.track_running_stats = track_running_stats
self.eps = eps

if self.track_running_stats:
self.register_buffer(
"running_mean", torch.zeros([1, self.num_features, 1, 1])
)
self.register_buffer("running_variance", torch.eye(self.num_features))

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.unsqueeze(2).unsqueeze(3)
m = x.mean(0).view(self.num_features, -1).mean(-1).view(1, -1, 1, 1)
if not self.training and self.track_running_stats: # for inference
m = self.running_mean
xn = x - m

T = xn.permute(1, 0, 2, 3).contiguous().view(self.num_features, -1)
f_cov = torch.mm(T, T.permute(1, 0)) / (T.shape[-1] - 1)

eye = torch.eye(self.num_features).type(f_cov.type())

if not self.training and self.track_running_stats: # for inference
f_cov = self.running_variance

f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye

inv_sqrt = torch.linalg.solve_triangular(
torch.linalg.cholesky(f_cov_shrinked), eye, upper=False
)

inv_sqrt = inv_sqrt.contiguous().view(
self.num_features, self.num_features, 1, 1
)

decorrelated = F.conv2d(xn, inv_sqrt)

if self.training and self.track_running_stats:
self.running_mean = torch.add(
self.momentum * m.detach(),
(1 - self.momentum) * self.running_mean,
out=self.running_mean,
)
self.running_variance = torch.add(
self.momentum * f_cov.detach(),
(1 - self.momentum) * self.running_variance,
out=self.running_variance,
)

return decorrelated.squeeze(2).squeeze(2)


class WMSELoss(torch.nn.Module):
"""
Implementation of the W-MSE loss function [0].
- [0] Whitening for Self-Supervised Representation Learning, 2021, https://arxiv.org/pdf/2007.06346.pdf
Examples:
>>> # initialize loss function
>>> loss_fn = WMSELoss(num_samples=2)
>>> transform_fn = WMSETransform(num_samples=2)
>>>
>>> # generate the transformed samples
>>> samples = transform_fn(image)
>>>
>>> # feed through encoder head
>>> h = torch.cat([model(s) for s in samples])
>>>
>>> # calculate loss
>>> loss = loss_fn(samples, h)
"""

def __init__(
self,
embedding_dim: int = 64,
momentum: float = 0.01,
eps: float = 0.0,
track_running_stats: bool = True,
w_iter: int = 1,
w_size: int = 128,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = norm_mse_loss,
num_samples: int = 2,
):
"""Parameters as described in [0]
Args:
embedding_dim:
Dimensionality of the embedding.
momentum:
Momentum for the running statistics.
eps:
Epsilon for the running statistics.
track_running_stats:
Whether to track running statistics.
w_iter:
Number of iterations for the whitening.
w_size:
Sub-batch size to use for whitening.
loss_fn:
Loss function to use for the whitening.
num_samples:
Number of samples generated by the transforms for each image.
"""
super().__init__()
self.whitening = Whitening2d(
num_features=embedding_dim,
momentum=momentum,
eps=eps,
track_running_stats=track_running_stats,
)
self.w_iter = w_iter
self.w_size = w_size
self.loss_f = loss_fn
self.num_samples = num_samples
self.num_pairs = num_samples * (num_samples - 1) // 2

def forward(self, input: torch.Tensor) -> torch.Tensor:
if input.shape[0] % self.num_samples != 0:
raise RuntimeError("input batch size must be divisible by num_samples")

bs = input.shape[0] // self.num_samples

if bs < self.w_size:
raise ValueError("batch size must be greater than or equal to w_size")
loss = torch.tensor(0.0, device=input.device, requires_grad=True)

for _ in range(self.w_iter):
z = torch.empty_like(input)
perm = torch.randperm(bs).view(-1, self.w_size)
for idx in perm:
for i in range(self.num_samples):
z[idx + i * bs] = self.whitening(input[idx + i * bs])
for i in range(self.num_samples - 1):
for j in range(i + 1, self.num_samples):
x0 = z[i * bs : (i + 1) * bs]
x1 = z[j * bs : (j + 1) * bs]
loss += self.loss_f(x0, x1)
loss /= self.w_iter * self.num_pairs
return loss
90 changes: 90 additions & 0 deletions lightly/transforms/wmse_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Dict, List

import torchvision.transforms as T

from lightly.transforms.multi_view_transform import MultiViewTransform
from lightly.transforms.utils import IMAGENET_NORMALIZE


class WMSETransform(MultiViewTransform):
"""Implements the transformations for W-MSE [0].
Input to this transform:
PIL Image or Tensor.
Output of this transform:
List of Tensor of length num_samples.
Applies the following augmentations by default:
- Color jitter
- Random gray scale
- Random resized crop
- Random horizontal flip
- ImageNet normalization
- [0] Whitening for Self-Supervised Representation Learning, 2021, https://arxiv.org/pdf/2007.06346.pdf
Input to this transform:
PIL Image or Tensor.
Output of this transform:
List of tensors of length k.
Attributes:
num_samples:
Number of views. Must be the same as num_samples in the WMSELoss.
input_size:
Size of the input image in pixels.
cj_prob:
Probability that color jitter is applied.
cj_bright:
How much to jitter brightness.
cj_contrast:
How much to jitter constrast.
cj_sat:
How much to jitter saturation.
cj_hue:
How much to jitter hue.
min_scale:
Minimum size of the randomized crop relative to the input_size.
random_gray_scale:
Probability that random gray scale is applied.
hf_prob:
Probability that horizontal flip is applied.
normalize:
Dictionary with 'mean' and 'std' for torchvision.transforms.Normalize.
"""

def __init__(
self,
num_samples: int = 2,
input_size: int = 224,
cj_prob: float = 0.8,
cj_bright: float = 0.4,
cj_contrast: float = 0.4,
cj_sat: float = 0.4,
cj_hue: float = 0.1,
min_scale: float = 0.2,
random_gray_scale: float = 0.1,
hf_prob: float = 0.5,
normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE,
):
if num_samples < 1:
raise ValueError("num_samples must be greater than or equal to 1")
transform = T.Compose(
[
T.RandomApply(
[T.ColorJitter(cj_bright, cj_contrast, cj_sat, cj_hue)], p=cj_prob
),
T.RandomGrayscale(p=random_gray_scale),
T.RandomResizedCrop(
input_size,
scale=(min_scale, 1.0),
interpolation=3,
),
T.RandomHorizontalFlip(p=hf_prob),
T.ToTensor(),
T.Normalize(mean=normalize["mean"], std=normalize["std"]),
]
)
super().__init__(transforms=[transform] * num_samples)
57 changes: 57 additions & 0 deletions tests/loss/test_WMSELoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import unittest

import torch

from lightly.loss.wmse_loss import WMSELoss


class testWMSELoss(unittest.TestCase):
def test_forward(self) -> None:
bs = 2
dim = 64
num_samples = 32

loss_fn = WMSELoss()
x = torch.randn(bs * num_samples, dim)

loss = loss_fn(x)

print(loss)

@unittest.skipUnless(torch.cuda.is_available(), "cuda not available")
def test_forward_cuda(self) -> None:
bs = 2
dim = 64
num_samples = 32

loss_fn = WMSELoss().cuda()
x = torch.randn(bs * num_samples, dim).cuda()

loss = loss_fn(x)

print(loss)

def test_loss_value(self) -> None:
"""If all values are zero, the loss should be zero."""
bs = 2
dim = 64
num_samples = 32

loss_fn = WMSELoss()
x = torch.randn(bs * num_samples, dim)

loss = loss_fn(x)

self.assertGreater(loss, 0)

def test_num_samples_error(self) -> None:
with self.assertRaises(RuntimeError):
loss_fn = WMSELoss(num_samples=3)
x = torch.randn(5, 64)
loss_fn(x)

def test_w_size_error(self) -> None:
with self.assertRaises(ValueError):
loss_fn = WMSELoss(w_size=5)
x = torch.randn(4, 64)
loss_fn(x)
21 changes: 21 additions & 0 deletions tests/transforms/test_wmse_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
from PIL import Image

from lightly.transforms.wmse_transform import WMSETransform


def test_raise_value_error() -> None:
with pytest.raises(ValueError):
WMSETransform(num_samples=0)


def test_num_views() -> None:
multi_view_transform = WMSETransform(num_samples=3)
assert len(multi_view_transform.transforms) == 3


def test_multi_view_on_pil_image() -> None:
multi_view_transform = WMSETransform(num_samples=3)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
assert len(output) == 3

0 comments on commit e7f1bb9

Please sign in to comment.