Skip to content

Commit

Permalink
Make KNN feature normalization optional
Browse files Browse the repository at this point in the history
  • Loading branch information
guarin committed Dec 15, 2023
1 parent 610f73e commit 1d42012
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions lightly/utils/benchmarking/knn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
knn_t: float = 0.1,
topk: Tuple[int, ...] = (1, 5),
feature_dtype: torch.dtype = torch.float32,
normalize: bool = True,
):
"""KNN classifier for benchmarking.
Expand All @@ -42,6 +43,8 @@ def __init__(
feature_dtype:
Torch data type of the features used for KNN search. Reduce to float16
for memory-efficient KNN search.
normalize:
Whether to normalize the features for KNN search.
Examples:
>>> from pytorch_lightning import Trainer
Expand Down Expand Up @@ -90,6 +93,7 @@ def __init__(
self.knn_t = knn_t
self.topk = topk
self.feature_dtype = feature_dtype
self.normalize = normalize

self._train_features = []
self._train_targets = []
Expand All @@ -100,7 +104,8 @@ def __init__(
def training_step(self, batch, batch_idx) -> None:
images, targets = batch[0], batch[1]
features = self.model.forward(images).flatten(start_dim=1)
features = F.normalize(features, dim=1).to(self.feature_dtype)
if self.normalize:
features = F.normalize(features, dim=1).to(self.feature_dtype)
self._train_features.append(features.cpu())
self._train_targets.append(targets.cpu())

Expand All @@ -110,7 +115,8 @@ def validation_step(self, batch, batch_idx) -> None:

images, targets = batch[0], batch[1]
features = self.model.forward(images).flatten(start_dim=1)
features = F.normalize(features, dim=1).to(self.feature_dtype)
if self.normalize:
features = F.normalize(features, dim=1).to(self.feature_dtype)
predicted_classes = knn_predict(
feature=features,
feature_bank=self._train_features_tensor,
Expand Down

0 comments on commit 1d42012

Please sign in to comment.