|
38 | 38 |
|
39 | 39 | __all__ = ["sliding_window_inference"] |
40 | 40 |
|
| 41 | +def ensure_channel_first(x: torch.Tensor, spatial_ndim: Optional[int] = None) -> Tuple[torch.Tensor, int]: |
| 42 | + """ |
| 43 | + 將張量標準化為 channel-first(N,C,spatial...)。 |
| 44 | + 回傳 (可能已轉換的張量, 原本 channel 維度:1 表示本來就在 dim=1;-1 表示本來在最後一維)。 |
41 | 45 |
|
| 46 | + 支援常見情況: |
| 47 | + - [N, C, *spatial] -> 原樣返回 |
| 48 | + - [N, *spatial, C] -> 移動最後一維到 dim=1 |
| 49 | + 其他模糊情況則丟出 ValueError,避免悄悄算錯。 |
| 50 | + """ |
| 51 | + if not isinstance(x, torch.Tensor): |
| 52 | + raise TypeError(f"expect torch.Tensor, got {type(x)}") |
| 53 | + if x.ndim < 3: |
| 54 | + raise ValueError(f"expect >=3 dims (N,C,spatial...), got shape={tuple(x.shape)}") |
| 55 | + |
| 56 | + # 若未指定,估個常見的 2D/3D 空間維度數,僅用於錯誤訊息與判斷參考 |
| 57 | + if spatial_ndim is None: |
| 58 | + spatial_ndim = max(2, min(3, x.ndim - 2)) |
| 59 | + |
| 60 | + # 簡單啟發式:C 通常不會太大(<=512) |
| 61 | + c_first_ok = x.shape[1] <= 512 |
| 62 | + c_last_ok = x.shape[-1] <= 512 |
| 63 | + |
| 64 | + # 優先保留 channel-first |
| 65 | + if c_first_ok and x.ndim >= 2 + spatial_ndim: |
| 66 | + return x, 1 |
| 67 | + if c_last_ok: |
| 68 | + return x.movedim(-1, 1), -1 |
| 69 | + |
| 70 | + raise ValueError( |
| 71 | + f"cannot infer channel dim for shape={tuple(x.shape)}; " |
| 72 | + f"expected [N,C,spatial...] or [N,spatial...,C] (spatial_ndim≈{spatial_ndim})" |
| 73 | + ) |
42 | 74 | def sliding_window_inference( |
43 | 75 | inputs: torch.Tensor | MetaTensor, |
44 | 76 | roi_size: Sequence[int] | int, |
|
0 commit comments