Skip to content

Commit f5d2424

Browse files
林旻佑林旻佑
authored andcommitted
Fix: robust channel-last detection for y using num_classes (refs #8366)
1 parent 59786d2 commit f5d2424

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

monai/metrics/meandice.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,11 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
310310
y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
311311
"""
312312
y_pred, _ = ensure_channel_first(y_pred)
313-
if y.ndim == y_pred.ndim and (y.shape[-1] == y_pred.shape[1] or y.shape[-1] == 1):
313+
314+
n_ch = self.num_classes or y_pred.shape[1]
315+
if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch):
314316
y, _ = ensure_channel_first(y)
317+
315318

316319
_apply_argmax, _threshold = self.apply_argmax, self.threshold
317320
if self.num_classes is None:

0 commit comments

Comments
 (0)