-
Notifications
You must be signed in to change notification settings - Fork 290
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement W-MSE Transform
- Loading branch information
Showing
4 changed files
with
346 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
"""Code for W-MSE Loss, largely taken from https://github.com/htdt/self-supervised""" | ||
|
||
from typing import Callable, List | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
def norm_mse_loss(x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: | ||
"""Normalized MSE Loss as implemented in https://github.com/htdt/self-supervised.""" | ||
x0 = F.normalize(x0) | ||
x1 = F.normalize(x1) | ||
return torch.sub(input=2, other=(x0 * x1).sum(dim=-1).mean(), alpha=2) | ||
|
||
|
||
class Whitening2d(nn.Module): | ||
""" | ||
Implementation of the whitening layer as described in [0]. | ||
[0] Whitening for Self-Supervised Representation Learning, 2021, https://arxiv.org/pdf/2007.06346.pdf | ||
""" | ||
|
||
running_mean: torch.Tensor | ||
running_variance: torch.Tensor | ||
|
||
def __init__( | ||
self, | ||
num_features: int, | ||
momentum: float = 0.01, | ||
track_running_stats: bool = True, | ||
eps: float = 0, | ||
): | ||
super(Whitening2d, self).__init__() | ||
self.num_features = num_features | ||
self.momentum = momentum | ||
self.track_running_stats = track_running_stats | ||
self.eps = eps | ||
|
||
if self.track_running_stats: | ||
self.register_buffer( | ||
"running_mean", torch.zeros([1, self.num_features, 1, 1]) | ||
) | ||
self.register_buffer("running_variance", torch.eye(self.num_features)) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = x.unsqueeze(2).unsqueeze(3) | ||
m = x.mean(0).view(self.num_features, -1).mean(-1).view(1, -1, 1, 1) | ||
if not self.training and self.track_running_stats: # for inference | ||
m = self.running_mean | ||
xn = x - m | ||
|
||
T = xn.permute(1, 0, 2, 3).contiguous().view(self.num_features, -1) | ||
f_cov = torch.mm(T, T.permute(1, 0)) / (T.shape[-1] - 1) | ||
|
||
eye = torch.eye(self.num_features).type(f_cov.type()) | ||
|
||
if not self.training and self.track_running_stats: # for inference | ||
f_cov = self.running_variance | ||
|
||
f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye | ||
|
||
inv_sqrt = torch.linalg.solve_triangular( | ||
torch.linalg.cholesky(f_cov_shrinked), eye, upper=False | ||
) | ||
|
||
inv_sqrt = inv_sqrt.contiguous().view( | ||
self.num_features, self.num_features, 1, 1 | ||
) | ||
|
||
decorrelated = F.conv2d(xn, inv_sqrt) | ||
|
||
if self.training and self.track_running_stats: | ||
self.running_mean = torch.add( | ||
self.momentum * m.detach(), | ||
(1 - self.momentum) * self.running_mean, | ||
out=self.running_mean, | ||
) | ||
self.running_variance = torch.add( | ||
self.momentum * f_cov.detach(), | ||
(1 - self.momentum) * self.running_variance, | ||
out=self.running_variance, | ||
) | ||
|
||
return decorrelated.squeeze(2).squeeze(2) | ||
|
||
|
||
class WMSELoss(torch.nn.Module): | ||
""" | ||
Implementation of the W-MSE loss function [0]. | ||
- [0] Whitening for Self-Supervised Representation Learning, 2021, https://arxiv.org/pdf/2007.06346.pdf | ||
Examples: | ||
>>> # initialize loss function | ||
>>> loss_fn = WMSELoss(num_samples=2) | ||
>>> transform_fn = WMSETransform(num_samples=2) | ||
>>> | ||
>>> # generate the transformed samples | ||
>>> samples = transform_fn(image) | ||
>>> | ||
>>> # feed through encoder head | ||
>>> h = torch.cat([model(s) for s in samples]) | ||
>>> | ||
>>> # calculate loss | ||
>>> loss = loss_fn(samples, h) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
embedding_dim: int = 64, | ||
momentum: float = 0.01, | ||
eps: float = 0.0, | ||
track_running_stats: bool = True, | ||
w_iter: int = 1, | ||
w_size: int = 128, | ||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = norm_mse_loss, | ||
num_samples: int = 2, | ||
): | ||
"""Parameters as described in [0] | ||
Args: | ||
embedding_dim: | ||
Dimensionality of the embedding. | ||
momentum: | ||
Momentum for the running statistics. | ||
eps: | ||
Epsilon for the running statistics. | ||
track_running_stats: | ||
Whether to track running statistics. | ||
w_iter: | ||
Number of iterations for the whitening. | ||
w_size: | ||
Sub-batch size to use for whitening. | ||
loss_fn: | ||
Loss function to use for the whitening. | ||
num_samples: | ||
Number of samples generated by the transforms for each image. | ||
""" | ||
super().__init__() | ||
self.whitening = Whitening2d( | ||
num_features=embedding_dim, | ||
momentum=momentum, | ||
eps=eps, | ||
track_running_stats=track_running_stats, | ||
) | ||
self.w_iter = w_iter | ||
self.w_size = w_size | ||
self.loss_f = loss_fn | ||
self.num_samples = num_samples | ||
self.num_pairs = num_samples * (num_samples - 1) // 2 | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
if input.shape[0] % self.num_samples != 0: | ||
raise RuntimeError("input batch size must be divisible by num_samples") | ||
|
||
bs = input.shape[0] // self.num_samples | ||
|
||
if bs < self.w_size: | ||
raise ValueError("batch size must be greater than or equal to w_size") | ||
loss = torch.tensor(0.0, device=input.device, requires_grad=True) | ||
|
||
for _ in range(self.w_iter): | ||
z = torch.empty_like(input) | ||
perm = torch.randperm(bs).view(-1, self.w_size) | ||
for idx in perm: | ||
for i in range(self.num_samples): | ||
z[idx + i * bs] = self.whitening(input[idx + i * bs]) | ||
for i in range(self.num_samples - 1): | ||
for j in range(i + 1, self.num_samples): | ||
x0 = z[i * bs : (i + 1) * bs] | ||
x1 = z[j * bs : (j + 1) * bs] | ||
loss += self.loss_f(x0, x1) | ||
loss /= self.w_iter * self.num_pairs | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from typing import Dict, List | ||
|
||
import torchvision.transforms as T | ||
|
||
from lightly.transforms.multi_view_transform import MultiViewTransform | ||
from lightly.transforms.utils import IMAGENET_NORMALIZE | ||
|
||
|
||
class WMSETransform(MultiViewTransform): | ||
"""Implements the transformations for W-MSE [0]. | ||
Input to this transform: | ||
PIL Image or Tensor. | ||
Output of this transform: | ||
List of Tensor of length num_samples. | ||
Applies the following augmentations by default: | ||
- Color jitter | ||
- Random gray scale | ||
- Random resized crop | ||
- Random horizontal flip | ||
- ImageNet normalization | ||
- [0] Whitening for Self-Supervised Representation Learning, 2021, https://arxiv.org/pdf/2007.06346.pdf | ||
Input to this transform: | ||
PIL Image or Tensor. | ||
Output of this transform: | ||
List of tensors of length k. | ||
Attributes: | ||
num_samples: | ||
Number of views. Must be the same as num_samples in the WMSELoss. | ||
input_size: | ||
Size of the input image in pixels. | ||
cj_prob: | ||
Probability that color jitter is applied. | ||
cj_bright: | ||
How much to jitter brightness. | ||
cj_contrast: | ||
How much to jitter constrast. | ||
cj_sat: | ||
How much to jitter saturation. | ||
cj_hue: | ||
How much to jitter hue. | ||
min_scale: | ||
Minimum size of the randomized crop relative to the input_size. | ||
random_gray_scale: | ||
Probability that random gray scale is applied. | ||
hf_prob: | ||
Probability that horizontal flip is applied. | ||
normalize: | ||
Dictionary with 'mean' and 'std' for torchvision.transforms.Normalize. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_samples: int = 2, | ||
input_size: int = 224, | ||
cj_prob: float = 0.8, | ||
cj_bright: float = 0.4, | ||
cj_contrast: float = 0.4, | ||
cj_sat: float = 0.4, | ||
cj_hue: float = 0.1, | ||
min_scale: float = 0.2, | ||
random_gray_scale: float = 0.1, | ||
hf_prob: float = 0.5, | ||
normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE, | ||
): | ||
if num_samples < 1: | ||
raise ValueError("num_samples must be greater than or equal to 1") | ||
transform = T.Compose( | ||
[ | ||
T.RandomApply( | ||
[T.ColorJitter(cj_bright, cj_contrast, cj_sat, cj_hue)], p=cj_prob | ||
), | ||
T.RandomGrayscale(p=random_gray_scale), | ||
T.RandomResizedCrop( | ||
input_size, | ||
scale=(min_scale, 1.0), | ||
interpolation=3, | ||
), | ||
T.RandomHorizontalFlip(p=hf_prob), | ||
T.ToTensor(), | ||
T.Normalize(mean=normalize["mean"], std=normalize["std"]), | ||
] | ||
) | ||
super().__init__(transforms=[transform] * num_samples) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import unittest | ||
|
||
import torch | ||
|
||
from lightly.loss.wmse_loss import WMSELoss | ||
|
||
|
||
class testWMSELoss(unittest.TestCase): | ||
def test_forward(self) -> None: | ||
bs = 2 | ||
dim = 64 | ||
num_samples = 32 | ||
|
||
loss_fn = WMSELoss() | ||
x = torch.randn(bs * num_samples, dim) | ||
|
||
loss = loss_fn(x) | ||
|
||
print(loss) | ||
|
||
@unittest.skipUnless(torch.cuda.is_available(), "cuda not available") | ||
def test_forward_cuda(self) -> None: | ||
bs = 2 | ||
dim = 64 | ||
num_samples = 32 | ||
|
||
loss_fn = WMSELoss().cuda() | ||
x = torch.randn(bs * num_samples, dim).cuda() | ||
|
||
loss = loss_fn(x) | ||
|
||
print(loss) | ||
|
||
def test_loss_value(self) -> None: | ||
"""If all values are zero, the loss should be zero.""" | ||
bs = 2 | ||
dim = 64 | ||
num_samples = 32 | ||
|
||
loss_fn = WMSELoss() | ||
x = torch.randn(bs * num_samples, dim) | ||
|
||
loss = loss_fn(x) | ||
|
||
self.assertGreater(loss, 0) | ||
|
||
def test_num_samples_error(self) -> None: | ||
with self.assertRaises(RuntimeError): | ||
loss_fn = WMSELoss(num_samples=3) | ||
x = torch.randn(5, 64) | ||
loss_fn(x) | ||
|
||
def test_w_size_error(self) -> None: | ||
with self.assertRaises(ValueError): | ||
loss_fn = WMSELoss(w_size=5) | ||
x = torch.randn(4, 64) | ||
loss_fn(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import pytest | ||
from PIL import Image | ||
|
||
from lightly.transforms.wmse_transform import WMSETransform | ||
|
||
|
||
def test_raise_value_error() -> None: | ||
with pytest.raises(ValueError): | ||
WMSETransform(num_samples=0) | ||
|
||
|
||
def test_num_views() -> None: | ||
multi_view_transform = WMSETransform(num_samples=3) | ||
assert len(multi_view_transform.transforms) == 3 | ||
|
||
|
||
def test_multi_view_on_pil_image() -> None: | ||
multi_view_transform = WMSETransform(num_samples=3) | ||
sample = Image.new("RGB", (100, 100)) | ||
output = multi_view_transform(sample) | ||
assert len(output) == 3 |