From 8fa20a2b7b592761502d13d7f5ab0be6d5ea80c3 Mon Sep 17 00:00:00 2001 From: Dwayne Duane Date: Fri, 16 Oct 2020 10:06:16 -0400 Subject: [PATCH] Input shape bug in DepthwiseConv2D in NCHW format --- coremltools/converters/mil/frontend/tensorflow/ops.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/coremltools/converters/mil/frontend/tensorflow/ops.py b/coremltools/converters/mil/frontend/tensorflow/ops.py index 1d61b0463..570b74613 100644 --- a/coremltools/converters/mil/frontend/tensorflow/ops.py +++ b/coremltools/converters/mil/frontend/tensorflow/ops.py @@ -729,9 +729,16 @@ def DepthwiseConv2dNative(context, node): pad_type = pad_type.lower() x = context[node.inputs[0]] - C_in = x.shape[-1] + if data_format == "NHWC": x = _transpose_NHWC_to_NCHW(x) + C_in = x.shape[-1] + elif data_format == "NCHW": + C_in = x.shape[1] + + if not isinstance(C_in, int): + raise ValueError("Channel number of input node must be an integer, instead got: {}".format(C_in)) + # Only the last op should have the same name as node.name conv_name = node.name + "x" if data_format == "NHWC" else node.name x = mb.conv(