-
Notifications
You must be signed in to change notification settings - Fork 290
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add MMCRLoss * Add MMCR Transform
- Loading branch information
Showing
6 changed files
with
243 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch.linalg import svd | ||
|
||
|
||
class MMCRLoss(nn.Module): | ||
"""Implementation of the loss function from MMCR [0] using Manifold Capacity. | ||
All hyperparameters are set to the default values from the paper for ImageNet. | ||
- [0]: Efficient Coding of Natural Images using Maximum Manifold Capacity | ||
Representations, 2023, https://arxiv.org/pdf/2303.03307.pdf | ||
Examples: | ||
>>> # initialize loss function | ||
>>> loss_fn = MMCRLoss() | ||
>>> transform = MMCRTransform(k=2) | ||
>>> | ||
>>> # transform images, then feed through encoder and projector | ||
>>> x = transform(x) | ||
>>> online = online_network(x) | ||
>>> momentum = momentum_network(x) | ||
>>> | ||
>>> # calculate loss | ||
>>> loss = loss_fn(online, momentum) | ||
""" | ||
|
||
def __init__(self, lmda: float = 5e-3): | ||
super().__init__() | ||
if lmda < 0: | ||
raise ValueError("lmda must be greater than or equal to 0") | ||
|
||
self.lmda = lmda | ||
|
||
def forward(self, online: torch.Tensor, momentum: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Args: | ||
online: | ||
Output of the online network for the current batch. Expected to be | ||
of shape (batch_size, k, embedding_size), where k represents the | ||
number of randomly augmented views for each sample. | ||
momentum: | ||
Output of the momentum network for the current batch. Expected to be | ||
of shape (batch_size, k, embedding_size), where k represents the | ||
number of randomly augmented views for each sample. | ||
""" | ||
assert ( | ||
online.shape == momentum.shape | ||
), "online and momentum need to have the same shape" | ||
|
||
B = online.shape[0] | ||
|
||
# Concatenate and calculate centroid | ||
z = torch.cat([online, momentum], dim=1) | ||
c = torch.mean(z, dim=1) # B x D | ||
|
||
# Calculate singular values | ||
_, S_z, _ = svd(z) | ||
_, S_c, _ = svd(c) | ||
|
||
# Calculate loss | ||
loss = -1.0 * torch.sum(S_c) + self.lmda * torch.sum(S_z) / B | ||
|
||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
from typing import Dict, List, Optional, Tuple, Union | ||
|
||
from lightly.transforms.byol_transform import BYOLView1Transform | ||
from lightly.transforms.multi_view_transform import MultiViewTransform | ||
from lightly.transforms.utils import IMAGENET_NORMALIZE | ||
|
||
|
||
class MMCRTransform(MultiViewTransform): | ||
"""Implements the transformations for MMCR[0], which | ||
are based on BYOL[1]. | ||
Input to this transform: | ||
PIL Image or Tensor. | ||
Output of this transform: | ||
List of Tensor of length k. | ||
Applies the following augmentations by default: | ||
- Random resized crop | ||
- Random horizontal flip | ||
- Color jitter | ||
- Random gray scale | ||
- Gaussian blur | ||
- Solarization | ||
- ImageNet normalization | ||
Please refer to the BYOL implementation for additional details. | ||
- [0]: Efficient Coding of Natural Images using Maximum Manifold Capacity | ||
Representations, 2023, https://arxiv.org/pdf/2303.03307.pdf | ||
- [1]: Bootstrap Your Own Latent, 2020, https://arxiv.org/pdf/2006.07733.pdf | ||
Input to this transform: | ||
PIL Image or Tensor. | ||
Output of this transform: | ||
List of tensors of length k. | ||
Attributes: | ||
k: Number of views. | ||
transform: The transform to apply to each view. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
k: int = 8, | ||
input_size: int = 224, | ||
cj_prob: float = 0.8, | ||
cj_strength: float = 1.0, | ||
cj_bright: float = 0.4, | ||
cj_contrast: float = 0.4, | ||
cj_sat: float = 0.2, | ||
cj_hue: float = 0.1, | ||
min_scale: float = 0.08, | ||
random_gray_scale: float = 0.2, | ||
gaussian_blur: float = 1.0, | ||
solarization_prob: float = 0.0, | ||
kernel_size: Optional[float] = None, | ||
sigmas: Tuple[float, float] = (0.1, 2), | ||
vf_prob: float = 0.0, | ||
hf_prob: float = 0.5, | ||
rr_prob: float = 0.0, | ||
rr_degrees: Union[None, float, Tuple[float, float]] = None, | ||
normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, | ||
): | ||
if k < 1: | ||
raise ValueError("k must be greater than or equal to 1") | ||
transform = BYOLView1Transform( | ||
input_size=input_size, | ||
cj_prob=cj_prob, | ||
cj_strength=cj_strength, | ||
cj_bright=cj_bright, | ||
cj_contrast=cj_contrast, | ||
cj_sat=cj_sat, | ||
cj_hue=cj_hue, | ||
min_scale=min_scale, | ||
random_gray_scale=random_gray_scale, | ||
gaussian_blur=gaussian_blur, | ||
solarization_prob=solarization_prob, | ||
kernel_size=kernel_size, | ||
sigmas=sigmas, | ||
vf_prob=vf_prob, | ||
hf_prob=hf_prob, | ||
rr_prob=rr_prob, | ||
rr_degrees=rr_degrees, | ||
normalize=normalize, | ||
) | ||
super().__init__(transforms=[transform] * k) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import unittest | ||
|
||
import torch | ||
|
||
from lightly.loss.mmcr_loss import MMCRLoss | ||
|
||
|
||
class testMMCRLoss(unittest.TestCase): | ||
def test_forward(self) -> None: | ||
bs = 3 | ||
dim = 128 | ||
k = 32 | ||
|
||
loss_fn = MMCRLoss() | ||
online = torch.randn(bs, k, dim) | ||
momentum = torch.randn(bs, k, dim) | ||
|
||
loss = loss_fn(online, momentum) | ||
|
||
print(loss) | ||
|
||
@unittest.skipUnless(torch.cuda.is_available(), "cuda not available") | ||
def test_forward_cuda(self) -> None: | ||
bs = 3 | ||
dim = 128 | ||
k = 32 | ||
|
||
loss_fn = MMCRLoss() | ||
online = torch.randn(bs, k, dim).cuda() | ||
momentum = torch.randn(bs, k, dim).cuda() | ||
|
||
loss = loss_fn(online, momentum) | ||
|
||
print(loss) | ||
|
||
def test_loss_value(self) -> None: | ||
"""If all values are zero, the loss should be zero.""" | ||
bs = 3 | ||
dim = 128 | ||
k = 32 | ||
|
||
loss_fn = MMCRLoss() | ||
online = torch.zeros(bs, k, dim) | ||
momentum = torch.zeros(bs, k, dim) | ||
|
||
loss = loss_fn(online, momentum) | ||
|
||
self.assertTrue(loss == 0) | ||
|
||
def test_lambda_value_error(self) -> None: | ||
"""If lambda is negative, a ValueError should be raised.""" | ||
with self.assertRaises(ValueError): | ||
MMCRLoss(lmda=-1) | ||
|
||
def test_shape_assertion_forward(self) -> None: | ||
bs = 3 | ||
dim = 128 | ||
k = 32 | ||
|
||
loss_fn = MMCRLoss() | ||
online = torch.randn(bs, k, dim) | ||
momentum = torch.randn(bs, k, dim + 1) | ||
|
||
with self.assertRaises(AssertionError): | ||
loss_fn(online, momentum) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import pytest | ||
from PIL import Image | ||
|
||
from lightly.transforms.mmcr_transform import MMCRTransform | ||
|
||
|
||
def test_raise_value_error() -> None: | ||
with pytest.raises(ValueError): | ||
MMCRTransform(k=0) | ||
|
||
|
||
def test_num_views() -> None: | ||
multi_view_transform = MMCRTransform(k=3) | ||
assert len(multi_view_transform.transforms) == 3 | ||
|
||
|
||
def test_multi_view_on_pil_image() -> None: | ||
multi_view_transform = MMCRTransform(k=3) | ||
sample = Image.new("RGB", (100, 100)) | ||
output = multi_view_transform(sample) | ||
assert len(output) == 3 |