Skip to content

Commit

Permalink
Guarin lig 3062 add mocov2 imagenet benchmark (#1291)
Browse files Browse the repository at this point in the history
* Add MoCoV2 imagenet benchmark
* Add MoCoV2 imagenet benchmark results
* Move `memory_bank.py` from `lightly/loss` to `lightly/modules/models`
* Add gather_distributed option to MemoryBankModule
* Add option to specify feature dimension when creating MemoryBankModule
* Fix distributed batch shuffle
  • Loading branch information
guarin authored Dec 14, 2023
1 parent 8ad126a commit 610f73e
Show file tree
Hide file tree
Showing 31 changed files with 586 additions and 264 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ See the [benchmarking scripts](./benchmarks/imagenet/resnet50/) for details.
| BarlowTwins | Res50 | 256 | 100 | 62.9 | 72.6 | 45.6 | [link](https://tensorboard.dev/experiment/NxyNRiQsQjWZ82I9b0PvKg/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_barlowtwins_2023-08-18_00-11-03/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| BYOL | Res50 | 256 | 100 | 62.4 | 74.0 | 45.6 | [link](https://tensorboard.dev/experiment/Z0iG2JLaTJe5nuBD7DK1bg) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_byol_2023-07-10_10-37-32/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| DINO | Res50 | 128 | 100 | 68.2 | 72.5 | 49.9 | [link](https://tensorboard.dev/experiment/DvKHX9sNSWWqDrRksllPLA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dino_2023-06-06_13-59-48/pretrain/version_0/checkpoints/epoch%3D99-step%3D1000900.ckpt) |
| MoCoV2 | Res50 | 256 | 100 | 61.5 | 74.4 | 41.3 | - | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_mocov2_2023-12-06_15-06-19/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SimCLR* | Res50 | 256 | 100 | 63.2 | 73.9 | 44.8 | [link](https://tensorboard.dev/experiment/Ugol97adQdezgcVibDYMMA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_simclr_2023-06-22_09-11-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SimCLR* + DCL | Res50 | 256 | 100 | 65.1 | 73.5 | 49.6 | [link](https://tensorboard.dev/experiment/k4ZonZ77QzmBkc0lXswQlg/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dcl_2023-07-04_16-51-40/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SimCLR* + DCLW | Res50 | 256 | 100 | 64.5 | 73.2 | 48.5 | [link](https://tensorboard.dev/experiment/TrALnpwFQ4OkZV3uvaX7wQ/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dclw_2023-07-07_14-57-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
Expand Down
4 changes: 3 additions & 1 deletion benchmarks/imagenet/resnet50/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import finetune_eval
import knn_eval
import linear_eval
import mocov2
import simclr
import swav
import torch
Expand Down Expand Up @@ -57,6 +58,7 @@
"dcl": {"model": dcl.DCL, "transform": dcl.transform},
"dclw": {"model": dclw.DCLW, "transform": dclw.transform},
"dino": {"model": dino.DINO, "transform": dino.transform},
"mocov2": {"model": mocov2.MoCoV2, "transform": mocov2.transform},
"simclr": {"model": simclr.SimCLR, "transform": simclr.transform},
"swav": {"model": swav.SwAV, "transform": swav.transform},
"vicreg": {"model": vicreg.VICReg, "transform": vicreg.transform},
Expand Down Expand Up @@ -228,7 +230,7 @@ def pretrain(
logger=TensorBoardLogger(save_dir=str(log_dir), name="pretrain"),
precision=precision,
strategy="ddp_find_unused_parameters_true",
sync_batchnorm=True,
sync_batchnorm=accelerator != "cpu", # Sync batchnorm is not supported on CPU.
num_sanity_val_steps=0,
)

Expand Down
145 changes: 145 additions & 0 deletions benchmarks/imagenet/resnet50/mocov2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import copy
from typing import List, Tuple

import torch
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.nn import Identity
from torch.optim import SGD
from torchvision.models import resnet50

from lightly.loss import NTXentLoss
from lightly.models.modules import MoCoProjectionHead
from lightly.models.utils import (
batch_shuffle,
batch_unshuffle,
get_weight_decay_parameters,
update_momentum,
)
from lightly.transforms import MoCoV2Transform
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.scheduler import CosineWarmupScheduler


class MoCoV2(LightningModule):
def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
super().__init__()
self.save_hyperparameters()
self.batch_size_per_device = batch_size_per_device

resnet = resnet50()
resnet.fc = Identity() # Ignore classification head
self.backbone = resnet
self.projection_head = MoCoProjectionHead()
self.query_backbone = copy.deepcopy(self.backbone)
self.query_projection_head = MoCoProjectionHead()
self.criterion = NTXentLoss(
temperature=0.2,
memory_bank_size=(65536, 128),
gather_distributed=True,
)

self.online_classifier = OnlineLinearClassifier(num_classes=num_classes)

def forward(self, x: Tensor) -> Tensor:
return self.backbone(x)

@torch.no_grad()
def forward_key_encoder(self, x: Tensor) -> Tuple[Tensor, Tensor]:
x, shuffle = batch_shuffle(batch=x, distributed=self.trainer.num_devices > 1)
features = self.forward(x).flatten(start_dim=1)
projections = self.projection_head(features)
features = batch_unshuffle(
batch=features,
shuffle=shuffle,
distributed=self.trainer.num_devices > 1,
)
projections = batch_unshuffle(
batch=projections,
shuffle=shuffle,
distributed=self.trainer.num_devices > 1,
)
return features, projections

def forward_query_encoder(self, x: Tensor) -> Tensor:
features = self.query_backbone(x).flatten(start_dim=1)
projections = self.query_projection_head(features)
return projections

def training_step(
self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
) -> Tensor:
views, targets = batch[0], batch[1]

# Encode queries.
query_projections = self.forward_query_encoder(views[1])

# Momentum update. This happens between query and key encoding, following the
# original implementation from the authors:
# https://github.com/facebookresearch/moco/blob/5a429c00bb6d4efdf511bf31b6f01e064bf929ab/moco/builder.py#L142
update_momentum(self.query_backbone, self.backbone, m=0.999)
update_momentum(self.query_projection_head, self.projection_head, m=0.999)

# Encode keys.
key_features, key_projections = self.forward_key_encoder(views[0])
loss = self.criterion(query_projections, key_projections)
self.log(
"train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets)
)

# Online linear evaluation.
cls_loss, cls_log = self.online_classifier.training_step(
(key_features.detach(), targets), batch_idx
)
self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
return loss + cls_loss

def validation_step(
self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int
) -> Tensor:
images, targets = batch[0], batch[1]
features = self.forward(images).flatten(start_dim=1)
cls_loss, cls_log = self.online_classifier.validation_step(
(features.detach(), targets), batch_idx
)
self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets))
return cls_loss

def configure_optimizers(self):
# Don't use weight decay for batch norm, bias parameters, and classification
# head to improve performance.
# NOTE: The original implementation from the authors uses weight decay for all
# parameters.
params, params_no_weight_decay = get_weight_decay_parameters(
[self.query_backbone, self.query_projection_head]
)
optimizer = SGD(
[
{"name": "mocov2", "params": params},
{
"name": "mocov2_no_weight_decay",
"params": params_no_weight_decay,
"weight_decay": 0.0,
},
{
"name": "online_classifier",
"params": self.online_classifier.parameters(),
"weight_decay": 0.0,
},
],
lr=0.03 * self.batch_size_per_device * self.trainer.world_size / 256,
momentum=0.9,
weight_decay=1e-4,
)
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
warmup_epochs=0,
max_epochs=int(self.trainer.estimated_stepping_batches),
),
"interval": "step",
}
return [optimizer], [scheduler]


