Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def is_nhwc_node(node: torch.fx.Node) -> bool:
if len(quantize_node.all_input_nodes) > 0:
actual_node = quantize_node.args[0]
if actual_node.op == "placeholder":
return not actual_node.meta["val"][0].is_contiguous()
return ChannelsLastTaggedReshapePass._is_nhwc_tensor(
actual_node.meta["val"][0]
)
else:
return actual_node.meta.get(
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False
Expand All @@ -125,14 +127,36 @@ def is_nchw_node(node: torch.fx.Node) -> bool:
if len(quantize_node.all_input_nodes) > 0:
actual_node = quantize_node.args[0]
if actual_node.op == "placeholder":
return actual_node.meta["val"][0].is_contiguous()
return not ChannelsLastTaggedReshapePass._is_nhwc_tensor(
actual_node.meta["val"][0]
)
else:
return not actual_node.meta.get(
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False
)

return not ChannelsLastTaggedReshapePass.is_nhwc_node(node)

@staticmethod
def _is_nhwc_tensor(tensor: torch.Tensor) -> bool:
nhwc = tensor.is_contiguous(memory_format=torch.channels_last)
nchw = tensor.is_contiguous()
# if both are true false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just use Tensor.dim_order(ambiguity_check=True)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it raises runtime error when ambiguity_check=True. Thats why i had to do this

# if both nchw and nhwc are true
# then we want to see this is nchw hence return false
# if either of nchw or nhwc is false, then just rely on hwc
# if both are false, mayb channels_last_3d, then return nhwc
# however this should not happen here
# return (not (nchw and nhwc)) and nhwc
# Readable version
if nchw and nhwc:
return False
else:
return nhwc

def _is_nhwc(self, tensor: torch.Tensor) -> bool:
return ChannelsLastTaggedReshapePass._is_nhwc_tensor(tensor)

def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
return node.target in self.memory_sensitive_ops_nhwc

Expand Down Expand Up @@ -315,11 +339,8 @@ def input_dim_order(
self, input_node: torch.fx.Node, input_order: InputDimOrder
) -> bool:
if input_node.op == "placeholder":
return (
input_node.meta["val"].is_contiguous()
if input_order == InputDimOrder.NCHW
else not input_node.meta["val"].is_contiguous()
)
is_nhwc = self._is_nhwc(input_node.meta["val"])
return not is_nhwc if input_order == InputDimOrder.NCHW else is_nhwc
else:
return (
ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
Expand Down Expand Up @@ -348,7 +369,7 @@ def input_to_nhwc(
self.mark_as_nhwc_node(input_node)

if input_node.op == "placeholder":
if not input_node.meta["val"][0].is_contiguous():
if self._is_nhwc(input_node.meta["val"][0]):
return
elif ChannelsLastTaggedReshapePass.is_nhwc_node(input_node):
return
Expand Down Expand Up @@ -420,7 +441,7 @@ def input_to_nchw(
self.mark_as_nchw_node(input_node)

if input_node.op == "placeholder":
if input_node.meta["val"].is_contiguous():
if not self._is_nhwc(input_node.meta["val"]):
return
elif ChannelsLastTaggedReshapePass.is_nchw_node(input_node):
return
Expand Down Expand Up @@ -462,17 +483,17 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
and isinstance(node.meta["val"], torch.Tensor)
and len(node.meta["val"].shape) == 4
):
if node.meta["val"].is_contiguous():
self.mark_as_nchw_node(node)
else:
if self._is_nhwc(node.meta["val"]):
self.mark_as_nhwc_node(node)
else:
self.mark_as_nchw_node(node)
continue

# Need special case for output node because it can have multiple output dim orders as we can output a tuple multiple nodes
if node.op == "output":
out_tuple = node.args[0]
for out_node in out_tuple:
if out_node.meta["val"].is_contiguous():
if not self._is_nhwc(out_node.meta["val"]):
self.input_to_nchw(graph_module, out_node, node)
else:
self.input_to_nhwc(graph_module, out_node, node)
Expand Down
Loading