diff --git a/lightly/utils/benchmarking/metric_callback.py b/lightly/utils/benchmarking/metric_callback.py index d19ef42fb..255965428 100644 --- a/lightly/utils/benchmarking/metric_callback.py +++ b/lightly/utils/benchmarking/metric_callback.py @@ -60,4 +60,7 @@ def _append_metrics( self, metrics_dict: Dict[str, List[float]], trainer: Trainer ) -> None: for name, value in trainer.callback_metrics.items(): - metrics_dict.setdefault(name, []).append(float(value)) + if isinstance(value, float) or ( # type: ignore # We can't rely on PyTorchLightning's type annotations. + isinstance(value, Tensor) and value.numel() == 1 + ): + metrics_dict.setdefault(name, []).append(float(value)) diff --git a/lightly/utils/embeddings_2d.py b/lightly/utils/embeddings_2d.py index 425dff817..046f2d874 100644 --- a/lightly/utils/embeddings_2d.py +++ b/lightly/utils/embeddings_2d.py @@ -66,8 +66,8 @@ def transform(self, X: NDArray[np.float32]) -> NDArray[np.float32]: raise ValueError("PCA not fitted yet. Call fit() before transform().") X = X.astype(np.float32) X = X - self.mean + self.eps - transformed = X.dot(self.w)[:, : self.n_components] - return np.asarray(transformed, dtype=np.float32) + transformed: NDArray[np.float32] = X.dot(self.w)[:, : self.n_components] + return np.asarray(transformed) def fit_pca(