Skip to content

Commit 51de606

Browse files
AdrianLundellSebastian-Larssondigantdesai
authored
Arm backend: Make TOSA backend NCHW-compatible (#12994)
- Moves the needed input/output transposes into the delegated graph to run on Ethos-U rather than requiring the EthosUBackend to implement transposes on CPU. - Renames the annotate_channels_last_dim_order_pass to to_tosa_memory_format_pass since to be more descriptive and future proof. This changes additionally enables running multiple batches since the EthosU transpose supports that natively, whereas the CPU implementation did not. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 --------- Signed-off-by: Adrian Lundell <[email protected]> Co-authored-by: Sebastian Larsson <[email protected]> Co-authored-by: Digant Desai <[email protected]>
1 parent 71206cf commit 51de606

37 files changed

+417
-445
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from . import arm_pass_utils # noqa
88
from .arm_pass import ArmPass # noqa # usort: skip
99
from .add_bias_pass import AddBiasPass # noqa
10-
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
1110
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
1211
from .broadcast_args_pass import BroadcastArgsPass # noqa
1312
from .cast_bool_to_int8_pass import CastBoolToInt8Pass # noqa
@@ -85,6 +84,7 @@
8584
)
8685
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
8786
from .size_adjust_input_pass import SizeAdjustInputPass # noqa
87+
from .to_tosa_memory_format_pass import ToTosaMemoryFormatPass # noqa
8888
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
8989
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
9090
from .replace_inf_values_pass import ReplaceInfValues # noqa # usort: skip

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import executorch.backends.arm.tosa.dialect # noqa: unused
1111
from executorch.backends.arm._passes import (
1212
AddBiasPass,
13-
AnnotateChannelsLastDimOrder,
1413
AnnotateDecomposedMatmulPass,
1514
BroadcastArgsPass,
1615
CastBoolToInt8Pass,
@@ -84,6 +83,7 @@
8483
RetraceFoldedDtypesPass,
8584
ScalarsToAttributePass,
8685
SizeAdjustInputPass,
86+
ToTosaMemoryFormatPass,
8787
UnsqueezeBeforeRepeatPass,
8888
UnsqueezeScalarPlaceholdersPass,
8989
)
@@ -162,7 +162,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
162162

163163
self.add_pass(InsertTableOpsPass(exported_program))
164164
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
165-
self.add_pass(AnnotateChannelsLastDimOrder())
165+
self.add_pass(ToTosaMemoryFormatPass(exported_program))
166166
self.add_pass(InsertRescalePass())
167167

168168
return self._transform(exported_program.graph_module)
@@ -241,7 +241,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
241241
self.add_pass(AddBiasPass(exported_program))
242242
self.add_pass(InsertTableOpsPass(exported_program))
243243
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
244-
self.add_pass(AnnotateChannelsLastDimOrder())
244+
self.add_pass(ToTosaMemoryFormatPass(exported_program))
245245
self.add_pass(InsertRescalePass())
246246

247247
return self._transform(exported_program.graph_module)

backends/arm/_passes/decompose_select.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
# pyre-unsafe
88

99
import torch
10-
from executorch.backends.arm._passes.arm_pass_utils import create_node
10+
from executorch.backends.arm._passes.arm_pass_utils import (
11+
create_node,
12+
get_first_fake_tensor,
13+
)
1114
from executorch.exir.dialects._ops import ops as exir_ops
1215
from executorch.exir.pass_base import ExportPass, PassResult
1316

@@ -34,8 +37,9 @@ def call(self, graph_module: torch.fx.GraphModule):
3437

3538
input_node, dim, index = node.args
3639

37-
rank = len(input_node.meta["val"].size())
38-
shape = input_node.meta["val"].shape
40+
input_tensor = get_first_fake_tensor(input_node)
41+
rank = len(input_tensor.size())
42+
shape = input_tensor.shape
3943
dim = dim % rank if dim < 0 else dim
4044
index = index % shape[dim] if index < 0 else index
4145

@@ -44,7 +48,7 @@ def call(self, graph_module: torch.fx.GraphModule):
4448
graph_module.graph, slice_op, (input_node, dim, index, index + 1)
4549
)
4650
squeeze_node = create_node(
47-
graph_module.graph, squeeze_op, (slice_node, [dim])
51+
graph_module.graph, squeeze_op, (slice_node, [dim]), from_node=node
4852
)
4953

5054
node.replace_all_uses_with(squeeze_node)

