diff --git a/lightly/utils/benchmarking/knn_classifier.py b/lightly/utils/benchmarking/knn_classifier.py index be2e6ada5..466c08647 100644 --- a/lightly/utils/benchmarking/knn_classifier.py +++ b/lightly/utils/benchmarking/knn_classifier.py @@ -151,6 +151,12 @@ def on_train_epoch_start(self) -> None: # Set model to eval mode to disable norm layer updates. self.model.eval() + # Reset features and targets. + self._train_features = [] + self._train_targets = [] + self._train_features_tensor = None + self._train_targets_tensor = None + def configure_optimizers(self) -> None: # configure_optimizers must be implemented for PyTorch Lightning. Returning None # means that no optimization is performed. diff --git a/tests/utils/benchmarking/test_knn_classifier.py b/tests/utils/benchmarking/test_knn_classifier.py index 1b776ce7c..119a3ae97 100644 --- a/tests/utils/benchmarking/test_knn_classifier.py +++ b/tests/utils/benchmarking/test_knn_classifier.py @@ -2,11 +2,13 @@ import pytest import torch +from pytest_mock import MockerFixture from pytorch_lightning import Trainer from torch import Tensor, nn from torch.utils.data import DataLoader, Dataset from lightly.utils.benchmarking import KNNClassifier +from lightly.utils.benchmarking.knn_classifier import F class TestKNNClassifier: @@ -131,6 +133,83 @@ def test__features_dtype(self) -> None: assert classifier._train_features_tensor is not None assert classifier._train_features_tensor.dtype == torch.int + def test__normalize(self, mocker: MockerFixture) -> None: + train_features = torch.randn(4, 3) + train_targets = torch.randint(0, 10, (4,)) + train_dataset = _FeaturesDataset(features=train_features, targets=train_targets) + val_features = torch.randn(4, 3) + val_targets = torch.randint(0, 10, (4,)) + val_dataset = _FeaturesDataset(features=val_features, targets=val_targets) + train_dataloader = DataLoader(train_dataset) + val_dataloader = DataLoader(val_dataset) + trainer = Trainer(max_epochs=1, accelerator="cpu", devices=1) + spy_normalize = mocker.spy(F, "normalize") + + # Test that normalize is called when normalize=True. + classifier = KNNClassifier( + nn.Identity(), num_classes=10, knn_k=3, normalize=True + ) + trainer.fit( + model=classifier, + train_dataloaders=train_dataloader, + ) + spy_normalize.assert_called() + spy_normalize.reset_mock() + + trainer.validate(model=classifier, dataloaders=val_dataloader) + spy_normalize.assert_called() + spy_normalize.reset_mock() + + # Test that normalize is not called when normalize=False. + classifier = KNNClassifier( + nn.Identity(), num_classes=10, knn_k=3, normalize=False + ) + trainer.fit( + model=classifier, + train_dataloaders=train_dataloader, + ) + spy_normalize.assert_not_called() + spy_normalize.reset_mock() + + trainer.validate(model=classifier, dataloaders=val_dataloader) + spy_normalize.assert_not_called() + spy_normalize.reset_mock() + + def test__reset_features_and_targets(self) -> None: + train_features = torch.randn(4, 3) + train_targets = torch.randint(0, 10, (4,)) + train_dataset = _FeaturesDataset(features=train_features, targets=train_targets) + val_features = torch.randn(4, 3) + val_targets = torch.randint(0, 10, (4,)) + val_dataset = _FeaturesDataset(features=val_features, targets=val_targets) + train_dataloader = DataLoader(train_dataset) + val_dataloader = DataLoader(val_dataset) + classifier = KNNClassifier(nn.Identity(), num_classes=10, knn_k=3) + + trainer = Trainer(max_epochs=2, accelerator="cpu", devices=1) + trainer.fit( + model=classifier, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ) + + # Check that train features and targets are reset after validation. + assert classifier._train_features == [] + assert classifier._train_targets == [] + assert classifier._train_features_tensor is not None + assert classifier._train_targets_tensor is not None + # Check that train features and targets are not accumulated over multiple + # validation epochs. + assert classifier._train_features_tensor.shape == (3, 4) + assert classifier._train_targets_tensor.shape == (4,) + + # Check that train features and targets are not accumulated over multiple + # training epochs. + trainer = Trainer(max_epochs=2, accelerator="cpu", devices=1) + trainer.fit(model=classifier, train_dataloaders=train_dataloader) + assert len(classifier._train_features) == 4 + assert len(classifier._train_targets) == 4 + class _FeaturesDataset(Dataset): def __init__(self, features: Tensor, targets) -> None: