From f0abe93efa06534325cbe7d66b0dde7c78e28892 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 7 Oct 2025 11:20:03 -0700 Subject: [PATCH] Make determinism of channels_last more conservative Summary: When in doubt because both is_contiguous and is_contiguous(channels_last) return true, assume it is channels first Differential Revision: D83998877 --- .../channels_last_tagged_reshape_pass.py | 47 ++++++++++++++----- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py index 85e9889ca36..c1bc3a54f7c 100644 --- a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py +++ b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py @@ -110,7 +110,9 @@ def is_nhwc_node(node: torch.fx.Node) -> bool: if len(quantize_node.all_input_nodes) > 0: actual_node = quantize_node.args[0] if actual_node.op == "placeholder": - return not actual_node.meta["val"][0].is_contiguous() + return ChannelsLastTaggedReshapePass._is_nhwc_tensor( + actual_node.meta["val"][0] + ) else: return actual_node.meta.get( ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False @@ -125,7 +127,9 @@ def is_nchw_node(node: torch.fx.Node) -> bool: if len(quantize_node.all_input_nodes) > 0: actual_node = quantize_node.args[0] if actual_node.op == "placeholder": - return actual_node.meta["val"][0].is_contiguous() + return not ChannelsLastTaggedReshapePass._is_nhwc_tensor( + actual_node.meta["val"][0] + ) else: return not actual_node.meta.get( ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False @@ -133,6 +137,26 @@ def is_nchw_node(node: torch.fx.Node) -> bool: return not ChannelsLastTaggedReshapePass.is_nhwc_node(node) + @staticmethod + def _is_nhwc_tensor(tensor: torch.Tensor) -> bool: + nhwc = tensor.is_contiguous(memory_format=torch.channels_last) + nchw = tensor.is_contiguous() + # if both are true false + # if both nchw and nhwc are true + # then we want to see this is nchw hence return false + # if either of nchw or nhwc is false, then just rely on hwc + # if both are false, mayb channels_last_3d, then return nhwc + # however this should not happen here + # return (not (nchw and nhwc)) and nhwc + # Readable version + if nchw and nhwc: + return False + else: + return nhwc + + def _is_nhwc(self, tensor: torch.Tensor) -> bool: + return ChannelsLastTaggedReshapePass._is_nhwc_tensor(tensor) + def requires_nhwc_input(self, node: torch.fx.Node) -> bool: return node.target in self.memory_sensitive_ops_nhwc @@ -315,11 +339,8 @@ def input_dim_order( self, input_node: torch.fx.Node, input_order: InputDimOrder ) -> bool: if input_node.op == "placeholder": - return ( - input_node.meta["val"].is_contiguous() - if input_order == InputDimOrder.NCHW - else not input_node.meta["val"].is_contiguous() - ) + is_nhwc = self._is_nhwc(input_node.meta["val"]) + return not is_nhwc if input_order == InputDimOrder.NCHW else is_nhwc else: return ( ChannelsLastTaggedReshapePass.is_nchw_node(input_node) @@ -348,7 +369,7 @@ def input_to_nhwc( self.mark_as_nhwc_node(input_node) if input_node.op == "placeholder": - if not input_node.meta["val"][0].is_contiguous(): + if self._is_nhwc(input_node.meta["val"][0]): return elif ChannelsLastTaggedReshapePass.is_nhwc_node(input_node): return @@ -420,7 +441,7 @@ def input_to_nchw( self.mark_as_nchw_node(input_node) if input_node.op == "placeholder": - if input_node.meta["val"].is_contiguous(): + if not self._is_nhwc(input_node.meta["val"]): return elif ChannelsLastTaggedReshapePass.is_nchw_node(input_node): return @@ -462,17 +483,17 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901 and isinstance(node.meta["val"], torch.Tensor) and len(node.meta["val"].shape) == 4 ): - if node.meta["val"].is_contiguous(): - self.mark_as_nchw_node(node) - else: + if self._is_nhwc(node.meta["val"]): self.mark_as_nhwc_node(node) + else: + self.mark_as_nchw_node(node) continue # Need special case for output node because it can have multiple output dim orders as we can output a tuple multiple nodes if node.op == "output": out_tuple = node.args[0] for out_node in out_tuple: - if out_node.meta["val"].is_contiguous(): + if not self._is_nhwc(out_node.meta["val"]): self.input_to_nchw(graph_module, out_node, node) else: self.input_to_nhwc(graph_module, out_node, node)