diff --git a/evaluate.py b/evaluate.py index 9a4e3ba2b5..42eb064f9b 100644 --- a/evaluate.py +++ b/evaluate.py @@ -27,7 +27,7 @@ def evaluate(net, dataloader, device, amp): assert mask_true.min() >= 0 and mask_true.max() <= 1, 'True mask indices should be in [0, 1]' mask_pred = (F.sigmoid(mask_pred) > 0.5).float() # compute the Dice score - dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False) + dice_score += dice_coeff(mask_pred.squeeze(1), mask_true, reduce_batch_first=False) else: assert mask_true.min() >= 0 and mask_true.max() < net.n_classes, 'True mask indices should be in [0, n_classes[' # convert to one-hot format