backends/arm/_passes/annotate_channels_last_dim_order_pass.py renamed to backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,22 @@
1010
from executorch.backends.arm._passes.arm_pass_utils import (
1111
create_node,
1212
get_first_fake_tensor,
13+
is_param_node,
1314
)
1415
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
16+
from executorch.exir import ExportedProgram
1517
from executorch.exir.dialects._ops import ops as exir_ops
1618
from executorch.exir.pass_base import ExportPass, PassResult
1719

1820

19-
class AnnotateChannelsLastDimOrder(ExportPass):
21+
def _is_input(node: torch.fx.Node, exported_program: ExportedProgram) -> bool:
22+
"""
23+
Returns True if the node is an input node, i.e. a placeholder or a parameter.
24+
"""
25+
return node.op == "placeholder" and not is_param_node(exported_program, node)
26+
27+
28+
class ToTosaMemoryFormatPass(ExportPass):
2029
"""
2130
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
2231
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts backend.tosa.TRANSPOSE
@@ -30,6 +39,10 @@ class AnnotateChannelsLastDimOrder(ExportPass):
3039
NNHWC_order = (0, 1, 3, 4, 2)
3140
NNHWC_inverse_order = (0, 1, 4, 2, 3)
3241

42+
def __init__(self, exported_program: ExportedProgram) -> None:
43+
self.exported_program = exported_program
44+
super().__init__()
45+
3346
def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
3447
"""
3548
returns True for w in the following sequence;
@@ -92,25 +105,30 @@ def is_channel_reshape(input_shape, output_shape):
92105

93106
@staticmethod
94107
def insert_input_transpose(node, input_node, graph_module):
108+
if input_node.target == exir_ops.backend.tosa.TRANSPOSE.default:
109+
pre_permute_node = input_node.all_input_nodes[0]
110+
node.replace_input_with(input_node, pre_permute_node)
111+
return
112+
95113
with graph_module.graph.inserting_before(node):
96114
permute_node = create_node(
97115
graph_module.graph,
98116
exir_ops.backend.tosa.TRANSPOSE.default,
99117
args=(
100118
input_node,
101119
list(
102-
AnnotateChannelsLastDimOrder.NNHWC_inverse_order
120+
ToTosaMemoryFormatPass.NNHWC_inverse_order
103121
if len(get_first_fake_tensor(input_node).size()) == 5
104-
else AnnotateChannelsLastDimOrder.NHWC_inverse_order
122+
else ToTosaMemoryFormatPass.NHWC_inverse_order
105123
),
106124
),
125+
from_node=node,
107126
)
108127
node.replace_input_with(input_node, permute_node)
109128

110129
permute_node.meta["tosa_dim_order"] = tuple(
111130
range(len(input_node.meta["val"].size()))
112131
)
113-
permute_node.meta["val"] = input_node.meta["val"]
114132

115133
@staticmethod
116134
def insert_output_transpose(node, graph_module):
@@ -121,25 +139,23 @@ def insert_output_transpose(node, graph_module):
121139
args=(
122140
node,
123141
list(
124-
AnnotateChannelsLastDimOrder.NNHWC_order
142+
ToTosaMemoryFormatPass.NNHWC_order
125143
if len(get_first_fake_tensor(node).size()) == 5
126-
else AnnotateChannelsLastDimOrder.NHWC_order
144+
else ToTosaMemoryFormatPass.NHWC_order
127145
),
128146
),
147+
from_node=node,
129148
)
149+
130150
permute_node.meta["tosa_dim_order"] = (
131-
AnnotateChannelsLastDimOrder.NNHWC_order
151+
ToTosaMemoryFormatPass.NNHWC_order
132152
if len(get_first_fake_tensor(node).size()) == 5
133-
else AnnotateChannelsLastDimOrder.NHWC_order
134-
)
135-
permute_node.meta["val"] = get_first_fake_tensor(node).permute(
136-
AnnotateChannelsLastDimOrder.NNHWC_order
137-
if len(get_first_fake_tensor(node).size()) == 5
138-
else AnnotateChannelsLastDimOrder.NHWC_order
153+
else ToTosaMemoryFormatPass.NHWC_order
139154
)
140155
node.meta["tosa_dim_order"] = tuple(
141156
range(len(get_first_fake_tensor(node).size()))
142157
)
158+
143159
users = [user for user in node.users if user != permute_node]
144160
for user in users:
145161
user.replace_input_with(node, permute_node)
@@ -150,20 +166,23 @@ def _insert_view_transpose(
150166
):
151167
nchw_to_nhwc = len(input_shape) < 4 and len(output_shape) >= 4
152168
nhwc_to_nchw = len(input_shape) >= 4 and len(output_shape) < 4
153-
channel_reshape = AnnotateChannelsLastDimOrder.is_channel_reshape(
169+
channel_reshape = ToTosaMemoryFormatPass.is_channel_reshape(
154170
output_shape, input_shape
155171
)
156172

157173
if (
158174
channel_reshape or nhwc_to_nchw
159-
) and AnnotateChannelsLastDimOrder.memory_format_differs(input_shape):
160-
AnnotateChannelsLastDimOrder.insert_input_transpose(
175+
) and ToTosaMemoryFormatPass.memory_format_differs(input_shape):
176+
177+
ToTosaMemoryFormatPass.insert_input_transpose(
161178
node, input_node, graph_module
162179
)
180+
163181
if (
164182
channel_reshape or nchw_to_nhwc
165-
) and AnnotateChannelsLastDimOrder.memory_format_differs(output_shape):
166-
AnnotateChannelsLastDimOrder.insert_output_transpose(node, graph_module)
183+
) and ToTosaMemoryFormatPass.memory_format_differs(output_shape):
184+
185+
ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module)
167186

