Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make KNN feature normalization optional #1457

Merged
merged 3 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions lightly/utils/benchmarking/knn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
knn_t: float = 0.1,
topk: Tuple[int, ...] = (1, 5),
feature_dtype: torch.dtype = torch.float32,
normalize: bool = True,
):
"""KNN classifier for benchmarking.

Expand All @@ -42,6 +43,8 @@ def __init__(
feature_dtype:
Torch data type of the features used for KNN search. Reduce to float16
for memory-efficient KNN search.
normalize:
Whether to normalize the features for KNN search.

Examples:
>>> from pytorch_lightning import Trainer
Expand Down Expand Up @@ -90,6 +93,7 @@ def __init__(
self.knn_t = knn_t
self.topk = topk
self.feature_dtype = feature_dtype
self.normalize = normalize

self._train_features = []
self._train_targets = []
Expand All @@ -100,7 +104,9 @@ def __init__(
def training_step(self, batch, batch_idx) -> None:
images, targets = batch[0], batch[1]
features = self.model.forward(images).flatten(start_dim=1)
features = F.normalize(features, dim=1).to(self.feature_dtype)
if self.normalize:
features = F.normalize(features, dim=1)
features = features.to(self.feature_dtype)
self._train_features.append(features.cpu())
self._train_targets.append(targets.cpu())

Expand All @@ -110,7 +116,9 @@ def validation_step(self, batch, batch_idx) -> None:

images, targets = batch[0], batch[1]
features = self.model.forward(images).flatten(start_dim=1)
features = F.normalize(features, dim=1).to(self.feature_dtype)
if self.normalize:
features = F.normalize(features, dim=1)
features = features.to(self.feature_dtype)
predicted_classes = knn_predict(
feature=features,
feature_bank=self._train_features_tensor,
Expand Down Expand Up @@ -145,6 +153,12 @@ def on_train_epoch_start(self) -> None:
# Set model to eval mode to disable norm layer updates.
self.model.eval()

# Reset features and targets.
self._train_features = []
self._train_targets = []
self._train_features_tensor = None
self._train_targets_tensor = None

def configure_optimizers(self) -> None:
# configure_optimizers must be implemented for PyTorch Lightning. Returning None
# means that no optimization is performed.
Expand Down
79 changes: 79 additions & 0 deletions tests/utils/benchmarking/test_knn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import pytest
import torch
from pytest_mock import MockerFixture
from pytorch_lightning import Trainer
from torch import Tensor, nn
from torch.utils.data import DataLoader, Dataset

from lightly.utils.benchmarking import KNNClassifier
from lightly.utils.benchmarking.knn_classifier import F


class TestKNNClassifier:
Expand Down Expand Up @@ -131,6 +133,83 @@ def test__features_dtype(self) -> None:
assert classifier._train_features_tensor is not None
assert classifier._train_features_tensor.dtype == torch.int

def test__normalize(self, mocker: MockerFixture) -> None:
train_features = torch.randn(4, 3)
train_targets = torch.randint(0, 10, (4,))
train_dataset = _FeaturesDataset(features=train_features, targets=train_targets)
val_features = torch.randn(4, 3)
val_targets = torch.randint(0, 10, (4,))
val_dataset = _FeaturesDataset(features=val_features, targets=val_targets)
train_dataloader = DataLoader(train_dataset)
val_dataloader = DataLoader(val_dataset)
trainer = Trainer(max_epochs=1, accelerator="cpu", devices=1)
spy_normalize = mocker.spy(F, "normalize")

# Test that normalize is called when normalize=True.
classifier = KNNClassifier(
nn.Identity(), num_classes=10, knn_k=3, normalize=True
)
trainer.fit(
model=classifier,
train_dataloaders=train_dataloader,
)
spy_normalize.assert_called()
spy_normalize.reset_mock()

trainer.validate(model=classifier, dataloaders=val_dataloader)
spy_normalize.assert_called()
spy_normalize.reset_mock()

# Test that normalize is not called when normalize=False.
classifier = KNNClassifier(
nn.Identity(), num_classes=10, knn_k=3, normalize=False
)
trainer.fit(
model=classifier,
train_dataloaders=train_dataloader,
)
spy_normalize.assert_not_called()
spy_normalize.reset_mock()

trainer.validate(model=classifier, dataloaders=val_dataloader)
spy_normalize.assert_not_called()
spy_normalize.reset_mock()

def test__reset_features_and_targets(self) -> None:
train_features = torch.randn(4, 3)
train_targets = torch.randint(0, 10, (4,))
train_dataset = _FeaturesDataset(features=train_features, targets=train_targets)
val_features = torch.randn(4, 3)
val_targets = torch.randint(0, 10, (4,))
val_dataset = _FeaturesDataset(features=val_features, targets=val_targets)
train_dataloader = DataLoader(train_dataset)
val_dataloader = DataLoader(val_dataset)
classifier = KNNClassifier(nn.Identity(), num_classes=10, knn_k=3)

trainer = Trainer(max_epochs=2, accelerator="cpu", devices=1)
trainer.fit(
model=classifier,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)

# Check that train features and targets are reset after validation.
assert classifier._train_features == []
assert classifier._train_targets == []
assert classifier._train_features_tensor is not None
assert classifier._train_targets_tensor is not None
# Check that train features and targets are not accumulated over multiple
# validation epochs.
assert classifier._train_features_tensor.shape == (3, 4)
assert classifier._train_targets_tensor.shape == (4,)

# Check that train features and targets are not accumulated over multiple
# training epochs.
trainer = Trainer(max_epochs=2, accelerator="cpu", devices=1)
trainer.fit(model=classifier, train_dataloaders=train_dataloader)
assert len(classifier._train_features) == 4
assert len(classifier._train_targets) == 4


class _FeaturesDataset(Dataset):
def __init__(self, features: Tensor, targets) -> None:
Expand Down
Loading