diff --git a/README.md b/README.md index 70afc9418..4f61264cd 100644 --- a/README.md +++ b/README.md @@ -283,6 +283,7 @@ See the [benchmarking scripts](./benchmarks/imagenet/resnet50/) for details. | BarlowTwins | Res50 | 256 | 100 | 62.9 | 72.6 | 45.6 | [link](https://tensorboard.dev/experiment/NxyNRiQsQjWZ82I9b0PvKg/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_barlowtwins_2023-08-18_00-11-03/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | | BYOL | Res50 | 256 | 100 | 62.4 | 74.0 | 45.6 | [link](https://tensorboard.dev/experiment/Z0iG2JLaTJe5nuBD7DK1bg) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_byol_2023-07-10_10-37-32/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | | DINO | Res50 | 128 | 100 | 68.2 | 72.5 | 49.9 | [link](https://tensorboard.dev/experiment/DvKHX9sNSWWqDrRksllPLA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dino_2023-06-06_13-59-48/pretrain/version_0/checkpoints/epoch%3D99-step%3D1000900.ckpt) | +| MoCoV2 | Res50 | 256 | 100 | 61.5 | 74.4 | 41.3 | - | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_mocov2_2023-12-06_15-06-19/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | | SimCLR* | Res50 | 256 | 100 | 63.2 | 73.9 | 44.8 | [link](https://tensorboard.dev/experiment/Ugol97adQdezgcVibDYMMA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_simclr_2023-06-22_09-11-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | | SimCLR* + DCL | Res50 | 256 | 100 | 65.1 | 73.5 | 49.6 | [link](https://tensorboard.dev/experiment/k4ZonZ77QzmBkc0lXswQlg/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dcl_2023-07-04_16-51-40/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | | SimCLR* + DCLW | Res50 | 256 | 100 | 64.5 | 73.2 | 48.5 | [link](https://tensorboard.dev/experiment/TrALnpwFQ4OkZV3uvaX7wQ/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dclw_2023-07-07_14-57-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | diff --git a/benchmarks/imagenet/resnet50/main.py b/benchmarks/imagenet/resnet50/main.py index 8eb3e7f65..34ce68f12 100644 --- a/benchmarks/imagenet/resnet50/main.py +++ b/benchmarks/imagenet/resnet50/main.py @@ -11,6 +11,7 @@ import finetune_eval import knn_eval import linear_eval +import mocov2 import simclr import swav import torch @@ -57,6 +58,7 @@ "dcl": {"model": dcl.DCL, "transform": dcl.transform}, "dclw": {"model": dclw.DCLW, "transform": dclw.transform}, "dino": {"model": dino.DINO, "transform": dino.transform}, + "mocov2": {"model": mocov2.MoCoV2, "transform": mocov2.transform}, "simclr": {"model": simclr.SimCLR, "transform": simclr.transform}, "swav": {"model": swav.SwAV, "transform": swav.transform}, "vicreg": {"model": vicreg.VICReg, "transform": vicreg.transform}, @@ -228,7 +230,7 @@ def pretrain( logger=TensorBoardLogger(save_dir=str(log_dir), name="pretrain"), precision=precision, strategy="ddp_find_unused_parameters_true", - sync_batchnorm=True, + sync_batchnorm=accelerator != "cpu", # Sync batchnorm is not supported on CPU. num_sanity_val_steps=0, ) diff --git a/benchmarks/imagenet/resnet50/mocov2.py b/benchmarks/imagenet/resnet50/mocov2.py new file mode 100644 index 000000000..850165c2b --- /dev/null +++ b/benchmarks/imagenet/resnet50/mocov2.py @@ -0,0 +1,145 @@ +import copy +from typing import List, Tuple + +import torch +from pytorch_lightning import LightningModule +from torch import Tensor +from torch.nn import Identity +from torch.optim import SGD +from torchvision.models import resnet50 + +from lightly.loss import NTXentLoss +from lightly.models.modules import MoCoProjectionHead +from lightly.models.utils import ( + batch_shuffle, + batch_unshuffle, + get_weight_decay_parameters, + update_momentum, +) +from lightly.transforms import MoCoV2Transform +from lightly.utils.benchmarking import OnlineLinearClassifier +from lightly.utils.scheduler import CosineWarmupScheduler + + +class MoCoV2(LightningModule): + def __init__(self, batch_size_per_device: int, num_classes: int) -> None: + super().__init__() + self.save_hyperparameters() + self.batch_size_per_device = batch_size_per_device + + resnet = resnet50() + resnet.fc = Identity() # Ignore classification head + self.backbone = resnet + self.projection_head = MoCoProjectionHead() + self.query_backbone = copy.deepcopy(self.backbone) + self.query_projection_head = MoCoProjectionHead() + self.criterion = NTXentLoss( + temperature=0.2, + memory_bank_size=(65536, 128), + gather_distributed=True, + ) + + self.online_classifier = OnlineLinearClassifier(num_classes=num_classes) + + def forward(self, x: Tensor) -> Tensor: + return self.backbone(x) + + @torch.no_grad() + def forward_key_encoder(self, x: Tensor) -> Tuple[Tensor, Tensor]: + x, shuffle = batch_shuffle(batch=x, distributed=self.trainer.num_devices > 1) + features = self.forward(x).flatten(start_dim=1) + projections = self.projection_head(features) + features = batch_unshuffle( + batch=features, + shuffle=shuffle, + distributed=self.trainer.num_devices > 1, + ) + projections = batch_unshuffle( + batch=projections, + shuffle=shuffle, + distributed=self.trainer.num_devices > 1, + ) + return features, projections + + def forward_query_encoder(self, x: Tensor) -> Tensor: + features = self.query_backbone(x).flatten(start_dim=1) + projections = self.query_projection_head(features) + return projections + + def training_step( + self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int + ) -> Tensor: + views, targets = batch[0], batch[1] + + # Encode queries. + query_projections = self.forward_query_encoder(views[1]) + + # Momentum update. This happens between query and key encoding, following the + # original implementation from the authors: + # https://github.com/facebookresearch/moco/blob/5a429c00bb6d4efdf511bf31b6f01e064bf929ab/moco/builder.py#L142 + update_momentum(self.query_backbone, self.backbone, m=0.999) + update_momentum(self.query_projection_head, self.projection_head, m=0.999) + + # Encode keys. + key_features, key_projections = self.forward_key_encoder(views[0]) + loss = self.criterion(query_projections, key_projections) + self.log( + "train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets) + ) + + # Online linear evaluation. + cls_loss, cls_log = self.online_classifier.training_step( + (key_features.detach(), targets), batch_idx + ) + self.log_dict(cls_log, sync_dist=True, batch_size=len(targets)) + return loss + cls_loss + + def validation_step( + self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int + ) -> Tensor: + images, targets = batch[0], batch[1] + features = self.forward(images).flatten(start_dim=1) + cls_loss, cls_log = self.online_classifier.validation_step( + (features.detach(), targets), batch_idx + ) + self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets)) + return cls_loss + + def configure_optimizers(self): + # Don't use weight decay for batch norm, bias parameters, and classification + # head to improve performance. + # NOTE: The original implementation from the authors uses weight decay for all + # parameters. + params, params_no_weight_decay = get_weight_decay_parameters( + [self.query_backbone, self.query_projection_head] + ) + optimizer = SGD( + [ + {"name": "mocov2", "params": params}, + { + "name": "mocov2_no_weight_decay", + "params": params_no_weight_decay, + "weight_decay": 0.0, + }, + { + "name": "online_classifier", + "params": self.online_classifier.parameters(), + "weight_decay": 0.0, + }, + ], + lr=0.03 * self.batch_size_per_device * self.trainer.world_size / 256, + momentum=0.9, + weight_decay=1e-4, + ) + scheduler = { + "scheduler": CosineWarmupScheduler( + optimizer=optimizer, + warmup_epochs=0, + max_epochs=int(self.trainer.estimated_stepping_batches), + ), + "interval": "step", + } + return [optimizer], [scheduler] + + +transform = MoCoV2Transform() diff --git a/benchmarks/imagenet/resnet50/swav.py b/benchmarks/imagenet/resnet50/swav.py index 6c6bf9ade..c98f61361 100644 --- a/benchmarks/imagenet/resnet50/swav.py +++ b/benchmarks/imagenet/resnet50/swav.py @@ -8,9 +8,9 @@ from torch.nn import functional as F from torchvision.models import resnet50 -from lightly.loss.memory_bank import MemoryBankModule from lightly.loss.swav_loss import SwaVLoss from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes +from lightly.models.modules.memory_bank import MemoryBankModule from lightly.models.utils import get_weight_decay_parameters from lightly.transforms import SwaVTransform from lightly.utils.benchmarking import OnlineLinearClassifier @@ -40,7 +40,7 @@ def __init__(self, batch_size_per_device: int, num_classes: int) -> None: self.queues = ModuleList( [ MemoryBankModule( - size=self.n_batches_in_queue * self.batch_size_per_device + size=(self.n_batches_in_queue * self.batch_size_per_device, 128) ) for _ in range(CROP_COUNTS[0]) ] diff --git a/docs/source/getting_started/advanced.rst b/docs/source/getting_started/advanced.rst index 5be80439c..ede0f0464 100644 --- a/docs/source/getting_started/advanced.rst +++ b/docs/source/getting_started/advanced.rst @@ -298,9 +298,9 @@ For more information check the documentation: .. code-block:: python # to create a NTXentLoss with a memory bank (like for MoCo) set the - # memory_bank_size parameter to a value > 0 + # memory_bank_size parameter to a value > 0 and specify the feature dimension from lightly.loss import NTXentLoss - criterion = NTXentLoss(memory_bank_size=4096) + criterion = NTXentLoss(memory_bank_size=(4096, 128)) # the memory bank is used automatically for every forward pass y0, y1 = resnet_moco(x0, x1) loss = criterion(y0, y1) diff --git a/docs/source/getting_started/benchmarks.rst b/docs/source/getting_started/benchmarks.rst index 1947e6e9c..b22e32c8b 100644 --- a/docs/source/getting_started/benchmarks.rst +++ b/docs/source/getting_started/benchmarks.rst @@ -31,6 +31,7 @@ Evaluation settings are based on the following papers: "BarlowTwins", "Res50", "256", "100", "62.9", "84.3", "72.6", "90.9", "45.6", "73.9", "`link `_", "`link `_" "BYOL", "Res50", "256", "100", "62.4", "84.7", "74.0", "91.9", "45.6", "74.8", "`link `_", "`link `_" "DINO", "Res50", "128", "100", "68.2", "87.9", "72.5", "90.8", "49.9", "78.7", "`link `_", "`link `_" + "MoCoV2", "Res50", "256", "100", "61.5", "84.1", "74.4", "91.8", "41.3", "71.9", "\-", "`link `_" "SimCLR*", "Res50", "256", "100", "63.2", "85.2", "73.9", "91.9", "44.8", "73.9", "`link `_", "`link `_" "SimCLR* + DCL", "Res50", "256", "100", "65.1", "86.2", "73.5", "91.7", "49.6", "77.5", "`link `_", "`link `_" "SimCLR* + DCLW", "Res50", "256", "100", "64.5", "86.0", "73.2", "91.5", "48.5", "76.8", "`link `_", "`link `_" diff --git a/docs/source/getting_started/benchmarks/cifar10_benchmark.py b/docs/source/getting_started/benchmarks/cifar10_benchmark.py index 509c008cf..37aeb4ac6 100644 --- a/docs/source/getting_started/benchmarks/cifar10_benchmark.py +++ b/docs/source/getting_started/benchmarks/cifar10_benchmark.py @@ -79,10 +79,9 @@ NegativeCosineSimilarity, NTXentLoss, SwaVLoss, - memory_bank, ) from lightly.models import ResNetGenerator, modules, utils -from lightly.models.modules import heads +from lightly.models.modules import heads, memory_bank from lightly.transforms import ( BYOLTransform, BYOLView1Transform, @@ -800,7 +799,7 @@ def __init__(self, dataloader_kNN, num_classes): # smog self.n_groups = 300 memory_bank_size = 10000 - self.memory_bank = memory_bank.MemoryBankModule(size=memory_bank_size) + self.memory_bank = memory_bank.MemoryBankModule(size=(memory_bank_size, 128)) # create our loss group_features = torch.nn.functional.normalize( torch.rand(self.n_groups, 128), dim=1 diff --git a/docs/source/getting_started/benchmarks/imagenet100_benchmark.py b/docs/source/getting_started/benchmarks/imagenet100_benchmark.py index 3dcc1beb2..810b9e0e3 100644 --- a/docs/source/getting_started/benchmarks/imagenet100_benchmark.py +++ b/docs/source/getting_started/benchmarks/imagenet100_benchmark.py @@ -225,7 +225,9 @@ def __init__(self, dataloader_kNN, num_classes): utils.deactivate_requires_grad(self.projection_head_momentum) # create our loss with the optional memory bank - self.criterion = NTXentLoss(temperature=0.07, memory_bank_size=memory_bank_size) + self.criterion = NTXentLoss( + temperature=0.07, memory_bank_size=(memory_bank_size, 128) + ) def forward(self, x): x = self.backbone(x).flatten(start_dim=1) @@ -505,7 +507,7 @@ def __init__(self, dataloader_kNN, num_classes): self.projection_head = heads.NNCLRProjectionHead(feature_dim, 4096, 256) self.criterion = NTXentLoss() - self.memory_bank = modules.NNMemoryBankModule(size=memory_bank_size) + self.memory_bank = modules.NNMemoryBankModule(size=(memory_bank_size, 256)) def forward(self, x): y = self.backbone(x).flatten(start_dim=1) diff --git a/docs/source/getting_started/benchmarks/imagenette_benchmark.py b/docs/source/getting_started/benchmarks/imagenette_benchmark.py index b5c5e8e71..ee9cc0171 100644 --- a/docs/source/getting_started/benchmarks/imagenette_benchmark.py +++ b/docs/source/getting_started/benchmarks/imagenette_benchmark.py @@ -85,10 +85,9 @@ TiCoLoss, VICRegLLoss, VICRegLoss, - memory_bank, ) from lightly.models import modules, utils -from lightly.models.modules import heads, masked_autoencoder +from lightly.models.modules import heads, masked_autoencoder, memory_bank from lightly.transforms import ( BYOLTransform, BYOLView1Transform, @@ -330,7 +329,9 @@ def __init__(self, dataloader_kNN, num_classes): utils.deactivate_requires_grad(self.projection_head_momentum) # create our loss with the optional memory bank - self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=memory_bank_size) + self.criterion = NTXentLoss( + temperature=0.1, memory_bank_size=(memory_bank_size, 128) + ) def forward(self, x): x = self.backbone(x).flatten(start_dim=1) @@ -1015,7 +1016,7 @@ def __init__(self, dataloader_kNN, num_classes): # smog self.n_groups = 300 memory_bank_size = 10000 - self.memory_bank = memory_bank.MemoryBankModule(size=memory_bank_size) + self.memory_bank = memory_bank.MemoryBankModule(size=(memory_bank_size, 128)) # create our loss group_features = torch.nn.functional.normalize( torch.rand(self.n_groups, 128), dim=1 @@ -1320,7 +1321,7 @@ def __init__(self, dataloader_kNN, num_classes): self.prototypes = heads.SwaVPrototypes(128, 3000, 1) self.start_queue_at_epoch = 15 self.queues = nn.ModuleList( - [memory_bank.MemoryBankModule(size=384) for _ in range(2)] + [memory_bank.MemoryBankModule(size=(384, 128)) for _ in range(2)] ) # Queue size reduced in order to work with a smaller dataset self.criterion = SwaVLoss() diff --git a/docs/source/tutorials_source/package/tutorial_custom_augmentations.py b/docs/source/tutorials_source/package/tutorial_custom_augmentations.py index 7e572aa1b..fc68889ad 100644 --- a/docs/source/tutorials_source/package/tutorial_custom_augmentations.py +++ b/docs/source/tutorials_source/package/tutorial_custom_augmentations.py @@ -268,7 +268,7 @@ def __init__(self): deactivate_requires_grad(self.projection_head_momentum) # Create the loss function with memory bank. - self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=4096) + self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=(4096, 128)) def training_step(self, batch, batch_idx): (x_q, x_k), _, _ = batch diff --git a/docs/source/tutorials_source/package/tutorial_moco_memory_bank.py b/docs/source/tutorials_source/package/tutorial_moco_memory_bank.py index f97992607..399b41956 100644 --- a/docs/source/tutorials_source/package/tutorial_moco_memory_bank.py +++ b/docs/source/tutorials_source/package/tutorial_moco_memory_bank.py @@ -237,7 +237,9 @@ def __init__(self): deactivate_requires_grad(self.projection_head_momentum) # create our loss with the optional memory bank - self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=memory_bank_size) + self.criterion = NTXentLoss( + temperature=0.1, memory_bank_size=(memory_bank_size, 128) + ) def training_step(self, batch, batch_idx): (x_q, x_k), _, _ = batch diff --git a/examples/pytorch/moco.py b/examples/pytorch/moco.py index 961337d8f..613dd4a20 100644 --- a/examples/pytorch/moco.py +++ b/examples/pytorch/moco.py @@ -62,7 +62,7 @@ def forward_momentum(self, x): num_workers=8, ) -criterion = NTXentLoss(memory_bank_size=4096) +criterion = NTXentLoss(memory_bank_size=(4096, 128)) optimizer = torch.optim.SGD(model.parameters(), lr=0.06) epochs = 10 diff --git a/examples/pytorch/nnclr.py b/examples/pytorch/nnclr.py index c3745235a..530e18153 100644 --- a/examples/pytorch/nnclr.py +++ b/examples/pytorch/nnclr.py @@ -38,7 +38,7 @@ def forward(self, x): device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) -memory_bank = NNMemoryBankModule(size=4096) +memory_bank = NNMemoryBankModule(size=(4096, 128)) memory_bank.to(device) transform = SimCLRTransform(input_size=32) diff --git a/examples/pytorch/smog.py b/examples/pytorch/smog.py index a4f81c8ed..0e59758f7 100644 --- a/examples/pytorch/smog.py +++ b/examples/pytorch/smog.py @@ -9,13 +9,13 @@ from sklearn.cluster import KMeans from torch import nn -from lightly.loss.memory_bank import MemoryBankModule from lightly.models import utils from lightly.models.modules.heads import ( SMoGPredictionHead, SMoGProjectionHead, SMoGPrototypes, ) +from lightly.models.modules.memory_bank import MemoryBankModule from lightly.transforms.smog_transform import SMoGTransform @@ -80,7 +80,7 @@ def forward_momentum(self, x): # memory bank because we reset the group features every 300 iterations memory_bank_size = 300 * batch_size -memory_bank = MemoryBankModule(size=memory_bank_size) +memory_bank = MemoryBankModule(size=(memory_bank_size, 128)) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) diff --git a/examples/pytorch/swav_queue.py b/examples/pytorch/swav_queue.py index 232cf3b55..a3a148482 100644 --- a/examples/pytorch/swav_queue.py +++ b/examples/pytorch/swav_queue.py @@ -7,8 +7,8 @@ from torch import nn from lightly.loss import SwaVLoss -from lightly.loss.memory_bank import MemoryBankModule from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes +from lightly.models.modules.memory_bank import MemoryBankModule from lightly.transforms.swav_transform import SwaVTransform @@ -20,7 +20,9 @@ def __init__(self, backbone): self.prototypes = SwaVPrototypes(128, 512, 1) self.start_queue_at_epoch = 2 - self.queues = nn.ModuleList([MemoryBankModule(size=3840) for _ in range(2)]) + self.queues = nn.ModuleList( + [MemoryBankModule(size=(3840, 128)) for _ in range(2)] + ) def forward(self, high_resolution, low_resolution, epoch): self.prototypes.normalize() diff --git a/examples/pytorch_lightning/moco.py b/examples/pytorch_lightning/moco.py index 3e6e1d973..7aa32dbbf 100644 --- a/examples/pytorch_lightning/moco.py +++ b/examples/pytorch_lightning/moco.py @@ -29,7 +29,7 @@ def __init__(self): deactivate_requires_grad(self.backbone_momentum) deactivate_requires_grad(self.projection_head_momentum) - self.criterion = NTXentLoss(memory_bank_size=4096) + self.criterion = NTXentLoss(memory_bank_size=(4096, 128)) def forward(self, x): query = self.backbone(x).flatten(start_dim=1) diff --git a/examples/pytorch_lightning/nnclr.py b/examples/pytorch_lightning/nnclr.py index 2f859a111..0a3c462cc 100644 --- a/examples/pytorch_lightning/nnclr.py +++ b/examples/pytorch_lightning/nnclr.py @@ -23,7 +23,7 @@ def __init__(self): self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.projection_head = NNCLRProjectionHead(512, 512, 128) self.prediction_head = NNCLRPredictionHead(128, 512, 128) - self.memory_bank = NNMemoryBankModule(size=4096) + self.memory_bank = NNMemoryBankModule(size=(4096, 128)) self.criterion = NTXentLoss() diff --git a/examples/pytorch_lightning/swav_queue.py b/examples/pytorch_lightning/swav_queue.py index 0a2d40ab4..2aad1c44a 100644 --- a/examples/pytorch_lightning/swav_queue.py +++ b/examples/pytorch_lightning/swav_queue.py @@ -8,8 +8,8 @@ from torch import nn from lightly.loss import SwaVLoss -from lightly.loss.memory_bank import MemoryBankModule from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes +from lightly.models.modules.memory_bank import MemoryBankModule from lightly.transforms.swav_transform import SwaVTransform @@ -21,7 +21,9 @@ def __init__(self): self.projection_head = SwaVProjectionHead(512, 512, 128) self.prototypes = SwaVPrototypes(128, 512, 1) self.start_queue_at_epoch = 2 - self.queues = nn.ModuleList([MemoryBankModule(size=3840) for _ in range(2)]) + self.queues = nn.ModuleList( + [MemoryBankModule(size=(3840, 128)) for _ in range(2)] + ) self.criterion = SwaVLoss() def training_step(self, batch, batch_idx): diff --git a/examples/pytorch_lightning_distributed/moco.py b/examples/pytorch_lightning_distributed/moco.py index 03b9698ac..6a59a096a 100644 --- a/examples/pytorch_lightning_distributed/moco.py +++ b/examples/pytorch_lightning_distributed/moco.py @@ -29,7 +29,7 @@ def __init__(self): deactivate_requires_grad(self.backbone_momentum) deactivate_requires_grad(self.projection_head_momentum) - self.criterion = NTXentLoss(memory_bank_size=4096) + self.criterion = NTXentLoss(memory_bank_size=(4096, 128)) def forward(self, x): query = self.backbone(x).flatten(start_dim=1) diff --git a/examples/pytorch_lightning_distributed/nnclr.py b/examples/pytorch_lightning_distributed/nnclr.py index 9c02e2834..d925265bd 100644 --- a/examples/pytorch_lightning_distributed/nnclr.py +++ b/examples/pytorch_lightning_distributed/nnclr.py @@ -23,7 +23,7 @@ def __init__(self): self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.projection_head = NNCLRProjectionHead(512, 512, 128) self.prediction_head = NNCLRPredictionHead(128, 512, 128) - self.memory_bank = NNMemoryBankModule(size=4096) + self.memory_bank = NNMemoryBankModule(size=(4096, 128)) self.criterion = NTXentLoss() diff --git a/lightly/loss/memory_bank.py b/lightly/loss/memory_bank.py index 998917715..3c01a7082 100644 --- a/lightly/loss/memory_bank.py +++ b/lightly/loss/memory_bank.py @@ -1,131 +1,2 @@ -""" Memory Bank Wrapper """ - -# Copyright (c) 2020. Lightly AG and its affiliates. -# All Rights Reserved - -from typing import Optional, Tuple, Union - -import torch -from torch import Tensor - - -class MemoryBankModule(torch.nn.Module): - """Memory bank implementation - - This is a parent class to all loss functions implemented by the lightly - Python package. This way, any loss can be used with a memory bank if - desired. - - Attributes: - size: - Number of keys the memory bank can store. If set to 0, - memory bank is not used. - - Examples: - >>> class MyLossFunction(MemoryBankModule): - >>> - >>> def __init__(self, memory_bank_size: int = 2 ** 16): - >>> super(MyLossFunction, self).__init__(memory_bank_size) - >>> - >>> def forward(self, output: Tensor, - >>> labels: Tensor = None): - >>> - >>> output, negatives = super( - >>> MyLossFunction, self).forward(output) - >>> - >>> if negatives is not None: - >>> # evaluate loss with negative samples - >>> else: - >>> # evaluate loss without negative samples - - """ - - def __init__(self, size: int = 2**16): - super(MemoryBankModule, self).__init__() - - if size < 0: - msg = f"Illegal memory bank size {size}, must be non-negative." - raise ValueError(msg) - - self.size = size - self.register_buffer( - "bank", tensor=torch.empty(0, dtype=torch.float), persistent=False - ) - self.register_buffer( - "bank_ptr", tensor=torch.empty(0, dtype=torch.long), persistent=False - ) - - @torch.no_grad() - def _init_memory_bank(self, dim: int) -> None: - """Initialize the memory bank if it's empty - - Args: - dim: - The dimension of the which are stored in the bank. - - """ - # create memory bank - # we could use register buffers like in the moco repo - # https://github.com/facebookresearch/moco but we don't - # want to pollute our checkpoints - bank: Tensor = torch.randn(dim, self.size).type_as(self.bank) - self.bank: Tensor = torch.nn.functional.normalize(bank, dim=0) - self.bank_ptr: Tensor = torch.zeros(1).type_as(self.bank_ptr) - - @torch.no_grad() - def _dequeue_and_enqueue(self, batch: Tensor) -> None: - """Dequeue the oldest batch and add the latest one - - Args: - batch: - The latest batch of keys to add to the memory bank. - - """ - batch_size = batch.shape[0] - ptr = int(self.bank_ptr) - - if ptr + batch_size >= self.size: - self.bank[:, ptr:] = batch[: self.size - ptr].T.detach() - self.bank_ptr[0] = 0 - else: - self.bank[:, ptr : ptr + batch_size] = batch.T.detach() - self.bank_ptr[0] = ptr + batch_size - - def forward( - self, - output: Tensor, - labels: Optional[Tensor] = None, - update: bool = False, - ) -> Union[Tuple[Tensor, Optional[Tensor]], Tensor]: - """Query memory bank for additional negative samples - - Args: - output: - The output of the model. - labels: - Should always be None, will be ignored. - - Returns: - The output if the memory bank is of size 0, otherwise the output - and the entries from the memory bank. - - """ - - # no memory bank, return the output - if self.size == 0: - return output, None - - _, dim = output.shape - - # initialize the memory bank if it is not already done - if self.bank.nelement() == 0: - self._init_memory_bank(dim) - - # query and update memory bank - bank = self.bank.clone().detach() - - # only update memory bank if we later do backward pass (gradient) - if update: - self._dequeue_and_enqueue(output) - - return output, bank +# For backwards compatibility as memory_bank module was previously in loss module. +from lightly.models.modules.memory_bank import MemoryBankModule diff --git a/lightly/loss/ntx_ent_loss.py b/lightly/loss/ntx_ent_loss.py index b94211ec7..7f0c1d429 100644 --- a/lightly/loss/ntx_ent_loss.py +++ b/lightly/loss/ntx_ent_loss.py @@ -3,11 +3,13 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved +from typing import Sequence, Union + import torch from torch import distributed as torch_dist from torch import nn -from lightly.loss.memory_bank import MemoryBankModule +from lightly.models.modules.memory_bank import MemoryBankModule from lightly.utils import dist @@ -25,11 +27,18 @@ class NTXentLoss(MemoryBankModule): temperature: Scale logits by the inverse of the temperature. memory_bank_size: - Number of negative samples to store in the memory bank. - Use 0 for SimCLR. For MoCo we typically use numbers like 4096 or 65536. + Size of the memory bank as (num_features, dim) tuple. num_features are the + number of negative samples stored in the memory bank. If num_features is 0, + the memory bank is disabled. Use 0 for SimCLR. For MoCo we typically use + numbers like 4096 or 65536. + Deprecated: If only a single integer is passed, it is interpreted as the + number of features and the feature dimension is inferred from the first + batch stored in the memory bank. Leaving out the feature dimension might + lead to errors in distributed training. gather_distributed: If True then negatives from all gpus are gathered before the - loss calculation. This flag has no effect if memory_bank_size > 0. + loss calculation. If a memory bank is used and gather_distributed is True, + then tensors from all gpus are gathered before the memory bank is updated. Raises: ValueError: If abs(temperature) < 1e-8 to prevent divide by zero. @@ -55,10 +64,10 @@ class NTXentLoss(MemoryBankModule): def __init__( self, temperature: float = 0.5, - memory_bank_size: int = 0, + memory_bank_size: Union[int, Sequence[int]] = 0, gather_distributed: bool = False, ): - super(NTXentLoss, self).__init__(size=memory_bank_size) + super().__init__(size=memory_bank_size, gather_distributed=gather_distributed) self.temperature = temperature self.gather_distributed = gather_distributed self.cross_entropy = nn.CrossEntropyLoss(reduction="mean") diff --git a/lightly/loss/regularizer/co2.py b/lightly/loss/regularizer/co2.py index 2145cf04a..e61a7a302 100644 --- a/lightly/loss/regularizer/co2.py +++ b/lightly/loss/regularizer/co2.py @@ -3,9 +3,11 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved +from typing import Sequence, Union + import torch -from lightly.loss.memory_bank import MemoryBankModule +from lightly.models.modules.memory_bank import MemoryBankModule class CO2Regularizer(MemoryBankModule): @@ -19,15 +21,19 @@ class CO2Regularizer(MemoryBankModule): t_consistency: Temperature used during softmax calculations. memory_bank_size: - Number of negative samples to store in the memory bank. - Use 0 to use the second batch for negative samples. + Size of the memory bank as (num_features, dim) tuple. num_features is the + number of negatives stored in the bank. If set to 0, the memory bank is + disabled. Deprecated: If only a single integer is passed, it is interpreted + as the number of features and the feature dimension is inferred from the + first batch stored in the memory bank. Leaving out the feature dimension + might lead to errors in distributed training. Examples: >>> # initialize loss function for MoCo - >>> loss_fn = NTXentLoss(memory_bank_size=4096) + >>> loss_fn = NTXentLoss(memory_bank_size=(4096, 128)) >>> >>> # initialize CO2 regularizer - >>> co2 = CO2Regularizer(alpha=1.0, memory_bank_size=4096) + >>> co2 = CO2Regularizer(alpha=1.0, memory_bank_size=(4096, 128)) >>> >>> # generate two random trasnforms of images >>> t0 = transforms(images) @@ -42,7 +48,10 @@ class CO2Regularizer(MemoryBankModule): """ def __init__( - self, alpha: float = 1, t_consistency: float = 0.05, memory_bank_size: int = 0 + self, + alpha: float = 1, + t_consistency: float = 0.05, + memory_bank_size: Union[int, Sequence[int]] = 0, ): super(CO2Regularizer, self).__init__(size=memory_bank_size) # try-catch the KLDivLoss construction for backwards compatability diff --git a/lightly/models/modules/memory_bank.py b/lightly/models/modules/memory_bank.py new file mode 100644 index 000000000..db7139b8c --- /dev/null +++ b/lightly/models/modules/memory_bank.py @@ -0,0 +1,177 @@ +""" Memory Bank Wrapper """ + +# Copyright (c) 2020. Lightly AG and its affiliates. +# All Rights Reserved + +import warnings +from typing import Sequence, Tuple, Union + +import torch +from torch import Tensor +from torch.nn import Module + +from lightly.models import utils + + +class MemoryBankModule(Module): + """Memory bank implementation + + This is a parent class to all loss functions implemented by the lightly + Python package. This way, any loss can be used with a memory bank if + desired. + + Attributes: + size: + Size of the memory bank as (num_features, dim) tuple. If num_features is 0 + then the memory bank is disabled. Deprecated: If only a single integer is + passed, it is interpreted as the number of features and the feature + dimension is inferred from the first batch stored in the memory bank. + Leaving out the feature dimension might lead to errors in distributed + training. + gather_distributed: + If True then negatives from all gpus are gathered before the memory bank + is updated. This results in more frequent updates of the memory bank and + keeps the memory bank contents independent of the number of gpus. But it has + the drawback that synchronization between processes is required and + diversity of the memory bank content is reduced. + feature_dim_first: + If True, the memory bank returns features with shape (dim, num_features). + If False, the memory bank returns features with shape (num_features, dim). + + Examples: + >>> class MyLossFunction(MemoryBankModule): + >>> + >>> def __init__(self, memory_bank_size: Tuple[int, int] = (2 ** 16, 128)): + >>> super().__init__(memory_bank_size) + >>> + >>> def forward(self, output: Tensor, labels: Union[Tensor, None] = None): + >>> output, negatives = super().forward(output) + >>> + >>> if negatives is not None: + >>> # evaluate loss with negative samples + >>> else: + >>> # evaluate loss without negative samples + + """ + + def __init__( + self, + size: Union[int, Sequence[int]] = 65536, + gather_distributed: bool = False, + feature_dim_first: bool = True, + ): + super().__init__() + size_tuple = (size,) if isinstance(size, int) else tuple(size) + + if any(x < 0 for x in size_tuple): + raise ValueError( + f"Illegal memory bank size {size}, all entries must be non-negative." + ) + + self.size = size_tuple + self.gather_distributed = gather_distributed + self.feature_dim_first = feature_dim_first + self.bank: Tensor + self.register_buffer( + "bank", + tensor=torch.empty(size=size_tuple, dtype=torch.float), + persistent=False, + ) + self.bank_ptr: Tensor + self.register_buffer( + "bank_ptr", + tensor=torch.empty(1, dtype=torch.long), + persistent=False, + ) + + if isinstance(size, int) and size > 0: + warnings.warn( + ( + f"Memory bank size 'size={size}' does not specify feature " + "dimension. It is recommended to set the feature dimension with " + "'size=(n, dim)' when creating the memory bank. Distributed " + "training might fail if the feature dimension is not set." + ), + UserWarning, + ) + elif len(size_tuple) > 1: + self._init_memory_bank(size=size_tuple) + + @torch.no_grad() + def _init_memory_bank(self, size: Tuple[int, ...]) -> None: + """Initialize the memory bank. + + Args: + size: + Size of the memory bank as (num_features, dim) tuple. + + """ + self.bank = torch.randn(size).type_as(self.bank) + self.bank = torch.nn.functional.normalize(self.bank, dim=-1) + self.bank_ptr = torch.zeros(1).type_as(self.bank_ptr) + + @torch.no_grad() + def _dequeue_and_enqueue(self, batch: Tensor) -> None: + """Dequeue the oldest batch and add the latest one + + Args: + batch: + The latest batch of keys to add to the memory bank. + + """ + if self.gather_distributed: + batch = utils.concat_all_gather(batch) + + batch_size = batch.shape[0] + ptr = int(self.bank_ptr) + if ptr + batch_size >= self.size[0]: + self.bank[ptr:] = batch[: self.size[0] - ptr].detach() + self.bank_ptr.zero_() + else: + self.bank[ptr : ptr + batch_size] = batch.detach() + self.bank_ptr[0] = ptr + batch_size + + def forward( + self, + output: Tensor, + labels: Union[Tensor, None] = None, + update: bool = False, + ) -> Tuple[Tensor, Union[Tensor, None]]: + """Query memory bank for additional negative samples + + Args: + output: + The output of the model. + labels: + Should always be None, will be ignored. + update: + If True, the memory bank will be updated with the current output. + + Returns: + The output if the memory bank is of size 0, otherwise the output + and the entries from the memory bank. Entries from the memory bank have + shape (dim, num_features) if feature_dim_first is True and + (num_features, dim) otherwise. + + """ + + # no memory bank, return the output + if self.size[0] == 0: + return output, None + + # Initialize the memory bank if it is not already done. + if self.bank.ndim == 1: + dim = output.shape[1:] + self._init_memory_bank(size=(*self.size, *dim)) + + # query and update memory bank + bank = self.bank.clone().detach() + if self.feature_dim_first: + # swap bank size and feature dimension for backwards compatibility + bank = bank.transpose(0, -1) + + # only update memory bank if we later do backward pass (gradient) + if update: + self._dequeue_and_enqueue(output) + + return output, bank diff --git a/lightly/models/modules/nn_memory_bank.py b/lightly/models/modules/nn_memory_bank.py index 59abb65a5..8e5f1cd31 100644 --- a/lightly/models/modules/nn_memory_bank.py +++ b/lightly/models/modules/nn_memory_bank.py @@ -3,12 +3,12 @@ # Copyright (c) 2021. Lightly AG and its affiliates. # All Rights Reserved -from typing import Optional +from typing import Sequence, Union import torch from torch import Tensor -from lightly.loss.memory_bank import MemoryBankModule +from lightly.models.modules.memory_bank import MemoryBankModule class NNMemoryBankModule(MemoryBankModule): @@ -22,13 +22,18 @@ class NNMemoryBankModule(MemoryBankModule): Attributes: size: - Number of keys the memory bank can store. + Size of the memory bank as (num_features, dim) tuple. If num_features is 0 + then the memory bank is disabled. Deprecated: If only a single integer is + passed, it is interpreted as the number of features and the feature + dimension is inferred from the first batch stored in the memory bank. + Leaving out the feature dimension might lead to errors in distributed + training. Examples: >>> model = NNCLR(backbone) >>> criterion = NTXentLoss(temperature=0.1) >>> - >>> nn_replacer = NNmemoryBankModule(size=2 ** 16) + >>> nn_replacer = NNmemoryBankModule(size=(2 ** 16, 128)) >>> >>> # forward pass >>> (z0, p0), (z1, p1) = model(x0, x1) @@ -39,9 +44,7 @@ class NNMemoryBankModule(MemoryBankModule): """ - def __init__(self, size: int = 2**16): - if size <= 0: - raise ValueError(f"Memory bank size must be positive, got {size}.") + def __init__(self, size: Union[int, Sequence[int]] = 2**16): super(NNMemoryBankModule, self).__init__(size) def forward( # type: ignore[override] # TODO(Philipp, 11/23): Fix signature to match parent class. diff --git a/lightly/models/utils.py b/lightly/models/utils.py index bd7c4e26d..ee96c067c 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -95,7 +95,7 @@ def concat_all_gather(x: torch.Tensor) -> torch.Tensor: @torch.no_grad() def batch_shuffle_distributed(batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """Shuffles batch over multiple gpus. + """Shuffles batch over multiple devices. This code was taken and adapted from here: https://github.com/facebookresearch/moco. @@ -109,26 +109,25 @@ def batch_shuffle_distributed(batch: torch.Tensor) -> Tuple[torch.Tensor, torch. input batch and shuffle is an index to restore the original order. """ - # gather from all gpus + # gather from all devices batch_size_this = batch.shape[0] batch_gather = concat_all_gather(batch) batch_size_all = batch_gather.shape[0] - num_gpus = batch_size_all // batch_size_this + num_devices = batch_size_all // batch_size_this # random shuffle index - idx_shuffle = torch.randperm(batch_size_all).cuda() + idx_shuffle = torch.randperm(batch_size_all, device=batch.device) - # broadcast to all gpus + # broadcast to all devices dist.broadcast(idx_shuffle, src=0) # index for restoring shuffle = torch.argsort(idx_shuffle) - # shuffled index for this gpu - gpu_idx = dist.get_rank() - idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] - + # shuffled index for this device + rank = dist.get_rank() + idx_this = idx_shuffle.view(num_devices, -1)[rank] return batch_gather[idx_this], shuffle @@ -136,7 +135,7 @@ def batch_shuffle_distributed(batch: torch.Tensor) -> Tuple[torch.Tensor, torch. def batch_unshuffle_distributed( batch: torch.Tensor, shuffle: torch.Tensor ) -> torch.Tensor: - """Undo batch shuffle over multiple gpus. + """Undo batch shuffle over multiple devices. This code was taken and adapted from here: https://github.com/facebookresearch/moco. @@ -151,17 +150,16 @@ def batch_unshuffle_distributed( The unshuffled tensor. """ - # gather from all gpus + # gather from all devices batch_size_this = batch.shape[0] batch_gather = concat_all_gather(batch) batch_size_all = batch_gather.shape[0] - num_gpus = batch_size_all // batch_size_this + num_devices = batch_size_all // batch_size_this # restored index for this gpu - gpu_idx = dist.get_rank() - idx_this = shuffle.view(num_gpus, -1)[gpu_idx] - + rank = dist.get_rank() + idx_this = shuffle.view(num_devices, -1)[rank] return batch_gather[idx_this] diff --git a/lightly/utils/benchmarking/metric_callback.py b/lightly/utils/benchmarking/metric_callback.py index b635b2fbb..ffc9747c6 100644 --- a/lightly/utils/benchmarking/metric_callback.py +++ b/lightly/utils/benchmarking/metric_callback.py @@ -60,7 +60,7 @@ def _append_metrics( self, metrics_dict: Dict[str, List[float]], trainer: Trainer ) -> None: for name, value in trainer.callback_metrics.items(): - if isinstance(value, float) or ( - isinstance(value, Tensor) and value.numel() == 1 - ): - metrics_dict.setdefault(name, []).append(float(value)) + if isinstance(value, Tensor) and value.numel() != 1: + # Skip non-scalar tensors. + continue + metrics_dict.setdefault(name, []).append(float(value)) diff --git a/tests/loss/test_CO2Regularizer.py b/tests/loss/test_CO2Regularizer.py index 7b56efcb1..28918eedd 100644 --- a/tests/loss/test_CO2Regularizer.py +++ b/tests/loss/test_CO2Regularizer.py @@ -18,7 +18,7 @@ def test_forward_pass_no_memory_bank(self): self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0) def test_forward_pass_memory_bank(self): - reg = CO2Regularizer(memory_bank_size=4096) + reg = CO2Regularizer(memory_bank_size=(4096, 32)) for bsz in range(1, 20): batch_1 = torch.randn((bsz, 32)) batch_2 = torch.randn((bsz, 32)) @@ -44,7 +44,7 @@ def test_forward_pass_cuda_memory_bank(self): if not torch.cuda.is_available(): return - reg = CO2Regularizer(memory_bank_size=4096) + reg = CO2Regularizer(memory_bank_size=(4096, 32)) for bsz in range(1, 20): batch_1 = torch.randn((bsz, 32)).cuda() batch_2 = torch.randn((bsz, 32)).cuda() diff --git a/tests/loss/test_MemoryBank.py b/tests/loss/test_MemoryBank.py deleted file mode 100644 index dc1df7ca0..000000000 --- a/tests/loss/test_MemoryBank.py +++ /dev/null @@ -1,63 +0,0 @@ -import unittest - -import torch - -from lightly.loss.memory_bank import MemoryBankModule - - -class TestNTXentLoss(unittest.TestCase): - def test_init__negative_size(self): - with self.assertRaises(ValueError): - MemoryBankModule(size=-1) - - def test_forward_easy(self): - bsz = 3 - dim, size = 2, 9 - n = 33 * bsz - memory_bank = MemoryBankModule(size=size) - - ptr = 0 - for i in range(0, n, bsz): - output = torch.randn(2 * bsz, dim) - output.requires_grad = True - out0, out1 = output[:bsz], output[bsz:] - - _, curr_memory_bank = memory_bank(out1, update=True) - next_memory_bank = memory_bank.bank - - curr_diff = out0.T - curr_memory_bank[:, ptr : ptr + bsz] - next_diff = out1.T - next_memory_bank[:, ptr : ptr + bsz] - - # the current memory bank should not hold the batch yet - self.assertGreater(curr_diff.norm(), 1e-5) - # the "next" memory bank should hold the batch - self.assertGreater(1e-5, next_diff.norm()) - - ptr = (ptr + bsz) % size - - def test_forward(self): - bsz = 3 - dim, size = 2, 10 - n = 33 * bsz - memory_bank = MemoryBankModule(size=size) - - for i in range(0, n, bsz): - # see if there are any problems when the bank size - # is no multiple of the batch size - output = torch.randn(bsz, dim) - _, _ = memory_bank(output) - - @unittest.skipUnless(torch.cuda.is_available(), "cuda not available") - def test_forward__cuda(self): - bsz = 3 - dim, size = 2, 10 - n = 33 * bsz - memory_bank = MemoryBankModule(size=size) - device = torch.device("cuda") - memory_bank.to(device=device) - - for i in range(0, n, bsz): - # see if there are any problems when the bank size - # is no multiple of the batch size - output = torch.randn(bsz, dim, device=device) - _, _ = memory_bank(output) diff --git a/tests/models/modules/test_memory_bank.py b/tests/models/modules/test_memory_bank.py new file mode 100644 index 000000000..0f6ccae53 --- /dev/null +++ b/tests/models/modules/test_memory_bank.py @@ -0,0 +1,161 @@ +import re +import unittest + +import pytest +import torch + +from lightly.models.modules.memory_bank import MemoryBankModule + + +class TestNTXentLoss(unittest.TestCase): + def test_init__negative_size(self) -> None: + with self.assertRaises(ValueError): + MemoryBankModule(size=-1) + + def test_forward_easy(self) -> None: + bsz = 3 + dim, size = 2, 9 + n = 33 * bsz + memory_bank = MemoryBankModule(size=size) + + ptr = 0 + for i in range(0, n, bsz): + output = torch.randn(2 * bsz, dim) + output.requires_grad = True + out0, out1 = output[:bsz], output[bsz:] + + _, curr_memory_bank = memory_bank(out1, update=True) + next_memory_bank = memory_bank.bank.transpose(0, -1) + + curr_diff = out0.T - curr_memory_bank[:, ptr : ptr + bsz] + next_diff = out1.T - next_memory_bank[:, ptr : ptr + bsz] + + # the current memory bank should not hold the batch yet + self.assertGreater(curr_diff.norm(), 1e-5) + # the "next" memory bank should hold the batch + self.assertGreater(1e-5, next_diff.norm()) + + ptr = (ptr + bsz) % size + + def test_forward(self) -> None: + bsz = 3 + dim, size = 2, 10 + n = 33 * bsz + memory_bank = MemoryBankModule(size=size) + + for i in range(0, n, bsz): + # see if there are any problems when the bank size + # is no multiple of the batch size + output = torch.randn(bsz, dim) + _, _ = memory_bank(output) + + @unittest.skipUnless(torch.cuda.is_available(), "cuda not available") + def test_forward__cuda(self) -> None: + bsz = 3 + dim, size = 2, 10 + n = 33 * bsz + memory_bank = MemoryBankModule(size=size) + device = torch.device("cuda") + memory_bank.to(device=device) + + for i in range(0, n, bsz): + # see if there are any problems when the bank size + # is no multiple of the batch size + output = torch.randn(bsz, dim, device=device) + _, _ = memory_bank(output) + + +class TestMemoryBank: + def test_init__negative_size(self) -> None: + with pytest.raises( + ValueError, + match="Illegal memory bank size -1, all entries must be non-negative.", + ): + MemoryBankModule(size=-1) + + with pytest.raises( + ValueError, + match=re.escape( + "Illegal memory bank size (10, -1), all entries must be non-negative." + ), + ): + MemoryBankModule(size=(10, -1)) + + def test_init__no_dim_warning(self) -> None: + with pytest.warns( + UserWarning, + match=re.escape( + "Memory bank size 'size=10' does not specify feature " + "dimension. It is recommended to set the feature dimension with " + "'size=(n, dim)' when creating the memory bank. Distributed " + "training might fail if the feature dimension is not set." + ), + ): + MemoryBankModule(size=10) + + def test_forward(self) -> None: + torch.manual_seed(0) + memory_bank = MemoryBankModule(size=(5, 2), feature_dim_first=False) + x0 = torch.randn(3, 2) + out0, bank0 = memory_bank(x0, update=True) + # Verify that output is same as input. + assert out0.tolist() == x0.tolist() + # Verify that memory bank was initialized and has correct shape. + assert bank0.shape == (5, 2) + assert memory_bank.bank.shape == (5, 2) + # Verify that output bank does not contain features from x0. + assert bank0[:3].tolist() != x0.tolist() + # Verify that memory bank was updated. + assert memory_bank.bank[:3].tolist() == x0.tolist() + + x1 = torch.randn(3, 2) + out1, bank1 = memory_bank(x1, update=True) + # Verify that output is same as input. + assert out1.tolist() == x1.tolist() + # Verify that output bank contains features from x0. + assert bank1[:3].tolist() == x0.tolist() + # Verify that output bank does not contain features from x1. + assert bank1[3:].tolist() != x1[:2].tolist() + # Verify that memory bank was updated. + assert memory_bank.bank[:3].tolist() == x0.tolist() + assert memory_bank.bank[3:].tolist() == x1[:2].tolist() + + # At this point the memory bank is full. + # Adding more features will start overwriting the bank from the beginning. + + x2 = torch.randn(3, 2) + out2, bank2 = memory_bank(x2, update=True) + # Verify that output is same as input. + assert out2.tolist() == x2.tolist() + # Verify that output bank contains features from x0 and x1. + assert bank2[:3].tolist() == x0.tolist() + assert bank2[3:].tolist() == x1[:2].tolist() + # Verify that memory bank is overwritten. + assert memory_bank.bank[:3].tolist() == x2.tolist() + + def test_forward__no_dim(self) -> None: + torch.manual_seed(0) + # Only specify size but not feature dimension. + memory_bank = MemoryBankModule(size=5, feature_dim_first=False) + x0 = torch.randn(3, 2) + out0, bank0 = memory_bank(x0, update=True) + # Verify that output is same as input. + assert out0.tolist() == x0.tolist() + # Verify that memory bank was initialized and has correct shape. + assert bank0.shape == (5, 2) + assert memory_bank.bank.shape == (5, 2) + # Verify that output bank does not contain features from x0. + assert bank0[:3].tolist() != x0.tolist() + # Verify that memory bank was updated. + assert memory_bank.bank[:3].tolist() == x0.tolist() + + def test_forward__dim_first(self) -> None: + torch.manual_seed(0) + memory_bank = MemoryBankModule(size=(5, 2), feature_dim_first=True) + x0 = torch.randn(3, 2) + out0, bank0 = memory_bank(x0, update=True) + assert bank0.shape == (2, 5) + x1 = torch.randn(3, 2) + out1, bank1 = memory_bank(x1, update=True) + assert bank1.shape == (2, 5) + assert bank1[:, :3].tolist() == x0.T.tolist() diff --git a/tests/models/test_ModelsNNCLR.py b/tests/models/test_ModelsNNCLR.py index 3760faef3..ac8eb05e8 100644 --- a/tests/models/test_ModelsNNCLR.py +++ b/tests/models/test_ModelsNNCLR.py @@ -115,7 +115,7 @@ def test_memory_bank(self): model = NNCLR(get_backbone(resnet), **config).to(device) for nn_size in [2**3, 2**8]: - nn_replacer = NNMemoryBankModule(size=nn_size) + nn_replacer = NNMemoryBankModule(size=(nn_size, config["out_dim"])) with torch.no_grad(): for i in range(10):