Skip to content

Commit

Permalink
Reset features after training epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
guarin committed Dec 15, 2023
1 parent 1d42012 commit 453f9bf
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
6 changes: 6 additions & 0 deletions lightly/utils/benchmarking/knn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ def on_train_epoch_start(self) -> None:
# Set model to eval mode to disable norm layer updates.
self.model.eval()

# Reset features and targets.
self._train_features = []
self._train_targets = []
self._train_features_tensor = None
self._train_targets_tensor = None

def configure_optimizers(self) -> None:
# configure_optimizers must be implemented for PyTorch Lightning. Returning None
# means that no optimization is performed.
Expand Down
79 changes: 79 additions & 0 deletions tests/utils/benchmarking/test_knn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import pytest
import torch
from pytest_mock import MockerFixture
from pytorch_lightning import Trainer
from torch import Tensor, nn
from torch.utils.data import DataLoader, Dataset

from lightly.utils.benchmarking import KNNClassifier
from lightly.utils.benchmarking.knn_classifier import F


class TestKNNClassifier:
Expand Down Expand Up @@ -131,6 +133,83 @@ def test__features_dtype(self) -> None:
assert classifier._train_features_tensor is not None
assert classifier._train_features_tensor.dtype == torch.int

def test__normalize(self, mocker: MockerFixture) -> None:
train_features = torch.randn(4, 3)
train_targets = torch.randint(0, 10, (4,))
train_dataset = _FeaturesDataset(features=train_features, targets=train_targets)
val_features = torch.randn(4, 3)
val_targets = torch.randint(0, 10, (4,))
val_dataset = _FeaturesDataset(features=val_features, targets=val_targets)
train_dataloader = DataLoader(train_dataset)
val_dataloader = DataLoader(val_dataset)
trainer = Trainer(max_epochs=1, accelerator="cpu", devices=1)
spy_normalize = mocker.spy(F, "normalize")

# Test that normalize is called when normalize=True.
classifier = KNNClassifier(
nn.Identity(), num_classes=10, knn_k=3, normalize=True
)
trainer.fit(
model=classifier,
train_dataloaders=train_dataloader,
)
spy_normalize.assert_called()
spy_normalize.reset_mock()

trainer.validate(model=classifier, dataloaders=val_dataloader)
spy_normalize.assert_called()
spy_normalize.reset_mock()

# Test that normalize is not called when normalize=False.
classifier = KNNClassifier(
nn.Identity(), num_classes=10, knn_k=3, normalize=False
)
trainer.fit(
model=classifier,
train_dataloaders=train_dataloader,
)
spy_normalize.assert_not_called()
spy_normalize.reset_mock()

trainer.validate(model=classifier, dataloaders=val_dataloader)
spy_normalize.assert_not_called()
spy_normalize.reset_mock()

def test__reset_features_and_targets(self) -> None:
train_features = torch.randn(4, 3)
train_targets = torch.randint(0, 10, (4,))
train_dataset = _FeaturesDataset(features=train_features, targets=train_targets)
val_features = torch.randn(4, 3)
val_targets = torch.randint(0, 10, (4,))
val_dataset = _FeaturesDataset(features=val_features, targets=val_targets)
train_dataloader = DataLoader(train_dataset)
val_dataloader = DataLoader(val_dataset)
classifier = KNNClassifier(nn.Identity(), num_classes=10, knn_k=3)

trainer = Trainer(max_epochs=2, accelerator="cpu", devices=1)
trainer.fit(
model=classifier,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)

# Check that train features and targets are reset after validation.
assert classifier._train_features == []
assert classifier._train_targets == []
assert classifier._train_features_tensor is not None
assert classifier._train_targets_tensor is not None
# Check that train features and targets are not accumulated over multiple
# validation epochs.
assert classifier._train_features_tensor.shape == (3, 4)
assert classifier._train_targets_tensor.shape == (4,)

# Check that train features and targets are not accumulated over multiple
# training epochs.
trainer = Trainer(max_epochs=2, accelerator="cpu", devices=1)
trainer.fit(model=classifier, train_dataloaders=train_dataloader)
assert len(classifier._train_features) == 4
assert len(classifier._train_targets) == 4


class _FeaturesDataset(Dataset):
def __init__(self, features: Tensor, targets) -> None:
Expand Down

0 comments on commit 453f9bf

Please sign in to comment.