diff --git a/train_imagenet.py b/train_imagenet.py index 8b48115..2a8b18f 100644 --- a/train_imagenet.py +++ b/train_imagenet.py @@ -416,9 +416,9 @@ def val_loop(self, lr_tta): @param('logging.folder') def initialize_logger(self, folder): self.val_meters = { - 'top_1': torchmetrics.Accuracy(task='multiclass', num_classes=1000, compute_on_step=False).to(self.gpu), - 'top_5': torchmetrics.Accuracy(task='multiclass', num_classes=1000, compute_on_step=False, top_k=5).to(self.gpu), - 'loss': MeanScalarMetric(compute_on_step=False).to(self.gpu) + 'top_1': torchmetrics.Accuracy(task='multiclass', num_classes=1000).to(self.gpu), + 'top_5': torchmetrics.Accuracy(task='multiclass', num_classes=1000, top_k=5).to(self.gpu), + 'loss': MeanScalarMetric().to(self.gpu) } if self.gpu == 0: