Skip to content

Commit

Permalink
Add MMCR Loss and Transform (#1446)
Browse files Browse the repository at this point in the history
* Add MMCRLoss
* Add MMCR Transform
  • Loading branch information
johnsutor authored Dec 13, 2023
1 parent 66ad1b4 commit f3fd4a3
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 0 deletions.
1 change: 1 addition & 0 deletions lightly/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from lightly.loss.barlow_twins_loss import BarlowTwinsLoss
from lightly.loss.dcl_loss import DCLLoss, DCLWLoss
from lightly.loss.dino_loss import DINOLoss
from lightly.loss.mmcr_loss import MMCRLoss
from lightly.loss.msn_loss import MSNLoss
from lightly.loss.negative_cosine_similarity import NegativeCosineSimilarity
from lightly.loss.ntx_ent_loss import NTXentLoss
Expand Down
66 changes: 66 additions & 0 deletions lightly/loss/mmcr_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch
import torch.nn as nn
from torch.linalg import svd


class MMCRLoss(nn.Module):
"""Implementation of the loss function from MMCR [0] using Manifold Capacity.
All hyperparameters are set to the default values from the paper for ImageNet.
- [0]: Efficient Coding of Natural Images using Maximum Manifold Capacity
Representations, 2023, https://arxiv.org/pdf/2303.03307.pdf
Examples:
>>> # initialize loss function
>>> loss_fn = MMCRLoss()
>>> transform = MMCRTransform(k=2)
>>>
>>> # transform images, then feed through encoder and projector
>>> x = transform(x)
>>> online = online_network(x)
>>> momentum = momentum_network(x)
>>>
>>> # calculate loss
>>> loss = loss_fn(online, momentum)
"""

def __init__(self, lmda: float = 5e-3):
super().__init__()
if lmda < 0:
raise ValueError("lmda must be greater than or equal to 0")

self.lmda = lmda

def forward(self, online: torch.Tensor, momentum: torch.Tensor) -> torch.Tensor:
"""
Args:
online:
Output of the online network for the current batch. Expected to be
of shape (batch_size, k, embedding_size), where k represents the
number of randomly augmented views for each sample.
momentum:
Output of the momentum network for the current batch. Expected to be
of shape (batch_size, k, embedding_size), where k represents the
number of randomly augmented views for each sample.
"""
assert (
online.shape == momentum.shape
), "online and momentum need to have the same shape"

B = online.shape[0]

# Concatenate and calculate centroid
z = torch.cat([online, momentum], dim=1)
c = torch.mean(z, dim=1) # B x D

# Calculate singular values
_, S_z, _ = svd(z)
_, S_c, _ = svd(c)

# Calculate loss
loss = -1.0 * torch.sum(S_c) + self.lmda * torch.sum(S_z) / B

return loss
1 change: 1 addition & 0 deletions lightly/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from lightly.transforms.gaussian_blur import GaussianBlur
from lightly.transforms.jigsaw import Jigsaw
from lightly.transforms.mae_transform import MAETransform
from lightly.transforms.mmcr_transform import MMCRTransform
from lightly.transforms.moco_transform import MoCoV1Transform, MoCoV2Transform
from lightly.transforms.msn_transform import MSNTransform, MSNViewTransform
from lightly.transforms.pirl_transform import PIRLTransform
Expand Down
89 changes: 89 additions & 0 deletions lightly/transforms/mmcr_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import Dict, List, Optional, Tuple, Union

from lightly.transforms.byol_transform import BYOLView1Transform
from lightly.transforms.multi_view_transform import MultiViewTransform
from lightly.transforms.utils import IMAGENET_NORMALIZE


class MMCRTransform(MultiViewTransform):
"""Implements the transformations for MMCR[0], which
are based on BYOL[1].
Input to this transform:
PIL Image or Tensor.
Output of this transform:
List of Tensor of length k.
Applies the following augmentations by default:
- Random resized crop
- Random horizontal flip
- Color jitter
- Random gray scale
- Gaussian blur
- Solarization
- ImageNet normalization
Please refer to the BYOL implementation for additional details.
- [0]: Efficient Coding of Natural Images using Maximum Manifold Capacity
Representations, 2023, https://arxiv.org/pdf/2303.03307.pdf
- [1]: Bootstrap Your Own Latent, 2020, https://arxiv.org/pdf/2006.07733.pdf
Input to this transform:
PIL Image or Tensor.
Output of this transform:
List of tensors of length k.
Attributes:
k: Number of views.
transform: The transform to apply to each view.
"""

def __init__(
self,
k: int = 8,
input_size: int = 224,
cj_prob: float = 0.8,
cj_strength: float = 1.0,
cj_bright: float = 0.4,
cj_contrast: float = 0.4,
cj_sat: float = 0.2,
cj_hue: float = 0.1,
min_scale: float = 0.08,
random_gray_scale: float = 0.2,
gaussian_blur: float = 1.0,
solarization_prob: float = 0.0,
kernel_size: Optional[float] = None,
sigmas: Tuple[float, float] = (0.1, 2),
vf_prob: float = 0.0,
hf_prob: float = 0.5,
rr_prob: float = 0.0,
rr_degrees: Union[None, float, Tuple[float, float]] = None,
normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
if k < 1:
raise ValueError("k must be greater than or equal to 1")
transform = BYOLView1Transform(
input_size=input_size,
cj_prob=cj_prob,
cj_strength=cj_strength,
cj_bright=cj_bright,
cj_contrast=cj_contrast,
cj_sat=cj_sat,
cj_hue=cj_hue,
min_scale=min_scale,
random_gray_scale=random_gray_scale,
gaussian_blur=gaussian_blur,
solarization_prob=solarization_prob,
kernel_size=kernel_size,
sigmas=sigmas,
vf_prob=vf_prob,
hf_prob=hf_prob,
rr_prob=rr_prob,
rr_degrees=rr_degrees,
normalize=normalize,
)
super().__init__(transforms=[transform] * k)
65 changes: 65 additions & 0 deletions tests/loss/test_MMCR_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import unittest

import torch

from lightly.loss.mmcr_loss import MMCRLoss


class testMMCRLoss(unittest.TestCase):
def test_forward(self) -> None:
bs = 3
dim = 128
k = 32

loss_fn = MMCRLoss()
online = torch.randn(bs, k, dim)
momentum = torch.randn(bs, k, dim)

loss = loss_fn(online, momentum)

print(loss)

@unittest.skipUnless(torch.cuda.is_available(), "cuda not available")
def test_forward_cuda(self) -> None:
bs = 3
dim = 128
k = 32

loss_fn = MMCRLoss()
online = torch.randn(bs, k, dim).cuda()
momentum = torch.randn(bs, k, dim).cuda()

loss = loss_fn(online, momentum)

print(loss)

def test_loss_value(self) -> None:
"""If all values are zero, the loss should be zero."""
bs = 3
dim = 128
k = 32

loss_fn = MMCRLoss()
online = torch.zeros(bs, k, dim)
momentum = torch.zeros(bs, k, dim)

loss = loss_fn(online, momentum)

self.assertTrue(loss == 0)

def test_lambda_value_error(self) -> None:
"""If lambda is negative, a ValueError should be raised."""
with self.assertRaises(ValueError):
MMCRLoss(lmda=-1)

def test_shape_assertion_forward(self) -> None:
bs = 3
dim = 128
k = 32

loss_fn = MMCRLoss()
online = torch.randn(bs, k, dim)
momentum = torch.randn(bs, k, dim + 1)

with self.assertRaises(AssertionError):
loss_fn(online, momentum)
21 changes: 21 additions & 0 deletions tests/transforms/test_mmcr_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
from PIL import Image

from lightly.transforms.mmcr_transform import MMCRTransform


def test_raise_value_error() -> None:
with pytest.raises(ValueError):
MMCRTransform(k=0)


def test_num_views() -> None:
multi_view_transform = MMCRTransform(k=3)
assert len(multi_view_transform.transforms) == 3


def test_multi_view_on_pil_image() -> None:
multi_view_transform = MMCRTransform(k=3)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
assert len(output) == 3

0 comments on commit f3fd4a3

Please sign in to comment.