transform = MoCoV2Transform()
4 changes: 2 additions & 2 deletions benchmarks/imagenet/resnet50/swav.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from torch.nn import functional as F
from torchvision.models import resnet50

from lightly.loss.memory_bank import MemoryBankModule
from lightly.loss.swav_loss import SwaVLoss
from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes
from lightly.models.modules.memory_bank import MemoryBankModule
from lightly.models.utils import get_weight_decay_parameters
from lightly.transforms import SwaVTransform
from lightly.utils.benchmarking import OnlineLinearClassifier
Expand Down Expand Up @@ -40,7 +40,7 @@ def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
self.queues = ModuleList(
[
MemoryBankModule(
size=self.n_batches_in_queue * self.batch_size_per_device
size=(self.n_batches_in_queue * self.batch_size_per_device, 128)
)
for _ in range(CROP_COUNTS[0])
]
Expand Down
4 changes: 2 additions & 2 deletions docs/source/getting_started/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,9 @@ For more information check the documentation:
.. code-block:: python
# to create a NTXentLoss with a memory bank (like for MoCo) set the
# memory_bank_size parameter to a value > 0
# memory_bank_size parameter to a value > 0 and specify the feature dimension
from lightly.loss import NTXentLoss
criterion = NTXentLoss(memory_bank_size=4096)
criterion = NTXentLoss(memory_bank_size=(4096, 128))
# the memory bank is used automatically for every forward pass
y0, y1 = resnet_moco(x0, x1)
loss = criterion(y0, y1)
Expand Down
1 change: 1 addition & 0 deletions docs/source/getting_started/benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Evaluation settings are based on the following papers:
"BarlowTwins", "Res50", "256", "100", "62.9", "84.3", "72.6", "90.9", "45.6", "73.9", "`link <https://tensorboard.dev/experiment/NxyNRiQsQjWZ82I9b0PvKg/>`_", "`link <https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_barlowtwins_2023-08-18_00-11-03/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt>`_"
"BYOL", "Res50", "256", "100", "62.4", "84.7", "74.0", "91.9", "45.6", "74.8", "`link <https://tensorboard.dev/experiment/Z0iG2JLaTJe5nuBD7DK1bg>`_", "`link <https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_byol_2023-07-10_10-37-32/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt>`_"
"DINO", "Res50", "128", "100", "68.2", "87.9", "72.5", "90.8", "49.9", "78.7", "`link <https://tensorboard.dev/experiment/DvKHX9sNSWWqDrRksllPLA>`_", "`link <https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dino_2023-06-06_13-59-48/pretrain/version_0/checkpoints/epoch%3D99-step%3D1000900.ckpt>`_"
"MoCoV2", "Res50", "256", "100", "61.5", "84.1", "74.4", "91.8", "41.3", "71.9", "\-", "`link <https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_mocov2_2023-12-06_15-06-19/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt>`_"
"SimCLR*", "Res50", "256", "100", "63.2", "85.2", "73.9", "91.9", "44.8", "73.9", "`link <https://tensorboard.dev/experiment/Ugol97adQdezgcVibDYMMA>`_", "`link <https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_simclr_2023-06-22_09-11-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt>`_"
"SimCLR* + DCL", "Res50", "256", "100", "65.1", "86.2", "73.5", "91.7", "49.6", "77.5", "`link <https://tensorboard.dev/experiment/k4ZonZ77QzmBkc0lXswQlg>`_", "`link <https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dcl_2023-07-04_16-51-40/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt>`_"
"SimCLR* + DCLW", "Res50", "256", "100", "64.5", "86.0", "73.2", "91.5", "48.5", "76.8", "`link <https://tensorboard.dev/experiment/TrALnpwFQ4OkZV3uvaX7wQ>`_", "`link <https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dclw_2023-07-07_14-57-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt>`_"
Expand Down
5 changes: 2 additions & 3 deletions docs/source/getting_started/benchmarks/cifar10_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,9 @@
NegativeCosineSimilarity,
NTXentLoss,
SwaVLoss,
memory_bank,
)
from lightly.models import ResNetGenerator, modules, utils
from lightly.models.modules import heads
from lightly.models.modules import heads, memory_bank
from lightly.transforms import (
BYOLTransform,
BYOLView1Transform,
Expand Down Expand Up @@ -800,7 +799,7 @@ def __init__(self, dataloader_kNN, num_classes):
# smog
self.n_groups = 300
memory_bank_size = 10000
self.memory_bank = memory_bank.MemoryBankModule(size=memory_bank_size)
self.memory_bank = memory_bank.MemoryBankModule(size=(memory_bank_size, 128))
# create our loss
group_features = torch.nn.functional.normalize(
torch.rand(self.n_groups, 128), dim=1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,9 @@ def __init__(self, dataloader_kNN, num_classes):
utils.deactivate_requires_grad(self.projection_head_momentum)

# create our loss with the optional memory bank
self.criterion = NTXentLoss(temperature=0.07, memory_bank_size=memory_bank_size)
self.criterion = NTXentLoss(
temperature=0.07, memory_bank_size=(memory_bank_size, 128)
)

def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
Expand Down Expand Up @@ -505,7 +507,7 @@ def __init__(self, dataloader_kNN, num_classes):
self.projection_head = heads.NNCLRProjectionHead(feature_dim, 4096, 256)

self.criterion = NTXentLoss()
self.memory_bank = modules.NNMemoryBankModule(size=memory_bank_size)
self.memory_bank = modules.NNMemoryBankModule(size=(memory_bank_size, 256))

def forward(self, x):
y = self.backbone(x).flatten(start_dim=1)
Expand Down
11 changes: 6 additions & 5 deletions docs/source/getting_started/benchmarks/imagenette_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,9 @@
TiCoLoss,
VICRegLLoss,
VICRegLoss,
memory_bank,
)
from lightly.models import modules, utils
from lightly.models.modules import heads, masked_autoencoder
from lightly.models.modules import heads, masked_autoencoder, memory_bank
from lightly.transforms import (
BYOLTransform,
BYOLView1Transform,
Expand Down Expand Up @@ -330,7 +329,9 @@ def __init__(self, dataloader_kNN, num_classes):
utils.deactivate_requires_grad(self.projection_head_momentum)

# create our loss with the optional memory bank
self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=memory_bank_size)
self.criterion = NTXentLoss(
temperature=0.1, memory_bank_size=(memory_bank_size, 128)
)

def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
Expand Down Expand Up @@ -1015,7 +1016,7 @@ def __init__(self, dataloader_kNN, num_classes):
# smog
self.n_groups = 300
memory_bank_size = 10000
self.memory_bank = memory_bank.MemoryBankModule(size=memory_bank_size)
self.memory_bank = memory_bank.MemoryBankModule(size=(memory_bank_size, 128))
# create our loss
group_features = torch.nn.functional.normalize(
torch.rand(self.n_groups, 128), dim=1
Expand Down Expand Up @@ -1320,7 +1321,7 @@ def __init__(self, dataloader_kNN, num_classes):
self.prototypes = heads.SwaVPrototypes(128, 3000, 1)
self.start_queue_at_epoch = 15
self.queues = nn.ModuleList(
[memory_bank.MemoryBankModule(size=384) for _ in range(2)]
[memory_bank.MemoryBankModule(size=(384, 128)) for _ in range(2)]
) # Queue size reduced in order to work with a smaller dataset
self.criterion = SwaVLoss()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def __init__(self):
deactivate_requires_grad(self.projection_head_momentum)

# Create the loss function with memory bank.
self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=4096)
self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=(4096, 128))

