From 736cecee620cb07c62f5cf637953c2818e8e405c Mon Sep 17 00:00:00 2001 From: h2th3k Date: Fri, 10 May 2024 13:21:43 -0500 Subject: [PATCH] resolve deprecated compute_on_step --- train_imagenet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_imagenet.py b/train_imagenet.py index 8b48115..2a8b18f 100644 --- a/train_imagenet.py +++ b/train_imagenet.py @@ -416,9 +416,9 @@ def val_loop(self, lr_tta): @param('logging.folder') def initialize_logger(self, folder): self.val_meters = { - 'top_1': torchmetrics.Accuracy(task='multiclass', num_classes=1000, compute_on_step=False).to(self.gpu), - 'top_5': torchmetrics.Accuracy(task='multiclass', num_classes=1000, compute_on_step=False, top_k=5).to(self.gpu), - 'loss': MeanScalarMetric(compute_on_step=False).to(self.gpu) + 'top_1': torchmetrics.Accuracy(task='multiclass', num_classes=1000).to(self.gpu), + 'top_5': torchmetrics.Accuracy(task='multiclass', num_classes=1000, top_k=5).to(self.gpu), + 'loss': MeanScalarMetric().to(self.gpu) } if self.gpu == 0: