10
10
from executorch .backends .arm ._passes .arm_pass_utils import (
11
11
create_node ,
12
12
get_first_fake_tensor ,
13
+ is_param_node ,
13
14
)
14
15
from executorch .backends .arm .tosa_utils import is_consumer_node_depthwise_conv2d
16
+ from executorch .exir import ExportedProgram
15
17
from executorch .exir .dialects ._ops import ops as exir_ops
16
18
from executorch .exir .pass_base import ExportPass , PassResult
17
19
18
20
19
- class AnnotateChannelsLastDimOrder (ExportPass ):
21
+ def _is_input (node : torch .fx .Node , exported_program : ExportedProgram ) -> bool :
22
+ """
23
+ Returns True if the node is an input node, i.e. a placeholder or a parameter.
24
+ """
25
+ return node .op == "placeholder" and not is_param_node (exported_program , node )
26
+
27
+
28
+ class ToTosaMemoryFormatPass (ExportPass ):
20
29
"""
21
30
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
22
31
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts backend.tosa.TRANSPOSE
@@ -30,6 +39,10 @@ class AnnotateChannelsLastDimOrder(ExportPass):
30
39
NNHWC_order = (0 , 1 , 3 , 4 , 2 )
31
40
NNHWC_inverse_order = (0 , 1 , 4 , 2 , 3 )
32
41
42
+ def __init__ (self , exported_program : ExportedProgram ) -> None :
43
+ self .exported_program = exported_program
44
+ super ().__init__ ()
45
+
33
46
def is_weight_node_for_depthwise_conv2d (self , node : torch .fx .Node ):
34
47
"""
35
48
returns True for w in the following sequence;
@@ -92,25 +105,30 @@ def is_channel_reshape(input_shape, output_shape):
92
105
93
106
@staticmethod
94
107
def insert_input_transpose (node , input_node , graph_module ):
108
+ if input_node .target == exir_ops .backend .tosa .TRANSPOSE .default :
109
+ pre_permute_node = input_node .all_input_nodes [0 ]
110
+ node .replace_input_with (input_node , pre_permute_node )
111
+ return
112
+
95
113
with graph_module .graph .inserting_before (node ):
96
114
permute_node = create_node (
97
115
graph_module .graph ,
98
116
exir_ops .backend .tosa .TRANSPOSE .default ,
99
117
args = (
100
118
input_node ,
101
119
list (
102
- AnnotateChannelsLastDimOrder .NNHWC_inverse_order
120
+ ToTosaMemoryFormatPass .NNHWC_inverse_order
103
121
if len (get_first_fake_tensor (input_node ).size ()) == 5
104
- else AnnotateChannelsLastDimOrder .NHWC_inverse_order
122
+ else ToTosaMemoryFormatPass .NHWC_inverse_order
105
123
),
106
124
),
125
+ from_node = node ,
107
126
)
108
127
node .replace_input_with (input_node , permute_node )
109
128
110
129
permute_node .meta ["tosa_dim_order" ] = tuple (
111
130
range (len (input_node .meta ["val" ].size ()))
112
131
)
113
- permute_node .meta ["val" ] = input_node .meta ["val" ]
114
132
115
133
@staticmethod
116
134
def insert_output_transpose (node , graph_module ):
@@ -121,25 +139,23 @@ def insert_output_transpose(node, graph_module):
121
139
args = (
122
140
node ,
123
141
list (
124
- AnnotateChannelsLastDimOrder .NNHWC_order
142
+ ToTosaMemoryFormatPass .NNHWC_order
125
143
if len (get_first_fake_tensor (node ).size ()) == 5
126
- else AnnotateChannelsLastDimOrder .NHWC_order
144
+ else ToTosaMemoryFormatPass .NHWC_order
127
145
),
128
146
),
147
+ from_node = node ,
129
148
)
149
+
130
150
permute_node .meta ["tosa_dim_order" ] = (
131
- AnnotateChannelsLastDimOrder .NNHWC_order
151
+ ToTosaMemoryFormatPass .NNHWC_order
132
152
if len (get_first_fake_tensor (node ).size ()) == 5
133
- else AnnotateChannelsLastDimOrder .NHWC_order
134
- )
135
- permute_node .meta ["val" ] = get_first_fake_tensor (node ).permute (
136
- AnnotateChannelsLastDimOrder .NNHWC_order
137
- if len (get_first_fake_tensor (node ).size ()) == 5
138
- else AnnotateChannelsLastDimOrder .NHWC_order
153
+ else ToTosaMemoryFormatPass .NHWC_order
139
154
)
140
155
node .meta ["tosa_dim_order" ] = tuple (
141
156
range (len (get_first_fake_tensor (node ).size ()))
142
157
)
158
+
143
159
users = [user for user in node .users if user != permute_node ]
144
160
for user in users :
145
161
user .replace_input_with (node , permute_node )
@@ -150,20 +166,23 @@ def _insert_view_transpose(
150
166
):
151
167
nchw_to_nhwc = len (input_shape ) < 4 and len (output_shape ) >= 4
152
168
nhwc_to_nchw = len (input_shape ) >= 4 and len (output_shape ) < 4
153
- channel_reshape = AnnotateChannelsLastDimOrder .is_channel_reshape (
169
+ channel_reshape = ToTosaMemoryFormatPass .is_channel_reshape (
154
170
output_shape , input_shape
155
171
)
156
172
157
173
if (
158
174
channel_reshape or nhwc_to_nchw
159
- ) and AnnotateChannelsLastDimOrder .memory_format_differs (input_shape ):
160
- AnnotateChannelsLastDimOrder .insert_input_transpose (
175
+ ) and ToTosaMemoryFormatPass .memory_format_differs (input_shape ):
176
+
177
+ ToTosaMemoryFormatPass .insert_input_transpose (
161
178
node , input_node , graph_module
162
179
)
180
+
163
181
if (
164
182
channel_reshape or nchw_to_nhwc
165
- ) and AnnotateChannelsLastDimOrder .memory_format_differs (output_shape ):
166
- AnnotateChannelsLastDimOrder .insert_output_transpose (node , graph_module )
183
+ ) and ToTosaMemoryFormatPass .memory_format_differs (output_shape ):
184
+
185
+ ToTosaMemoryFormatPass .insert_output_transpose (node , graph_module )
167
186
168
187
def insert_tosa_transposes (self , graph_module : torch .fx .GraphModule ):
169
188
"""
@@ -181,9 +200,10 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
181
200
for node in graph_module .graph .nodes :
182
201
# call_function and placeholder allowed due to
183
202
# index.Tensor being able to come in as both
184
- if node .op not in ["call_function" , "placeholder" ]:
203
+ if node .op not in ["call_function" , "placeholder" , "output" ]:
185
204
continue
186
205
206
+ # Transpose views
187
207
elif node .target in (
188
208
exir_ops .edge .aten .view_copy .default ,
189
209
exir_ops .edge .aten .index .Tensor ,
@@ -194,25 +214,48 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
194
214
input_node = node .args [0 ]
195
215
input_shape = input_node .meta ["val" ].shape
196
216
output_shape = node .meta ["val" ].shape
197
-
198
217
self ._insert_view_transpose (
199
- input_shape , output_shape , node , input_node , graph_module
218
+ input_shape ,
219
+ output_shape ,
220
+ node ,
221
+ input_node ,
222
+ graph_module ,
200
223
)
201
224
225
+ # Transpose inputs
226
+ elif _is_input (node , self .exported_program ):
227
+ input_shape = get_first_fake_tensor (node ).size ()
228
+ if len (input_shape ) in (4 , 5 ):
229
+ ToTosaMemoryFormatPass .insert_output_transpose (node , graph_module )
230
+
231
+ # Transpose outputs
232
+ elif node .op == "output" :
233
+ output_shape = get_first_fake_tensor (node ).size ()
234
+
235
+ if len (output_shape ) in (4 , 5 ):
236
+ for input_node in node .all_input_nodes :
237
+ ToTosaMemoryFormatPass .insert_input_transpose (
238
+ node , input_node , graph_module
239
+ )
240
+
202
241
def call (self , graph_module : torch .fx .GraphModule ):
203
242
for node in graph_module .graph .nodes :
204
243
node_data = get_first_fake_tensor (node ).data
205
244
206
- if node_data .dim () == 4 :
245
+ # Inputs and outputs are always in (N)NCHW format
246
+ if _is_input (node , self .exported_program ) or node .op == "output" :
247
+ dim_order = tuple (range (node_data .dim ()))
248
+ elif node_data .dim () == 4 :
207
249
dim_order = self .NHWC_order
208
250
if self .is_weight_node_for_depthwise_conv2d (node ):
209
251
# The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
210
252
# dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
211
253
dim_order = self .HWCM_order
212
254
elif node_data .dim () == 5 :
213
- dim_order = self .NNHWC_order # type: ignore[assignment]
255
+ dim_order = self .NNHWC_order
214
256
else :
215
257
dim_order = tuple (range (node_data .dim ())) # type: ignore[assignment]
258
+
216
259
node .meta ["tosa_dim_order" ] = dim_order
217
260
# Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format.
218
261
# See insert_tosa_transposes for insertion conditions.
0 commit comments