def training_step(self, batch, batch_idx):
(x_q, x_k), _, _ = batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def __init__(self):
deactivate_requires_grad(self.projection_head_momentum)

# create our loss with the optional memory bank
self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=memory_bank_size)
self.criterion = NTXentLoss(
temperature=0.1, memory_bank_size=(memory_bank_size, 128)
)

def training_step(self, batch, batch_idx):
(x_q, x_k), _, _ = batch
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/moco.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def forward_momentum(self, x):
num_workers=8,
)

criterion = NTXentLoss(memory_bank_size=4096)
criterion = NTXentLoss(memory_bank_size=(4096, 128))
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)

epochs = 10
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/nnclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def forward(self, x):
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

memory_bank = NNMemoryBankModule(size=4096)
memory_bank = NNMemoryBankModule(size=(4096, 128))
memory_bank.to(device)

transform = SimCLRTransform(input_size=32)
Expand Down
4 changes: 2 additions & 2 deletions examples/pytorch/smog.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from sklearn.cluster import KMeans
from torch import nn

from lightly.loss.memory_bank import MemoryBankModule
from lightly.models import utils
from lightly.models.modules.heads import (
SMoGPredictionHead,
SMoGProjectionHead,
SMoGPrototypes,
)
from lightly.models.modules.memory_bank import MemoryBankModule
from lightly.transforms.smog_transform import SMoGTransform


