Skip to content

Commit

Permalink
Always convert features to correct dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
guarin committed Dec 15, 2023
1 parent 453f9bf commit c0d09ee
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions lightly/utils/benchmarking/knn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def training_step(self, batch, batch_idx) -> None:
images, targets = batch[0], batch[1]
features = self.model.forward(images).flatten(start_dim=1)
if self.normalize:
features = F.normalize(features, dim=1).to(self.feature_dtype)
features = F.normalize(features, dim=1)
features = features.to(self.feature_dtype)
self._train_features.append(features.cpu())
self._train_targets.append(targets.cpu())

Expand All @@ -116,7 +117,8 @@ def validation_step(self, batch, batch_idx) -> None:
images, targets = batch[0], batch[1]
features = self.model.forward(images).flatten(start_dim=1)
if self.normalize:
features = F.normalize(features, dim=1).to(self.feature_dtype)
features = F.normalize(features, dim=1)
features = features.to(self.feature_dtype)
predicted_classes = knn_predict(
feature=features,
feature_bank=self._train_features_tensor,
Expand Down

0 comments on commit c0d09ee

Please sign in to comment.