168187
def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
169188
"""
@@ -181,9 +200,10 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
181200
for node in graph_module.graph.nodes:
182201
# call_function and placeholder allowed due to
183202
# index.Tensor being able to come in as both
184-
if node.op not in ["call_function", "placeholder"]:
203+
if node.op not in ["call_function", "placeholder", "output"]:
185204
continue
186205

206+
# Transpose views
187207
elif node.target in (
188208
exir_ops.edge.aten.view_copy.default,
189209
exir_ops.edge.aten.index.Tensor,
@@ -194,25 +214,48 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
194214
input_node = node.args[0]
195215
input_shape = input_node.meta["val"].shape
196216
output_shape = node.meta["val"].shape
197-
198217
self._insert_view_transpose(
199-
input_shape, output_shape, node, input_node, graph_module
218+
input_shape,
219+
output_shape,
220+
node,
221+
input_node,
222+
graph_module,
200223
)
201224

225+
# Transpose inputs
226+
elif _is_input(node, self.exported_program):
227+
input_shape = get_first_fake_tensor(node).size()
228+
if len(input_shape) in (4, 5):
229+
ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module)
230+
231+
# Transpose outputs
232+
elif node.op == "output":
233+
output_shape = get_first_fake_tensor(node).size()
234+
235+
if len(output_shape) in (4, 5):
236+
for input_node in node.all_input_nodes:
237+
ToTosaMemoryFormatPass.insert_input_transpose(
238+
node, input_node, graph_module
239+
)
240+
202241
def call(self, graph_module: torch.fx.GraphModule):
203242
for node in graph_module.graph.nodes:
204243
node_data = get_first_fake_tensor(node).data
205244

206-
if node_data.dim() == 4:
245+
# Inputs and outputs are always in (N)NCHW format
246+
if _is_input(node, self.exported_program) or node.op == "output":
247+
dim_order = tuple(range(node_data.dim()))
248+
elif node_data.dim() == 4:
207249
dim_order = self.NHWC_order
208250
if self.is_weight_node_for_depthwise_conv2d(node):
209251
# The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
210252
# dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
211253
dim_order = self.HWCM_order
212254
elif node_data.dim() == 5:
213-
dim_order = self.NNHWC_order # type: ignore[assignment]
255+
dim_order = self.NNHWC_order
214256
else:
215257
dim_order = tuple(range(node_data.dim())) # type: ignore[assignment]
258+
216259
node.meta["tosa_dim_order"] = dim_order
217260
# Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format.
218261
# See insert_tosa_transposes for insertion conditions.

backends/arm/operators/op_transpose.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,14 @@ def define_node(
4747
validate_valid_dtype(
4848
self.target,
4949
[inputs[0], output],
50-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
50+
[
51+
ts.DType.INT8,
52+
ts.DType.INT16,
53+
ts.DType.INT32,
54+
ts.DType.FP32,
55+
ts.DType.BOOL,
56+
ts.DType.FP16,
57+
],
5158
output.tosa_spec,
5259
)
5360

0 commit comments

Comments
 (0)