Expand Down Expand Up @@ -80,7 +80,7 @@ def forward_momentum(self, x):

# memory bank because we reset the group features every 300 iterations
memory_bank_size = 300 * batch_size
memory_bank = MemoryBankModule(size=memory_bank_size)
memory_bank = MemoryBankModule(size=(memory_bank_size, 128))

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
Expand Down
6 changes: 4 additions & 2 deletions examples/pytorch/swav_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from torch import nn

from lightly.loss import SwaVLoss
from lightly.loss.memory_bank import MemoryBankModule
from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes
from lightly.models.modules.memory_bank import MemoryBankModule
from lightly.transforms.swav_transform import SwaVTransform


Expand All @@ -20,7 +20,9 @@ def __init__(self, backbone):
self.prototypes = SwaVPrototypes(128, 512, 1)

self.start_queue_at_epoch = 2
self.queues = nn.ModuleList([MemoryBankModule(size=3840) for _ in range(2)])
self.queues = nn.ModuleList(
[MemoryBankModule(size=(3840, 128)) for _ in range(2)]
)

def forward(self, high_resolution, low_resolution, epoch):
self.prototypes.normalize()
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch_lightning/moco.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self):
deactivate_requires_grad(self.backbone_momentum)
deactivate_requires_grad(self.projection_head_momentum)

self.criterion = NTXentLoss(memory_bank_size=4096)
self.criterion = NTXentLoss(memory_bank_size=(4096, 128))

def forward(self, x):
query = self.backbone(x).flatten(start_dim=1)
Expand Down
Loading

0 comments on commit 610f73e

Please sign in to comment.