Skip to content

Commit fce8287

Browse files
林旻佑林旻佑
authored andcommitted
Fix: robust channel-last detection for y using num_classes (refs #8366)
Signed-off-by: 林旻佑 <[email protected]>
1 parent f5d2424 commit fce8287

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

monai/inferers/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,39 @@
3838

3939
__all__ = ["sliding_window_inference"]
4040

41+
def ensure_channel_first(x: torch.Tensor, spatial_ndim: Optional[int] = None) -> Tuple[torch.Tensor, int]:
42+
"""
43+
將張量標準化為 channel-first(N,C,spatial...)。
44+
回傳 (可能已轉換的張量, 原本 channel 維度:1 表示本來就在 dim=1;-1 表示本來在最後一維)。
4145
46+
支援常見情況:
47+
- [N, C, *spatial] -> 原樣返回
48+
- [N, *spatial, C] -> 移動最後一維到 dim=1
49+
其他模糊情況則丟出 ValueError,避免悄悄算錯。
50+
"""
51+
if not isinstance(x, torch.Tensor):
52+
raise TypeError(f"expect torch.Tensor, got {type(x)}")
53+
if x.ndim < 3:
54+
raise ValueError(f"expect >=3 dims (N,C,spatial...), got shape={tuple(x.shape)}")
55+
56+
# 若未指定,估個常見的 2D/3D 空間維度數,僅用於錯誤訊息與判斷參考
57+
if spatial_ndim is None:
58+
spatial_ndim = max(2, min(3, x.ndim - 2))
59+
60+
# 簡單啟發式:C 通常不會太大(<=512)
61+
c_first_ok = x.shape[1] <= 512
62+
c_last_ok = x.shape[-1] <= 512
63+
64+
# 優先保留 channel-first
65+
if c_first_ok and x.ndim >= 2 + spatial_ndim:
66+
return x, 1
67+
if c_last_ok:
68+
return x.movedim(-1, 1), -1
69+
70+
raise ValueError(
71+
f"cannot infer channel dim for shape={tuple(x.shape)}; "
72+
f"expected [N,C,spatial...] or [N,spatial...,C] (spatial_ndim≈{spatial_ndim})"
73+
)
4274
def sliding_window_inference(
4375
inputs: torch.Tensor | MetaTensor,
4476
roi_size: Sequence[int] | int,

0 commit comments

Comments
 (0)