diff --git a/README.md b/README.md
index 70afc9418..4f61264cd 100644
--- a/README.md
+++ b/README.md
@@ -283,6 +283,7 @@ See the [benchmarking scripts](./benchmarks/imagenet/resnet50/) for details.
| BarlowTwins | Res50 | 256 | 100 | 62.9 | 72.6 | 45.6 | [link](https://tensorboard.dev/experiment/NxyNRiQsQjWZ82I9b0PvKg/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_barlowtwins_2023-08-18_00-11-03/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| BYOL | Res50 | 256 | 100 | 62.4 | 74.0 | 45.6 | [link](https://tensorboard.dev/experiment/Z0iG2JLaTJe5nuBD7DK1bg) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_byol_2023-07-10_10-37-32/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| DINO | Res50 | 128 | 100 | 68.2 | 72.5 | 49.9 | [link](https://tensorboard.dev/experiment/DvKHX9sNSWWqDrRksllPLA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dino_2023-06-06_13-59-48/pretrain/version_0/checkpoints/epoch%3D99-step%3D1000900.ckpt) |
+| MoCoV2 | Res50 | 256 | 100 | 61.5 | 74.4 | 41.3 | - | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_mocov2_2023-12-06_15-06-19/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SimCLR* | Res50 | 256 | 100 | 63.2 | 73.9 | 44.8 | [link](https://tensorboard.dev/experiment/Ugol97adQdezgcVibDYMMA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_simclr_2023-06-22_09-11-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SimCLR* + DCL | Res50 | 256 | 100 | 65.1 | 73.5 | 49.6 | [link](https://tensorboard.dev/experiment/k4ZonZ77QzmBkc0lXswQlg/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dcl_2023-07-04_16-51-40/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SimCLR* + DCLW | Res50 | 256 | 100 | 64.5 | 73.2 | 48.5 | [link](https://tensorboard.dev/experiment/TrALnpwFQ4OkZV3uvaX7wQ/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dclw_2023-07-07_14-57-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
diff --git a/benchmarks/imagenet/resnet50/main.py b/benchmarks/imagenet/resnet50/main.py
index 8eb3e7f65..34ce68f12 100644
--- a/benchmarks/imagenet/resnet50/main.py
+++ b/benchmarks/imagenet/resnet50/main.py
@@ -11,6 +11,7 @@
import finetune_eval
import knn_eval
import linear_eval
+import mocov2
import simclr
import swav
import torch
@@ -57,6 +58,7 @@
"dcl": {"model": dcl.DCL, "transform": dcl.transform},
"dclw": {"model": dclw.DCLW, "transform": dclw.transform},
"dino": {"model": dino.DINO, "transform": dino.transform},
+ "mocov2": {"model": mocov2.MoCoV2, "transform": mocov2.transform},
"simclr": {"model": simclr.SimCLR, "transform": simclr.transform},
"swav": {"model": swav.SwAV, "transform": swav.transform},
"vicreg": {"model": vicreg.VICReg, "transform": vicreg.transform},
@@ -228,7 +230,7 @@ def pretrain(
logger=TensorBoardLogger(save_dir=str(log_dir), name="pretrain"),
precision=precision,
strategy="ddp_find_unused_parameters_true",
- sync_batchnorm=True,
+ sync_batchnorm=accelerator != "cpu", # Sync batchnorm is not supported on CPU.
num_sanity_val_steps=0,
)
diff --git a/benchmarks/imagenet/resnet50/mocov2.py b/benchmarks/imagenet/resnet50/mocov2.py
new file mode 100644
index 000000000..850165c2b
--- /dev/null
+++ b/benchmarks/imagenet/resnet50/mocov2.py
@@ -0,0 +1,145 @@
+import copy
+from typing import List, Tuple
+
+import torch
+from pytorch_lightning import LightningModule
+from torch import Tensor
+from torch.nn import Identity
+from torch.optim import SGD
+from torchvision.models import resnet50
+
+from lightly.loss import NTXentLoss
+from lightly.models.modules import MoCoProjectionHead
+from lightly.models.utils import (
+ batch_shuffle,
+ batch_unshuffle,
+ get_weight_decay_parameters,
+ update_momentum,
+)
+from lightly.transforms import MoCoV2Transform
+from lightly.utils.benchmarking import OnlineLinearClassifier
+from lightly.utils.scheduler import CosineWarmupScheduler
+
+
+class MoCoV2(LightningModule):
+ def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
+ super().__init__()
+ self.save_hyperparameters()
+ self.batch_size_per_device = batch_size_per_device
+
+ resnet = resnet50()
+ resnet.fc = Identity() # Ignore classification head
+ self.backbone = resnet
+ self.projection_head = MoCoProjectionHead()
+ self.query_backbone = copy.deepcopy(self.backbone)
+ self.query_projection_head = MoCoProjectionHead()
+ self.criterion = NTXentLoss(
+ temperature=0.2,
+ memory_bank_size=(65536, 128),
+ gather_distributed=True,
+ )
+
+ self.online_classifier = OnlineLinearClassifier(num_classes=num_classes)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.backbone(x)
+
+ @torch.no_grad()
+ def forward_key_encoder(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ x, shuffle = batch_shuffle(batch=x, distributed=self.trainer.num_devices > 1)
+ features = self.forward(x).flatten(start_dim=1)
+ projections = self.projection_head(features)
+ features = batch_unshuffle(
+ batch=features,
+ shuffle=shuffle,
+ distributed=self.trainer.num_devices > 1,
+ )
+ projections = batch_unshuffle(
+ batch=projections,
+ shuffle=shuffle,
+ distributed=self.trainer.num_devices > 1,
+ )
+ return features, projections
+
+ def forward_query_encoder(self, x: Tensor) -> Tensor:
+ features = self.query_backbone(x).flatten(start_dim=1)
+ projections = self.query_projection_head(features)
+ return projections
+
+ def training_step(
+ self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
+ ) -> Tensor:
+ views, targets = batch[0], batch[1]
+
+ # Encode queries.
+ query_projections = self.forward_query_encoder(views[1])
+
+ # Momentum update. This happens between query and key encoding, following the
+ # original implementation from the authors:
+ # https://github.com/facebookresearch/moco/blob/5a429c00bb6d4efdf511bf31b6f01e064bf929ab/moco/builder.py#L142
+ update_momentum(self.query_backbone, self.backbone, m=0.999)
+ update_momentum(self.query_projection_head, self.projection_head, m=0.999)
+
+ # Encode keys.
+ key_features, key_projections = self.forward_key_encoder(views[0])
+ loss = self.criterion(query_projections, key_projections)
+ self.log(
+ "train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets)
+ )
+
+ # Online linear evaluation.
+ cls_loss, cls_log = self.online_classifier.training_step(
+ (key_features.detach(), targets), batch_idx
+ )
+ self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
+ return loss + cls_loss
+
+ def validation_step(
+ self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int
+ ) -> Tensor:
+ images, targets = batch[0], batch[1]
+ features = self.forward(images).flatten(start_dim=1)
+ cls_loss, cls_log = self.online_classifier.validation_step(
+ (features.detach(), targets), batch_idx
+ )
+ self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets))
+ return cls_loss
+
+ def configure_optimizers(self):
+ # Don't use weight decay for batch norm, bias parameters, and classification
+ # head to improve performance.
+ # NOTE: The original implementation from the authors uses weight decay for all
+ # parameters.
+ params, params_no_weight_decay = get_weight_decay_parameters(
+ [self.query_backbone, self.query_projection_head]
+ )
+ optimizer = SGD(
+ [
+ {"name": "mocov2", "params": params},
+ {
+ "name": "mocov2_no_weight_decay",
+ "params": params_no_weight_decay,
+ "weight_decay": 0.0,
+ },
+ {
+ "name": "online_classifier",
+ "params": self.online_classifier.parameters(),
+ "weight_decay": 0.0,
+ },
+ ],
+ lr=0.03 * self.batch_size_per_device * self.trainer.world_size / 256,
+ momentum=0.9,
+ weight_decay=1e-4,
+ )
+ scheduler = {
+ "scheduler": CosineWarmupScheduler(
+ optimizer=optimizer,
+ warmup_epochs=0,
+ max_epochs=int(self.trainer.estimated_stepping_batches),
+ ),
+ "interval": "step",
+ }
+ return [optimizer], [scheduler]
+
+
+transform = MoCoV2Transform()
diff --git a/benchmarks/imagenet/resnet50/swav.py b/benchmarks/imagenet/resnet50/swav.py
index 6c6bf9ade..c98f61361 100644
--- a/benchmarks/imagenet/resnet50/swav.py
+++ b/benchmarks/imagenet/resnet50/swav.py
@@ -8,9 +8,9 @@
from torch.nn import functional as F
from torchvision.models import resnet50
-from lightly.loss.memory_bank import MemoryBankModule
from lightly.loss.swav_loss import SwaVLoss
from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes
+from lightly.models.modules.memory_bank import MemoryBankModule
from lightly.models.utils import get_weight_decay_parameters
from lightly.transforms import SwaVTransform
from lightly.utils.benchmarking import OnlineLinearClassifier
@@ -40,7 +40,7 @@ def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
self.queues = ModuleList(
[
MemoryBankModule(
- size=self.n_batches_in_queue * self.batch_size_per_device
+ size=(self.n_batches_in_queue * self.batch_size_per_device, 128)
)
for _ in range(CROP_COUNTS[0])
]
diff --git a/docs/source/getting_started/advanced.rst b/docs/source/getting_started/advanced.rst
index 5be80439c..ede0f0464 100644
--- a/docs/source/getting_started/advanced.rst
+++ b/docs/source/getting_started/advanced.rst
@@ -298,9 +298,9 @@ For more information check the documentation:
.. code-block:: python
# to create a NTXentLoss with a memory bank (like for MoCo) set the
- # memory_bank_size parameter to a value > 0
+ # memory_bank_size parameter to a value > 0 and specify the feature dimension
from lightly.loss import NTXentLoss
- criterion = NTXentLoss(memory_bank_size=4096)
+ criterion = NTXentLoss(memory_bank_size=(4096, 128))
# the memory bank is used automatically for every forward pass
y0, y1 = resnet_moco(x0, x1)
loss = criterion(y0, y1)
diff --git a/docs/source/getting_started/benchmarks.rst b/docs/source/getting_started/benchmarks.rst
index 1947e6e9c..b22e32c8b 100644
--- a/docs/source/getting_started/benchmarks.rst
+++ b/docs/source/getting_started/benchmarks.rst
@@ -31,6 +31,7 @@ Evaluation settings are based on the following papers:
"BarlowTwins", "Res50", "256", "100", "62.9", "84.3", "72.6", "90.9", "45.6", "73.9", "`link `_", "`link `_"
"BYOL", "Res50", "256", "100", "62.4", "84.7", "74.0", "91.9", "45.6", "74.8", "`link `_", "`link `_"
"DINO", "Res50", "128", "100", "68.2", "87.9", "72.5", "90.8", "49.9", "78.7", "`link `_", "`link `_"
+ "MoCoV2", "Res50", "256", "100", "61.5", "84.1", "74.4", "91.8", "41.3", "71.9", "\-", "`link `_"
"SimCLR*", "Res50", "256", "100", "63.2", "85.2", "73.9", "91.9", "44.8", "73.9", "`link `_", "`link `_"
"SimCLR* + DCL", "Res50", "256", "100", "65.1", "86.2", "73.5", "91.7", "49.6", "77.5", "`link `_", "`link `_"
"SimCLR* + DCLW", "Res50", "256", "100", "64.5", "86.0", "73.2", "91.5", "48.5", "76.8", "`link `_", "`link `_"
diff --git a/docs/source/getting_started/benchmarks/cifar10_benchmark.py b/docs/source/getting_started/benchmarks/cifar10_benchmark.py
index 509c008cf..37aeb4ac6 100644
--- a/docs/source/getting_started/benchmarks/cifar10_benchmark.py
+++ b/docs/source/getting_started/benchmarks/cifar10_benchmark.py
@@ -79,10 +79,9 @@
NegativeCosineSimilarity,
NTXentLoss,
SwaVLoss,
- memory_bank,
)
from lightly.models import ResNetGenerator, modules, utils
-from lightly.models.modules import heads
+from lightly.models.modules import heads, memory_bank
from lightly.transforms import (
BYOLTransform,
BYOLView1Transform,
@@ -800,7 +799,7 @@ def __init__(self, dataloader_kNN, num_classes):
# smog
self.n_groups = 300
memory_bank_size = 10000
- self.memory_bank = memory_bank.MemoryBankModule(size=memory_bank_size)
+ self.memory_bank = memory_bank.MemoryBankModule(size=(memory_bank_size, 128))
# create our loss
group_features = torch.nn.functional.normalize(
torch.rand(self.n_groups, 128), dim=1
diff --git a/docs/source/getting_started/benchmarks/imagenet100_benchmark.py b/docs/source/getting_started/benchmarks/imagenet100_benchmark.py
index 3dcc1beb2..810b9e0e3 100644
--- a/docs/source/getting_started/benchmarks/imagenet100_benchmark.py
+++ b/docs/source/getting_started/benchmarks/imagenet100_benchmark.py
@@ -225,7 +225,9 @@ def __init__(self, dataloader_kNN, num_classes):
utils.deactivate_requires_grad(self.projection_head_momentum)
# create our loss with the optional memory bank
- self.criterion = NTXentLoss(temperature=0.07, memory_bank_size=memory_bank_size)
+ self.criterion = NTXentLoss(
+ temperature=0.07, memory_bank_size=(memory_bank_size, 128)
+ )
def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
@@ -505,7 +507,7 @@ def __init__(self, dataloader_kNN, num_classes):
self.projection_head = heads.NNCLRProjectionHead(feature_dim, 4096, 256)
self.criterion = NTXentLoss()
- self.memory_bank = modules.NNMemoryBankModule(size=memory_bank_size)
+ self.memory_bank = modules.NNMemoryBankModule(size=(memory_bank_size, 256))
def forward(self, x):
y = self.backbone(x).flatten(start_dim=1)
diff --git a/docs/source/getting_started/benchmarks/imagenette_benchmark.py b/docs/source/getting_started/benchmarks/imagenette_benchmark.py
index b5c5e8e71..ee9cc0171 100644
--- a/docs/source/getting_started/benchmarks/imagenette_benchmark.py
+++ b/docs/source/getting_started/benchmarks/imagenette_benchmark.py
@@ -85,10 +85,9 @@
TiCoLoss,
VICRegLLoss,
VICRegLoss,
- memory_bank,
)
from lightly.models import modules, utils
-from lightly.models.modules import heads, masked_autoencoder
+from lightly.models.modules import heads, masked_autoencoder, memory_bank
from lightly.transforms import (
BYOLTransform,
BYOLView1Transform,
@@ -330,7 +329,9 @@ def __init__(self, dataloader_kNN, num_classes):
utils.deactivate_requires_grad(self.projection_head_momentum)
# create our loss with the optional memory bank
- self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=memory_bank_size)
+ self.criterion = NTXentLoss(
+ temperature=0.1, memory_bank_size=(memory_bank_size, 128)
+ )
def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
@@ -1015,7 +1016,7 @@ def __init__(self, dataloader_kNN, num_classes):
# smog
self.n_groups = 300
memory_bank_size = 10000
- self.memory_bank = memory_bank.MemoryBankModule(size=memory_bank_size)
+ self.memory_bank = memory_bank.MemoryBankModule(size=(memory_bank_size, 128))
# create our loss
group_features = torch.nn.functional.normalize(
torch.rand(self.n_groups, 128), dim=1
@@ -1320,7 +1321,7 @@ def __init__(self, dataloader_kNN, num_classes):
self.prototypes = heads.SwaVPrototypes(128, 3000, 1)
self.start_queue_at_epoch = 15
self.queues = nn.ModuleList(
- [memory_bank.MemoryBankModule(size=384) for _ in range(2)]
+ [memory_bank.MemoryBankModule(size=(384, 128)) for _ in range(2)]
) # Queue size reduced in order to work with a smaller dataset
self.criterion = SwaVLoss()
diff --git a/docs/source/tutorials_source/package/tutorial_custom_augmentations.py b/docs/source/tutorials_source/package/tutorial_custom_augmentations.py
index 7e572aa1b..fc68889ad 100644
--- a/docs/source/tutorials_source/package/tutorial_custom_augmentations.py
+++ b/docs/source/tutorials_source/package/tutorial_custom_augmentations.py
@@ -268,7 +268,7 @@ def __init__(self):
deactivate_requires_grad(self.projection_head_momentum)
# Create the loss function with memory bank.
- self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=4096)
+ self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=(4096, 128))
def training_step(self, batch, batch_idx):
(x_q, x_k), _, _ = batch
diff --git a/docs/source/tutorials_source/package/tutorial_moco_memory_bank.py b/docs/source/tutorials_source/package/tutorial_moco_memory_bank.py
index f97992607..399b41956 100644
--- a/docs/source/tutorials_source/package/tutorial_moco_memory_bank.py
+++ b/docs/source/tutorials_source/package/tutorial_moco_memory_bank.py
@@ -237,7 +237,9 @@ def __init__(self):
deactivate_requires_grad(self.projection_head_momentum)
# create our loss with the optional memory bank
- self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=memory_bank_size)
+ self.criterion = NTXentLoss(
+ temperature=0.1, memory_bank_size=(memory_bank_size, 128)
+ )
def training_step(self, batch, batch_idx):
(x_q, x_k), _, _ = batch
diff --git a/examples/pytorch/moco.py b/examples/pytorch/moco.py
index 961337d8f..613dd4a20 100644
--- a/examples/pytorch/moco.py
+++ b/examples/pytorch/moco.py
@@ -62,7 +62,7 @@ def forward_momentum(self, x):
num_workers=8,
)
-criterion = NTXentLoss(memory_bank_size=4096)
+criterion = NTXentLoss(memory_bank_size=(4096, 128))
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)
epochs = 10
diff --git a/examples/pytorch/nnclr.py b/examples/pytorch/nnclr.py
index c3745235a..530e18153 100644
--- a/examples/pytorch/nnclr.py
+++ b/examples/pytorch/nnclr.py
@@ -38,7 +38,7 @@ def forward(self, x):
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
-memory_bank = NNMemoryBankModule(size=4096)
+memory_bank = NNMemoryBankModule(size=(4096, 128))
memory_bank.to(device)
transform = SimCLRTransform(input_size=32)
diff --git a/examples/pytorch/smog.py b/examples/pytorch/smog.py
index a4f81c8ed..0e59758f7 100644
--- a/examples/pytorch/smog.py
+++ b/examples/pytorch/smog.py
@@ -9,13 +9,13 @@
from sklearn.cluster import KMeans
from torch import nn
-from lightly.loss.memory_bank import MemoryBankModule
from lightly.models import utils
from lightly.models.modules.heads import (
SMoGPredictionHead,
SMoGProjectionHead,
SMoGPrototypes,
)
+from lightly.models.modules.memory_bank import MemoryBankModule
from lightly.transforms.smog_transform import SMoGTransform
@@ -80,7 +80,7 @@ def forward_momentum(self, x):
# memory bank because we reset the group features every 300 iterations
memory_bank_size = 300 * batch_size
-memory_bank = MemoryBankModule(size=memory_bank_size)
+memory_bank = MemoryBankModule(size=(memory_bank_size, 128))
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
diff --git a/examples/pytorch/swav_queue.py b/examples/pytorch/swav_queue.py
index 232cf3b55..a3a148482 100644
--- a/examples/pytorch/swav_queue.py
+++ b/examples/pytorch/swav_queue.py
@@ -7,8 +7,8 @@
from torch import nn
from lightly.loss import SwaVLoss
-from lightly.loss.memory_bank import MemoryBankModule
from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes
+from lightly.models.modules.memory_bank import MemoryBankModule
from lightly.transforms.swav_transform import SwaVTransform
@@ -20,7 +20,9 @@ def __init__(self, backbone):
self.prototypes = SwaVPrototypes(128, 512, 1)
self.start_queue_at_epoch = 2
- self.queues = nn.ModuleList([MemoryBankModule(size=3840) for _ in range(2)])
+ self.queues = nn.ModuleList(
+ [MemoryBankModule(size=(3840, 128)) for _ in range(2)]
+ )
def forward(self, high_resolution, low_resolution, epoch):
self.prototypes.normalize()
diff --git a/examples/pytorch_lightning/moco.py b/examples/pytorch_lightning/moco.py
index 3e6e1d973..7aa32dbbf 100644
--- a/examples/pytorch_lightning/moco.py
+++ b/examples/pytorch_lightning/moco.py
@@ -29,7 +29,7 @@ def __init__(self):
deactivate_requires_grad(self.backbone_momentum)
deactivate_requires_grad(self.projection_head_momentum)
- self.criterion = NTXentLoss(memory_bank_size=4096)
+ self.criterion = NTXentLoss(memory_bank_size=(4096, 128))
def forward(self, x):
query = self.backbone(x).flatten(start_dim=1)
diff --git a/examples/pytorch_lightning/nnclr.py b/examples/pytorch_lightning/nnclr.py
index 2f859a111..0a3c462cc 100644
--- a/examples/pytorch_lightning/nnclr.py
+++ b/examples/pytorch_lightning/nnclr.py
@@ -23,7 +23,7 @@ def __init__(self):
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = NNCLRProjectionHead(512, 512, 128)
self.prediction_head = NNCLRPredictionHead(128, 512, 128)
- self.memory_bank = NNMemoryBankModule(size=4096)
+ self.memory_bank = NNMemoryBankModule(size=(4096, 128))
self.criterion = NTXentLoss()
diff --git a/examples/pytorch_lightning/swav_queue.py b/examples/pytorch_lightning/swav_queue.py
index 0a2d40ab4..2aad1c44a 100644
--- a/examples/pytorch_lightning/swav_queue.py
+++ b/examples/pytorch_lightning/swav_queue.py
@@ -8,8 +8,8 @@
from torch import nn
from lightly.loss import SwaVLoss
-from lightly.loss.memory_bank import MemoryBankModule
from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes
+from lightly.models.modules.memory_bank import MemoryBankModule
from lightly.transforms.swav_transform import SwaVTransform
@@ -21,7 +21,9 @@ def __init__(self):
self.projection_head = SwaVProjectionHead(512, 512, 128)
self.prototypes = SwaVPrototypes(128, 512, 1)
self.start_queue_at_epoch = 2
- self.queues = nn.ModuleList([MemoryBankModule(size=3840) for _ in range(2)])
+ self.queues = nn.ModuleList(
+ [MemoryBankModule(size=(3840, 128)) for _ in range(2)]
+ )
self.criterion = SwaVLoss()
def training_step(self, batch, batch_idx):
diff --git a/examples/pytorch_lightning_distributed/moco.py b/examples/pytorch_lightning_distributed/moco.py
index 03b9698ac..6a59a096a 100644
--- a/examples/pytorch_lightning_distributed/moco.py
+++ b/examples/pytorch_lightning_distributed/moco.py
@@ -29,7 +29,7 @@ def __init__(self):
deactivate_requires_grad(self.backbone_momentum)
deactivate_requires_grad(self.projection_head_momentum)
- self.criterion = NTXentLoss(memory_bank_size=4096)
+ self.criterion = NTXentLoss(memory_bank_size=(4096, 128))
def forward(self, x):
query = self.backbone(x).flatten(start_dim=1)
diff --git a/examples/pytorch_lightning_distributed/nnclr.py b/examples/pytorch_lightning_distributed/nnclr.py
index 9c02e2834..d925265bd 100644
--- a/examples/pytorch_lightning_distributed/nnclr.py
+++ b/examples/pytorch_lightning_distributed/nnclr.py
@@ -23,7 +23,7 @@ def __init__(self):
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = NNCLRProjectionHead(512, 512, 128)
self.prediction_head = NNCLRPredictionHead(128, 512, 128)
- self.memory_bank = NNMemoryBankModule(size=4096)
+ self.memory_bank = NNMemoryBankModule(size=(4096, 128))
self.criterion = NTXentLoss()
diff --git a/lightly/loss/memory_bank.py b/lightly/loss/memory_bank.py
index 998917715..3c01a7082 100644
--- a/lightly/loss/memory_bank.py
+++ b/lightly/loss/memory_bank.py
@@ -1,131 +1,2 @@
-""" Memory Bank Wrapper """
-
-# Copyright (c) 2020. Lightly AG and its affiliates.
-# All Rights Reserved
-
-from typing import Optional, Tuple, Union
-
-import torch
-from torch import Tensor
-
-
-class MemoryBankModule(torch.nn.Module):
- """Memory bank implementation
-
- This is a parent class to all loss functions implemented by the lightly
- Python package. This way, any loss can be used with a memory bank if
- desired.
-
- Attributes:
- size:
- Number of keys the memory bank can store. If set to 0,
- memory bank is not used.
-
- Examples:
- >>> class MyLossFunction(MemoryBankModule):
- >>>
- >>> def __init__(self, memory_bank_size: int = 2 ** 16):
- >>> super(MyLossFunction, self).__init__(memory_bank_size)
- >>>
- >>> def forward(self, output: Tensor,
- >>> labels: Tensor = None):
- >>>
- >>> output, negatives = super(
- >>> MyLossFunction, self).forward(output)
- >>>
- >>> if negatives is not None:
- >>> # evaluate loss with negative samples
- >>> else:
- >>> # evaluate loss without negative samples
-
- """
-
- def __init__(self, size: int = 2**16):
- super(MemoryBankModule, self).__init__()
-
- if size < 0:
- msg = f"Illegal memory bank size {size}, must be non-negative."
- raise ValueError(msg)
-
- self.size = size
- self.register_buffer(
- "bank", tensor=torch.empty(0, dtype=torch.float), persistent=False
- )
- self.register_buffer(
- "bank_ptr", tensor=torch.empty(0, dtype=torch.long), persistent=False
- )
-
- @torch.no_grad()
- def _init_memory_bank(self, dim: int) -> None:
- """Initialize the memory bank if it's empty
-
- Args:
- dim:
- The dimension of the which are stored in the bank.
-
- """
- # create memory bank
- # we could use register buffers like in the moco repo
- # https://github.com/facebookresearch/moco but we don't
- # want to pollute our checkpoints
- bank: Tensor = torch.randn(dim, self.size).type_as(self.bank)
- self.bank: Tensor = torch.nn.functional.normalize(bank, dim=0)
- self.bank_ptr: Tensor = torch.zeros(1).type_as(self.bank_ptr)
-
- @torch.no_grad()
- def _dequeue_and_enqueue(self, batch: Tensor) -> None:
- """Dequeue the oldest batch and add the latest one
-
- Args:
- batch:
- The latest batch of keys to add to the memory bank.
-
- """
- batch_size = batch.shape[0]
- ptr = int(self.bank_ptr)
-
- if ptr + batch_size >= self.size:
- self.bank[:, ptr:] = batch[: self.size - ptr].T.detach()
- self.bank_ptr[0] = 0
- else:
- self.bank[:, ptr : ptr + batch_size] = batch.T.detach()
- self.bank_ptr[0] = ptr + batch_size
-
- def forward(
- self,
- output: Tensor,
- labels: Optional[Tensor] = None,
- update: bool = False,
- ) -> Union[Tuple[Tensor, Optional[Tensor]], Tensor]:
- """Query memory bank for additional negative samples
-
- Args:
- output:
- The output of the model.
- labels:
- Should always be None, will be ignored.
-
- Returns:
- The output if the memory bank is of size 0, otherwise the output
- and the entries from the memory bank.
-
- """
-
- # no memory bank, return the output
- if self.size == 0:
- return output, None
-
- _, dim = output.shape
-
- # initialize the memory bank if it is not already done
- if self.bank.nelement() == 0:
- self._init_memory_bank(dim)
-
- # query and update memory bank
- bank = self.bank.clone().detach()
-
- # only update memory bank if we later do backward pass (gradient)
- if update:
- self._dequeue_and_enqueue(output)
-
- return output, bank
+# For backwards compatibility as memory_bank module was previously in loss module.
+from lightly.models.modules.memory_bank import MemoryBankModule
diff --git a/lightly/loss/ntx_ent_loss.py b/lightly/loss/ntx_ent_loss.py
index b94211ec7..7f0c1d429 100644
--- a/lightly/loss/ntx_ent_loss.py
+++ b/lightly/loss/ntx_ent_loss.py
@@ -3,11 +3,13 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
+from typing import Sequence, Union
+
import torch
from torch import distributed as torch_dist
from torch import nn
-from lightly.loss.memory_bank import MemoryBankModule
+from lightly.models.modules.memory_bank import MemoryBankModule
from lightly.utils import dist
@@ -25,11 +27,18 @@ class NTXentLoss(MemoryBankModule):
temperature:
Scale logits by the inverse of the temperature.
memory_bank_size:
- Number of negative samples to store in the memory bank.
- Use 0 for SimCLR. For MoCo we typically use numbers like 4096 or 65536.
+ Size of the memory bank as (num_features, dim) tuple. num_features are the
+ number of negative samples stored in the memory bank. If num_features is 0,
+ the memory bank is disabled. Use 0 for SimCLR. For MoCo we typically use
+ numbers like 4096 or 65536.
+ Deprecated: If only a single integer is passed, it is interpreted as the
+ number of features and the feature dimension is inferred from the first
+ batch stored in the memory bank. Leaving out the feature dimension might
+ lead to errors in distributed training.
gather_distributed:
If True then negatives from all gpus are gathered before the
- loss calculation. This flag has no effect if memory_bank_size > 0.
+ loss calculation. If a memory bank is used and gather_distributed is True,
+ then tensors from all gpus are gathered before the memory bank is updated.
Raises:
ValueError: If abs(temperature) < 1e-8 to prevent divide by zero.
@@ -55,10 +64,10 @@ class NTXentLoss(MemoryBankModule):
def __init__(
self,
temperature: float = 0.5,
- memory_bank_size: int = 0,
+ memory_bank_size: Union[int, Sequence[int]] = 0,
gather_distributed: bool = False,
):
- super(NTXentLoss, self).__init__(size=memory_bank_size)
+ super().__init__(size=memory_bank_size, gather_distributed=gather_distributed)
self.temperature = temperature
self.gather_distributed = gather_distributed
self.cross_entropy = nn.CrossEntropyLoss(reduction="mean")
diff --git a/lightly/loss/regularizer/co2.py b/lightly/loss/regularizer/co2.py
index 2145cf04a..e61a7a302 100644
--- a/lightly/loss/regularizer/co2.py
+++ b/lightly/loss/regularizer/co2.py
@@ -3,9 +3,11 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
+from typing import Sequence, Union
+
import torch
-from lightly.loss.memory_bank import MemoryBankModule
+from lightly.models.modules.memory_bank import MemoryBankModule
class CO2Regularizer(MemoryBankModule):
@@ -19,15 +21,19 @@ class CO2Regularizer(MemoryBankModule):
t_consistency:
Temperature used during softmax calculations.
memory_bank_size:
- Number of negative samples to store in the memory bank.
- Use 0 to use the second batch for negative samples.
+ Size of the memory bank as (num_features, dim) tuple. num_features is the
+ number of negatives stored in the bank. If set to 0, the memory bank is
+ disabled. Deprecated: If only a single integer is passed, it is interpreted
+ as the number of features and the feature dimension is inferred from the
+ first batch stored in the memory bank. Leaving out the feature dimension
+ might lead to errors in distributed training.
Examples:
>>> # initialize loss function for MoCo
- >>> loss_fn = NTXentLoss(memory_bank_size=4096)
+ >>> loss_fn = NTXentLoss(memory_bank_size=(4096, 128))
>>>
>>> # initialize CO2 regularizer
- >>> co2 = CO2Regularizer(alpha=1.0, memory_bank_size=4096)
+ >>> co2 = CO2Regularizer(alpha=1.0, memory_bank_size=(4096, 128))
>>>
>>> # generate two random trasnforms of images
>>> t0 = transforms(images)
@@ -42,7 +48,10 @@ class CO2Regularizer(MemoryBankModule):
"""
def __init__(
- self, alpha: float = 1, t_consistency: float = 0.05, memory_bank_size: int = 0
+ self,
+ alpha: float = 1,
+ t_consistency: float = 0.05,
+ memory_bank_size: Union[int, Sequence[int]] = 0,
):
super(CO2Regularizer, self).__init__(size=memory_bank_size)
# try-catch the KLDivLoss construction for backwards compatability
diff --git a/lightly/models/modules/memory_bank.py b/lightly/models/modules/memory_bank.py
new file mode 100644
index 000000000..db7139b8c
--- /dev/null
+++ b/lightly/models/modules/memory_bank.py
@@ -0,0 +1,177 @@
+""" Memory Bank Wrapper """
+
+# Copyright (c) 2020. Lightly AG and its affiliates.
+# All Rights Reserved
+
+import warnings
+from typing import Sequence, Tuple, Union
+
+import torch
+from torch import Tensor
+from torch.nn import Module
+
+from lightly.models import utils
+
+
+class MemoryBankModule(Module):
+ """Memory bank implementation
+
+ This is a parent class to all loss functions implemented by the lightly
+ Python package. This way, any loss can be used with a memory bank if
+ desired.
+
+ Attributes:
+ size:
+ Size of the memory bank as (num_features, dim) tuple. If num_features is 0
+ then the memory bank is disabled. Deprecated: If only a single integer is
+ passed, it is interpreted as the number of features and the feature
+ dimension is inferred from the first batch stored in the memory bank.
+ Leaving out the feature dimension might lead to errors in distributed
+ training.
+ gather_distributed:
+ If True then negatives from all gpus are gathered before the memory bank
+ is updated. This results in more frequent updates of the memory bank and
+ keeps the memory bank contents independent of the number of gpus. But it has
+ the drawback that synchronization between processes is required and
+ diversity of the memory bank content is reduced.
+ feature_dim_first:
+ If True, the memory bank returns features with shape (dim, num_features).
+ If False, the memory bank returns features with shape (num_features, dim).
+
+ Examples:
+ >>> class MyLossFunction(MemoryBankModule):
+ >>>
+ >>> def __init__(self, memory_bank_size: Tuple[int, int] = (2 ** 16, 128)):
+ >>> super().__init__(memory_bank_size)
+ >>>
+ >>> def forward(self, output: Tensor, labels: Union[Tensor, None] = None):
+ >>> output, negatives = super().forward(output)
+ >>>
+ >>> if negatives is not None:
+ >>> # evaluate loss with negative samples
+ >>> else:
+ >>> # evaluate loss without negative samples
+
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]] = 65536,
+ gather_distributed: bool = False,
+ feature_dim_first: bool = True,
+ ):
+ super().__init__()
+ size_tuple = (size,) if isinstance(size, int) else tuple(size)
+
+ if any(x < 0 for x in size_tuple):
+ raise ValueError(
+ f"Illegal memory bank size {size}, all entries must be non-negative."
+ )
+
+ self.size = size_tuple
+ self.gather_distributed = gather_distributed
+ self.feature_dim_first = feature_dim_first
+ self.bank: Tensor
+ self.register_buffer(
+ "bank",
+ tensor=torch.empty(size=size_tuple, dtype=torch.float),
+ persistent=False,
+ )
+ self.bank_ptr: Tensor
+ self.register_buffer(
+ "bank_ptr",
+ tensor=torch.empty(1, dtype=torch.long),
+ persistent=False,
+ )
+
+ if isinstance(size, int) and size > 0:
+ warnings.warn(
+ (
+ f"Memory bank size 'size={size}' does not specify feature "
+ "dimension. It is recommended to set the feature dimension with "
+ "'size=(n, dim)' when creating the memory bank. Distributed "
+ "training might fail if the feature dimension is not set."
+ ),
+ UserWarning,
+ )
+ elif len(size_tuple) > 1:
+ self._init_memory_bank(size=size_tuple)
+
+ @torch.no_grad()
+ def _init_memory_bank(self, size: Tuple[int, ...]) -> None:
+ """Initialize the memory bank.
+
+ Args:
+ size:
+ Size of the memory bank as (num_features, dim) tuple.
+
+ """
+ self.bank = torch.randn(size).type_as(self.bank)
+ self.bank = torch.nn.functional.normalize(self.bank, dim=-1)
+ self.bank_ptr = torch.zeros(1).type_as(self.bank_ptr)
+
+ @torch.no_grad()
+ def _dequeue_and_enqueue(self, batch: Tensor) -> None:
+ """Dequeue the oldest batch and add the latest one
+
+ Args:
+ batch:
+ The latest batch of keys to add to the memory bank.
+
+ """
+ if self.gather_distributed:
+ batch = utils.concat_all_gather(batch)
+
+ batch_size = batch.shape[0]
+ ptr = int(self.bank_ptr)
+ if ptr + batch_size >= self.size[0]:
+ self.bank[ptr:] = batch[: self.size[0] - ptr].detach()
+ self.bank_ptr.zero_()
+ else:
+ self.bank[ptr : ptr + batch_size] = batch.detach()
+ self.bank_ptr[0] = ptr + batch_size
+
+ def forward(
+ self,
+ output: Tensor,
+ labels: Union[Tensor, None] = None,
+ update: bool = False,
+ ) -> Tuple[Tensor, Union[Tensor, None]]:
+ """Query memory bank for additional negative samples
+
+ Args:
+ output:
+ The output of the model.
+ labels:
+ Should always be None, will be ignored.
+ update:
+ If True, the memory bank will be updated with the current output.
+
+ Returns:
+ The output if the memory bank is of size 0, otherwise the output
+ and the entries from the memory bank. Entries from the memory bank have
+ shape (dim, num_features) if feature_dim_first is True and
+ (num_features, dim) otherwise.
+
+ """
+
+ # no memory bank, return the output
+ if self.size[0] == 0:
+ return output, None
+
+ # Initialize the memory bank if it is not already done.
+ if self.bank.ndim == 1:
+ dim = output.shape[1:]
+ self._init_memory_bank(size=(*self.size, *dim))
+
+ # query and update memory bank
+ bank = self.bank.clone().detach()
+ if self.feature_dim_first:
+ # swap bank size and feature dimension for backwards compatibility
+ bank = bank.transpose(0, -1)
+
+ # only update memory bank if we later do backward pass (gradient)
+ if update:
+ self._dequeue_and_enqueue(output)
+
+ return output, bank
diff --git a/lightly/models/modules/nn_memory_bank.py b/lightly/models/modules/nn_memory_bank.py
index 59abb65a5..8e5f1cd31 100644
--- a/lightly/models/modules/nn_memory_bank.py
+++ b/lightly/models/modules/nn_memory_bank.py
@@ -3,12 +3,12 @@
# Copyright (c) 2021. Lightly AG and its affiliates.
# All Rights Reserved
-from typing import Optional
+from typing import Sequence, Union
import torch
from torch import Tensor
-from lightly.loss.memory_bank import MemoryBankModule
+from lightly.models.modules.memory_bank import MemoryBankModule
class NNMemoryBankModule(MemoryBankModule):
@@ -22,13 +22,18 @@ class NNMemoryBankModule(MemoryBankModule):
Attributes:
size:
- Number of keys the memory bank can store.
+ Size of the memory bank as (num_features, dim) tuple. If num_features is 0
+ then the memory bank is disabled. Deprecated: If only a single integer is
+ passed, it is interpreted as the number of features and the feature
+ dimension is inferred from the first batch stored in the memory bank.
+ Leaving out the feature dimension might lead to errors in distributed
+ training.
Examples:
>>> model = NNCLR(backbone)
>>> criterion = NTXentLoss(temperature=0.1)
>>>
- >>> nn_replacer = NNmemoryBankModule(size=2 ** 16)
+ >>> nn_replacer = NNmemoryBankModule(size=(2 ** 16, 128))
>>>
>>> # forward pass
>>> (z0, p0), (z1, p1) = model(x0, x1)
@@ -39,9 +44,7 @@ class NNMemoryBankModule(MemoryBankModule):
"""
- def __init__(self, size: int = 2**16):
- if size <= 0:
- raise ValueError(f"Memory bank size must be positive, got {size}.")
+ def __init__(self, size: Union[int, Sequence[int]] = 2**16):
super(NNMemoryBankModule, self).__init__(size)
def forward( # type: ignore[override] # TODO(Philipp, 11/23): Fix signature to match parent class.
diff --git a/lightly/models/utils.py b/lightly/models/utils.py
index bd7c4e26d..ee96c067c 100644
--- a/lightly/models/utils.py
+++ b/lightly/models/utils.py
@@ -95,7 +95,7 @@ def concat_all_gather(x: torch.Tensor) -> torch.Tensor:
@torch.no_grad()
def batch_shuffle_distributed(batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- """Shuffles batch over multiple gpus.
+ """Shuffles batch over multiple devices.
This code was taken and adapted from here:
https://github.com/facebookresearch/moco.
@@ -109,26 +109,25 @@ def batch_shuffle_distributed(batch: torch.Tensor) -> Tuple[torch.Tensor, torch.
input batch and shuffle is an index to restore the original order.
"""
- # gather from all gpus
+ # gather from all devices
batch_size_this = batch.shape[0]
batch_gather = concat_all_gather(batch)
batch_size_all = batch_gather.shape[0]
- num_gpus = batch_size_all // batch_size_this
+ num_devices = batch_size_all // batch_size_this
# random shuffle index
- idx_shuffle = torch.randperm(batch_size_all).cuda()
+ idx_shuffle = torch.randperm(batch_size_all, device=batch.device)
- # broadcast to all gpus
+ # broadcast to all devices
dist.broadcast(idx_shuffle, src=0)
# index for restoring
shuffle = torch.argsort(idx_shuffle)
- # shuffled index for this gpu
- gpu_idx = dist.get_rank()
- idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
-
+ # shuffled index for this device
+ rank = dist.get_rank()
+ idx_this = idx_shuffle.view(num_devices, -1)[rank]
return batch_gather[idx_this], shuffle
@@ -136,7 +135,7 @@ def batch_shuffle_distributed(batch: torch.Tensor) -> Tuple[torch.Tensor, torch.
def batch_unshuffle_distributed(
batch: torch.Tensor, shuffle: torch.Tensor
) -> torch.Tensor:
- """Undo batch shuffle over multiple gpus.
+ """Undo batch shuffle over multiple devices.
This code was taken and adapted from here:
https://github.com/facebookresearch/moco.
@@ -151,17 +150,16 @@ def batch_unshuffle_distributed(
The unshuffled tensor.
"""
- # gather from all gpus
+ # gather from all devices
batch_size_this = batch.shape[0]
batch_gather = concat_all_gather(batch)
batch_size_all = batch_gather.shape[0]
- num_gpus = batch_size_all // batch_size_this
+ num_devices = batch_size_all // batch_size_this
# restored index for this gpu
- gpu_idx = dist.get_rank()
- idx_this = shuffle.view(num_gpus, -1)[gpu_idx]
-
+ rank = dist.get_rank()
+ idx_this = shuffle.view(num_devices, -1)[rank]
return batch_gather[idx_this]
diff --git a/lightly/utils/benchmarking/metric_callback.py b/lightly/utils/benchmarking/metric_callback.py
index b635b2fbb..ffc9747c6 100644
--- a/lightly/utils/benchmarking/metric_callback.py
+++ b/lightly/utils/benchmarking/metric_callback.py
@@ -60,7 +60,7 @@ def _append_metrics(
self, metrics_dict: Dict[str, List[float]], trainer: Trainer
) -> None:
for name, value in trainer.callback_metrics.items():
- if isinstance(value, float) or (
- isinstance(value, Tensor) and value.numel() == 1
- ):
- metrics_dict.setdefault(name, []).append(float(value))
+ if isinstance(value, Tensor) and value.numel() != 1:
+ # Skip non-scalar tensors.
+ continue
+ metrics_dict.setdefault(name, []).append(float(value))
diff --git a/tests/loss/test_CO2Regularizer.py b/tests/loss/test_CO2Regularizer.py
index 7b56efcb1..28918eedd 100644
--- a/tests/loss/test_CO2Regularizer.py
+++ b/tests/loss/test_CO2Regularizer.py
@@ -18,7 +18,7 @@ def test_forward_pass_no_memory_bank(self):
self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0)
def test_forward_pass_memory_bank(self):
- reg = CO2Regularizer(memory_bank_size=4096)
+ reg = CO2Regularizer(memory_bank_size=(4096, 32))
for bsz in range(1, 20):
batch_1 = torch.randn((bsz, 32))
batch_2 = torch.randn((bsz, 32))
@@ -44,7 +44,7 @@ def test_forward_pass_cuda_memory_bank(self):
if not torch.cuda.is_available():
return
- reg = CO2Regularizer(memory_bank_size=4096)
+ reg = CO2Regularizer(memory_bank_size=(4096, 32))
for bsz in range(1, 20):
batch_1 = torch.randn((bsz, 32)).cuda()
batch_2 = torch.randn((bsz, 32)).cuda()
diff --git a/tests/loss/test_MemoryBank.py b/tests/loss/test_MemoryBank.py
deleted file mode 100644
index dc1df7ca0..000000000
--- a/tests/loss/test_MemoryBank.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import unittest
-
-import torch
-
-from lightly.loss.memory_bank import MemoryBankModule
-
-
-class TestNTXentLoss(unittest.TestCase):
- def test_init__negative_size(self):
- with self.assertRaises(ValueError):
- MemoryBankModule(size=-1)
-
- def test_forward_easy(self):
- bsz = 3
- dim, size = 2, 9
- n = 33 * bsz
- memory_bank = MemoryBankModule(size=size)
-
- ptr = 0
- for i in range(0, n, bsz):
- output = torch.randn(2 * bsz, dim)
- output.requires_grad = True
- out0, out1 = output[:bsz], output[bsz:]
-
- _, curr_memory_bank = memory_bank(out1, update=True)
- next_memory_bank = memory_bank.bank
-
- curr_diff = out0.T - curr_memory_bank[:, ptr : ptr + bsz]
- next_diff = out1.T - next_memory_bank[:, ptr : ptr + bsz]
-
- # the current memory bank should not hold the batch yet
- self.assertGreater(curr_diff.norm(), 1e-5)
- # the "next" memory bank should hold the batch
- self.assertGreater(1e-5, next_diff.norm())
-
- ptr = (ptr + bsz) % size
-
- def test_forward(self):
- bsz = 3
- dim, size = 2, 10
- n = 33 * bsz
- memory_bank = MemoryBankModule(size=size)
-
- for i in range(0, n, bsz):
- # see if there are any problems when the bank size
- # is no multiple of the batch size
- output = torch.randn(bsz, dim)
- _, _ = memory_bank(output)
-
- @unittest.skipUnless(torch.cuda.is_available(), "cuda not available")
- def test_forward__cuda(self):
- bsz = 3
- dim, size = 2, 10
- n = 33 * bsz
- memory_bank = MemoryBankModule(size=size)
- device = torch.device("cuda")
- memory_bank.to(device=device)
-
- for i in range(0, n, bsz):
- # see if there are any problems when the bank size
- # is no multiple of the batch size
- output = torch.randn(bsz, dim, device=device)
- _, _ = memory_bank(output)
diff --git a/tests/models/modules/test_memory_bank.py b/tests/models/modules/test_memory_bank.py
new file mode 100644
index 000000000..0f6ccae53
--- /dev/null
+++ b/tests/models/modules/test_memory_bank.py
@@ -0,0 +1,161 @@
+import re
+import unittest
+
+import pytest
+import torch
+
+from lightly.models.modules.memory_bank import MemoryBankModule
+
+
+class TestNTXentLoss(unittest.TestCase):
+ def test_init__negative_size(self) -> None:
+ with self.assertRaises(ValueError):
+ MemoryBankModule(size=-1)
+
+ def test_forward_easy(self) -> None:
+ bsz = 3
+ dim, size = 2, 9
+ n = 33 * bsz
+ memory_bank = MemoryBankModule(size=size)
+
+ ptr = 0
+ for i in range(0, n, bsz):
+ output = torch.randn(2 * bsz, dim)
+ output.requires_grad = True
+ out0, out1 = output[:bsz], output[bsz:]
+
+ _, curr_memory_bank = memory_bank(out1, update=True)
+ next_memory_bank = memory_bank.bank.transpose(0, -1)
+
+ curr_diff = out0.T - curr_memory_bank[:, ptr : ptr + bsz]
+ next_diff = out1.T - next_memory_bank[:, ptr : ptr + bsz]
+
+ # the current memory bank should not hold the batch yet
+ self.assertGreater(curr_diff.norm(), 1e-5)
+ # the "next" memory bank should hold the batch
+ self.assertGreater(1e-5, next_diff.norm())
+
+ ptr = (ptr + bsz) % size
+
+ def test_forward(self) -> None:
+ bsz = 3
+ dim, size = 2, 10
+ n = 33 * bsz
+ memory_bank = MemoryBankModule(size=size)
+
+ for i in range(0, n, bsz):
+ # see if there are any problems when the bank size
+ # is no multiple of the batch size
+ output = torch.randn(bsz, dim)
+ _, _ = memory_bank(output)
+
+ @unittest.skipUnless(torch.cuda.is_available(), "cuda not available")
+ def test_forward__cuda(self) -> None:
+ bsz = 3
+ dim, size = 2, 10
+ n = 33 * bsz
+ memory_bank = MemoryBankModule(size=size)
+ device = torch.device("cuda")
+ memory_bank.to(device=device)
+
+ for i in range(0, n, bsz):
+ # see if there are any problems when the bank size
+ # is no multiple of the batch size
+ output = torch.randn(bsz, dim, device=device)
+ _, _ = memory_bank(output)
+
+
+class TestMemoryBank:
+ def test_init__negative_size(self) -> None:
+ with pytest.raises(
+ ValueError,
+ match="Illegal memory bank size -1, all entries must be non-negative.",
+ ):
+ MemoryBankModule(size=-1)
+
+ with pytest.raises(
+ ValueError,
+ match=re.escape(
+ "Illegal memory bank size (10, -1), all entries must be non-negative."
+ ),
+ ):
+ MemoryBankModule(size=(10, -1))
+
+ def test_init__no_dim_warning(self) -> None:
+ with pytest.warns(
+ UserWarning,
+ match=re.escape(
+ "Memory bank size 'size=10' does not specify feature "
+ "dimension. It is recommended to set the feature dimension with "
+ "'size=(n, dim)' when creating the memory bank. Distributed "
+ "training might fail if the feature dimension is not set."
+ ),
+ ):
+ MemoryBankModule(size=10)
+
+ def test_forward(self) -> None:
+ torch.manual_seed(0)
+ memory_bank = MemoryBankModule(size=(5, 2), feature_dim_first=False)
+ x0 = torch.randn(3, 2)
+ out0, bank0 = memory_bank(x0, update=True)
+ # Verify that output is same as input.
+ assert out0.tolist() == x0.tolist()
+ # Verify that memory bank was initialized and has correct shape.
+ assert bank0.shape == (5, 2)
+ assert memory_bank.bank.shape == (5, 2)
+ # Verify that output bank does not contain features from x0.
+ assert bank0[:3].tolist() != x0.tolist()
+ # Verify that memory bank was updated.
+ assert memory_bank.bank[:3].tolist() == x0.tolist()
+
+ x1 = torch.randn(3, 2)
+ out1, bank1 = memory_bank(x1, update=True)
+ # Verify that output is same as input.
+ assert out1.tolist() == x1.tolist()
+ # Verify that output bank contains features from x0.
+ assert bank1[:3].tolist() == x0.tolist()
+ # Verify that output bank does not contain features from x1.
+ assert bank1[3:].tolist() != x1[:2].tolist()
+ # Verify that memory bank was updated.
+ assert memory_bank.bank[:3].tolist() == x0.tolist()
+ assert memory_bank.bank[3:].tolist() == x1[:2].tolist()
+
+ # At this point the memory bank is full.
+ # Adding more features will start overwriting the bank from the beginning.
+
+ x2 = torch.randn(3, 2)
+ out2, bank2 = memory_bank(x2, update=True)
+ # Verify that output is same as input.
+ assert out2.tolist() == x2.tolist()
+ # Verify that output bank contains features from x0 and x1.
+ assert bank2[:3].tolist() == x0.tolist()
+ assert bank2[3:].tolist() == x1[:2].tolist()
+ # Verify that memory bank is overwritten.
+ assert memory_bank.bank[:3].tolist() == x2.tolist()
+
+ def test_forward__no_dim(self) -> None:
+ torch.manual_seed(0)
+ # Only specify size but not feature dimension.
+ memory_bank = MemoryBankModule(size=5, feature_dim_first=False)
+ x0 = torch.randn(3, 2)
+ out0, bank0 = memory_bank(x0, update=True)
+ # Verify that output is same as input.
+ assert out0.tolist() == x0.tolist()
+ # Verify that memory bank was initialized and has correct shape.
+ assert bank0.shape == (5, 2)
+ assert memory_bank.bank.shape == (5, 2)
+ # Verify that output bank does not contain features from x0.
+ assert bank0[:3].tolist() != x0.tolist()
+ # Verify that memory bank was updated.
+ assert memory_bank.bank[:3].tolist() == x0.tolist()
+
+ def test_forward__dim_first(self) -> None:
+ torch.manual_seed(0)
+ memory_bank = MemoryBankModule(size=(5, 2), feature_dim_first=True)
+ x0 = torch.randn(3, 2)
+ out0, bank0 = memory_bank(x0, update=True)
+ assert bank0.shape == (2, 5)
+ x1 = torch.randn(3, 2)
+ out1, bank1 = memory_bank(x1, update=True)
+ assert bank1.shape == (2, 5)
+ assert bank1[:, :3].tolist() == x0.T.tolist()
diff --git a/tests/models/test_ModelsNNCLR.py b/tests/models/test_ModelsNNCLR.py
index 3760faef3..ac8eb05e8 100644
--- a/tests/models/test_ModelsNNCLR.py
+++ b/tests/models/test_ModelsNNCLR.py
@@ -115,7 +115,7 @@ def test_memory_bank(self):
model = NNCLR(get_backbone(resnet), **config).to(device)
for nn_size in [2**3, 2**8]:
- nn_replacer = NNMemoryBankModule(size=nn_size)
+ nn_replacer = NNMemoryBankModule(size=(nn_size, config["out_dim"]))
with torch.no_grad():
for i in range(10):