Skip to content

Commit e1de9d0

Browse files
cccclaifacebook-github-bot
authored andcommitted
support argmax without dim kwargs (pytorch#14710)
Summary: As title, in PyTorch, when dim is not set, it will flatten the input and get argmax as dim=0. Add a pass to reshape the input when dim is not set and consolidate test case Differential Revision: D83606497
1 parent 65100f6 commit e1de9d0

File tree

9 files changed

+186
-14
lines changed

9 files changed

+186
-14
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .fuse_consecutive_transpose import FuseConsecutiveTranspose
3232
from .i64_to_i32 import I64toI32
3333
from .insert_io_qdq import InsertIOQDQ
34+
from .insert_reshape_for_argmax import InsertReshapeForArgmax
3435
from .insert_requantize import InsertRequantize
3536
from .layout_transform import LayoutTransform
3637
from .lift_constant_scalar_operands import LiftConstantScalarOperands
@@ -44,7 +45,6 @@
4445
from .seq_mse import SeqMSE
4546
from .tag_quant_io import TagQuantIO
4647

47-
4848
__all__ = [
4949
AnnotateAdaptiveAvgPool1D,
5050
AnnotateQuantAttrs,
@@ -73,6 +73,7 @@
7373
FuseConsecutiveTranspose,
7474
I64toI32,
7575
InsertIOQDQ,
76+
InsertReshapeForArgmax,
7677
InsertRequantize,
7778
LayoutTransform,
7879
LiftConstantScalarOperands,
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
from executorch.exir.passes import dead_code_elimination_pass
10+
11+
class InsertReshapeForArgmax(ExportPass):
12+
"""
13+
Rewrite `aten.argmax.default` with `dim=None` into
14+
a reshape-to-1D followed by argmax(dim=0).
15+
16+
PyTorch semantics:
17+
torch.argmax(x, dim=None) -> flatten(x) then argmax along axis=0
18+
19+
QNN requires an explicit axis, so we insert the reshape.
20+
"""
21+
22+
def __init__(self):
23+
super().__init__()
24+
self.op_map = {torch.ops.aten.argmax.default}
25+
26+
def call(self, graph_module: torch.fx.GraphModule):
27+
graph = graph_module.graph
28+
modified = False
29+
30+
for n in list(graph.nodes):
31+
if n.target in self.op_map:
32+
dim_arg = None if len(n.args) == 1 else n.args[1]
33+
34+
if dim_arg is None:
35+
inp = n.args[0]
36+
37+
# Insert reshape before argmax
38+
with graph.inserting_before(n):
39+
reshape_node = graph.create_node(
40+
"call_function",
41+
torch.ops.aten.reshape.default,
42+
(inp, [-1]),
43+
{},
44+
)
45+
reshape_node.meta = dict(inp.meta)
46+
if "val" in inp.meta:
47+
reshape_node.meta["val"] = inp.meta["val"].reshape(-1)
48+
49+
# Rewrite argmax: take reshape_node as input, set dim=0
50+
n.args = (reshape_node, 0, *n.args[2:])
51+
52+
modified = True
53+
54+
55+
if modified:
56+
graph_module.recompile()
57+
dead_code_elimination_pass(graph_module)
58+
59+
return PassResult(graph_module, modified)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
I64toI32,
3838
InsertIOQDQ,
3939
InsertRequantize,
40+
InsertReshapeForArgmax,
4041
LayoutTransform,
4142
LiftConstantScalarOperands,
4243
RecomposePixelUnshuffle,
@@ -207,6 +208,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
207208
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
208209
self.add_pass(ReplaceInfValues())
209210
self.add_pass(LiftConstantScalarOperands())
211+
self.add_pass(InsertReshapeForArgmax())
210212
return self._transform(graph_module)
211213

212214
def transform_for_export_pipeline(
@@ -226,6 +228,7 @@ def transform_for_export_pipeline(
226228
self.add_pass(ConvertLinearToConv2d(exported_program))
227229
self.add_pass(ConvertSquareToPow())
228230
self.add_pass(LiftConstantScalarOperands())
231+
self.add_pass(InsertReshapeForArgmax())
229232
self._transform(exported_program.graph_module)
230233
ep = lift_constant_tensor_pass(exported_program)
231234
return ep

backends/qualcomm/builders/op_argmax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def define_node(
4646
nodes_to_wrappers,
4747
)
4848
argmax_output_tensors = [argmax_out_tensor_wrapper]
49-
49+
print("[ARGMAX]: node: ", node, " node.args: ", node.args)
5050
dim = cast(int, node.args[1])
5151
if dim < 0:
5252
dim = dim % len(input_tensor.shape)

backends/qualcomm/quantizer/quantizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,10 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
387387
Returns:
388388
GraphModule: The transformed model.
389389
"""
390-
return QnnPassManager().transform_for_annotation_pipeline(model)
390+
ret = QnnPassManager().transform_for_annotation_pipeline(model)
391+
print("after transform_for_annotation")
392+
ret.print_readable()
393+
return ret
391394

392395
def validate(self, model: GraphModule) -> None:
393396
pass

backends/qualcomm/tests/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,17 @@ runtime.python_library(
4747
":test_qnn_delegate"
4848
]
4949
)
50+
51+
runtime.python_test(
52+
name = "test_passes",
53+
srcs = [
54+
"test_passes.py",
55+
],
56+
deps = [
57+
"fbsource//third-party/pypi/expecttest:expecttest", # @manual
58+
"//caffe2:torch",
59+
"//executorch/exir:lib",
60+
"//executorch/backends/qualcomm/_passes:passes",
61+
"//executorch/backends/qualcomm/builders:builders",
62+
],
63+
)

backends/qualcomm/tests/models.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
8+
from typing import Optional
99

1010
# module with related operator only
1111

@@ -170,12 +170,13 @@ def forward(self, y):
170170

171171

172172
class Argmax(torch.nn.Module):
173-
def __init__(self):
173+
def __init__(self, dim: Optional[int] = None, keepdim: bool = False):
174174
super().__init__()
175+
self.dim = dim
176+
self.keepdim = keepdim
175177

176178
def forward(self, x):
177-
x = torch.argmax(x, dim=0, keepdim=True)
178-
return x
179+
return torch.argmax(x, dim=self.dim, keepdim=self.keepdim)
179180

180181

181182
class Argmin(torch.nn.Module):
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import unittest
2+
import torch
3+
from executorch.backends.qualcomm._passes import InsertReshapeForArgmax
4+
5+
class TestPasses(unittest.TestCase):
6+
def test_insert_reshape_for_argmax(self):
7+
class ArgmaxModule(torch.nn.Module):
8+
def forward(self, x):
9+
return torch.argmax(x, dim=None)
10+
11+
mod = ArgmaxModule()
12+
13+
x = torch.tensor([[1.0, 5.0], [3.0, 2.0]])
14+
ep = torch.export.export(mod, (x, ))
15+
# Run original module for reference
16+
ref = mod(x)
17+
18+
reshape_nodes = [n for n in ep.graph.nodes if n.target == torch.ops.aten.reshape.default]
19+
argmax_nodes = [n for n in ep.graph.nodes if n.target == torch.ops.aten.argmax.default]
20+
self.assertTrue(len(reshape_nodes) == 0, "Reshape node not inserted")
21+
self.assertTrue(len(argmax_nodes) == 1, "Argmax node missing")
22+
23+
InsertReshapeForArgmax()(ep.graph_module)
24+
25+
out = ep.graph_module(x)
26+
27+
# Check graph structure: argmax should take a reshape as input
28+
reshape_nodes = [n for n in ep.graph.nodes if n.target == torch.ops.aten.reshape.default]
29+
argmax_nodes = [n for n in ep.graph.nodes if n.target == torch.ops.aten.argmax.default]
30+
self.assertTrue(len(reshape_nodes) == 1, "Reshape node should be inserted")
31+
self.assertTrue(len(argmax_nodes) == 1, "Argmax node missing")
32+
33+
34+
argmax_node = argmax_nodes[0]
35+
self.assertEqual(argmax_node.args[1], 0, "Argmax dim not set to 0")
36+
37+
# Execute new graph and compare with reference
38+
out = ep.graph_module(x)
39+
self.assertTrue(torch.equal(*out, ref), f"Output mismatch: got {out}, expected {ref}")
40+
41+
if __name__ == "__main__":
42+
unittest.main()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,33 @@ def test_qnn_backend_arange(self):
173173
self.lower_module_and_test_output(module, sample_input)
174174

175175
def test_qnn_backend_argmax(self):
176-
module = Argmax() # noqa: F405
177-
sample_input = (torch.randn(16, 3, 4, 4),)
178-
self.lower_module_and_test_output(module, sample_input)
176+
test_cases = [
177+
{
178+
"module": Argmax(),
179+
"sample_input": (torch.randn(16, 3, 4, 4),),
180+
},
181+
{
182+
"module": Argmax(dim=0, keepdim=True),
183+
"sample_input": (torch.randn(16, 3, 4, 4),),
184+
},
185+
{
186+
"module": Argmax(dim=1, keepdim=False),
187+
"sample_input": (torch.randn(8, 5),),
188+
},
189+
{
190+
"module": Argmax(dim=None, keepdim=False),
191+
"sample_input": (torch.tensor([5.0]),),
192+
},
193+
{
194+
"module": Argmax(dim=2, keepdim=True),
195+
"sample_input": (torch.randn(2, 3, 4),),
196+
},
197+
]
198+
199+
for i, case in enumerate(test_cases):
200+
with self.subTest(i=i):
201+
self.lower_module_and_test_output(case["module"], case["sample_input"])
202+
179203

180204
def test_qnn_backend_argmin(self):
181205
module = Argmin() # noqa: F405
@@ -1709,11 +1733,36 @@ def test_qnn_backend_arange(self):
17091733
module = self.get_qdq_module(module, sample_input)
17101734
self.lower_module_and_test_output(module, sample_input)
17111735

1736+
17121737
def test_qnn_backend_argmax(self):
1713-
module = Argmax() # noqa: F405
1714-
sample_input = (torch.randn(16, 3, 4, 4),)
1715-
module = self.get_qdq_module(module, sample_input)
1716-
self.lower_module_and_test_output(module, sample_input)
1738+
test_cases = [
1739+
{
1740+
"module": Argmax(),
1741+
"sample_input": (torch.randn(16, 3, 4, 4),),
1742+
},
1743+
{
1744+
"module": Argmax(dim=0, keepdim=True),
1745+
"sample_input": (torch.randn(16, 3, 4, 4),),
1746+
},
1747+
{
1748+
"module": Argmax(dim=1, keepdim=False),
1749+
"sample_input": (torch.randn(8, 5),),
1750+
},
1751+
{
1752+
"module": Argmax(dim=None, keepdim=False),
1753+
"sample_input": (torch.tensor([5.0]),),
1754+
},
1755+
{
1756+
"module": Argmax(dim=2, keepdim=True),
1757+
"sample_input": (torch.randn(2, 3, 4),),
1758+
},
1759+
]
1760+
1761+
for i, case in enumerate(test_cases):
1762+
with self.subTest(i=i):
1763+
module = self.get_qdq_module(case["module"], case["sample_input"])
1764+
self.lower_module_and_test_output(case["module"], case["sample_input"])
1765+
17171766

17181767
def test_qnn_backend_argmin(self):
17191768
module = Argmin() # noqa: F405

0 commit comments

Comments
 (0)