Skip to content

Commit f8b322a

Browse files
xenovaRohanjames1997
authored andcommitted
[webgpu] Fix BatchNormalization ShapeInferenceError for 2D inputs (microsoft#26659)
### Description Test model (happens with any 2D inputs): [2191__visual_projection_visual_projection.1_BatchNormalization.onnx.zip](https://github.com/user-attachments/files/23758390/2191__visual_projection_visual_projection.1_BatchNormalization.onnx.zip) Command: ``` python -c "import onnxruntime as ort; ort.InferenceSession('2191__visual_projection_visual_projection.1_BatchNormalization.onnx', providers=['WebGpuExecutionProvider'])" ``` Before (failure): ``` Op (BatchNormalization) [ShapeInferenceError] Tensor must have at least 3 dimensions to convert between channels first and channels last. ``` After (success): ``` (nothing, meaning success) ``` ### Motivation and Context This fixes BatchNormalization on WebGPU, matching CPU version. cc @guschmue
1 parent cabf13b commit f8b322a

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ class NhwcInferenceContext : public ONNX_NAMESPACE::InferenceContext {
8383
const int rank = nchw_shape.dim_size();
8484
// N and C dims are required. Some operators like AveragePool allow 1D input
8585
if (rank < 3) {
86-
fail_shape_inference("Output tensor must have at least 3 dimensions");
86+
*nhwc_tp.mutable_tensor_type()->mutable_shape() = nchw_shape;
87+
return;
8788
}
8889

8990
// Convert output shape from N, C, H {, W, ...} to N, H {, W, ...}, C.
@@ -105,8 +106,8 @@ class NhwcInferenceContext : public ONNX_NAMESPACE::InferenceContext {
105106
const int rank = nhwc_shape.dim_size();
106107
// N and C dims are required. Some operators like AveragePool allow 1D input.
107108
if (rank < 3) {
108-
fail_shape_inference(
109-
"Tensor must have at least 3 dimensions to convert between channels first and channels last.");
109+
*nchw_tp.mutable_tensor_type()->mutable_shape() = nhwc_shape;
110+
return;
110111
}
111112

112113
// Convert input shape from {N, H, W, ..., C} to {N, C, H, W, ...}.

0 commit comments

Comments
 (0)