Skip to content

Commit 5077d00

Browse files
cccclaifacebook-github-bot
authored andcommitted
support argmax/argmin without dim kwargs (pytorch#14710)
Summary: As title, in PyTorch, when dim is not set, it will flatten the input and get argmax as dim=0. Add a pass to reshape the input when dim is not set and consolidate test case Differential Revision: D83606497
1 parent b021fd0 commit 5077d00

File tree

7 files changed

+246
-21
lines changed

7 files changed

+246
-21
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .i64_to_i32 import I64toI32
3333
from .insert_io_qdq import InsertIOQDQ
3434
from .insert_requantize import InsertRequantize
35+
from .insert_reshape_for_reduce_ops import InsertReshapeForReduceOps
3536
from .layout_transform import LayoutTransform
3637
from .lift_constant_scalar_operands import LiftConstantScalarOperands
3738
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
@@ -44,7 +45,6 @@
4445
from .seq_mse import SeqMSE
4546
from .tag_quant_io import TagQuantIO
4647

47-
4848
__all__ = [
4949
AnnotateAdaptiveAvgPool1D,
5050
AnnotateQuantAttrs,
@@ -73,6 +73,7 @@
7373
FuseConsecutiveTranspose,
7474
I64toI32,
7575
InsertIOQDQ,
76+
InsertReshapeForReduceOps,
7677
InsertRequantize,
7778
LayoutTransform,
7879
LiftConstantScalarOperands,
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
from executorch.exir.passes import dead_code_elimination_pass
10+
11+
12+
class InsertReshapeForReduceOps(ExportPass):
13+
"""
14+
Rewrite `aten.argmax.default` with `dim=None` into
15+
a reshape-to-1D followed by argmax(dim=0).
16+
17+
PyTorch semantics:
18+
torch.argmax(x, dim=None) -> flatten(x) then argmax along axis=0
19+
20+
QNN requires an explicit axis, so we insert the reshape.
21+
"""
22+
23+
def __init__(self):
24+
super().__init__()
25+
self.op_map = {torch.ops.aten.argmax.default, torch.ops.aten.argmin.default}
26+
27+
def call(self, graph_module: torch.fx.GraphModule):
28+
graph = graph_module.graph
29+
modified = False
30+
31+
for n in graph.nodes:
32+
if n.target in self.op_map:
33+
dim_arg = None if len(n.args) == 1 else n.args[1]
34+
35+
if dim_arg is None:
36+
inp = n.args[0]
37+
38+
# Insert reshape before argmax
39+
with graph.inserting_before(n):
40+
reshape_node = graph.create_node(
41+
"call_function",
42+
torch.ops.aten.reshape.default,
43+
(inp, [-1]),
44+
{},
45+
)
46+
reshape_node.meta = dict(inp.meta)
47+
if "val" in inp.meta:
48+
reshape_node.meta["val"] = inp.meta["val"].reshape(-1)
49+
50+
# Rewrite argmax: take reshape_node as input, set dim=0
51+
n.args = (reshape_node, 0, *n.args[2:])
52+
53+
modified = True
54+
55+
if modified:
56+
graph_module.recompile()
57+
dead_code_elimination_pass(graph_module)
58+
59+
return PassResult(graph_module, modified)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
I64toI32,
3838
InsertIOQDQ,
3939
InsertRequantize,
40+
InsertReshapeForReduceOps,
4041
LayoutTransform,
4142
LiftConstantScalarOperands,
4243
RecomposePixelUnshuffle,
@@ -207,6 +208,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
207208
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
208209
self.add_pass(ReplaceInfValues())
209210
self.add_pass(LiftConstantScalarOperands())
211+
self.add_pass(InsertReshapeForReduceOps())
210212
return self._transform(graph_module)
211213

212214
def transform_for_export_pipeline(
@@ -226,6 +228,7 @@ def transform_for_export_pipeline(
226228
self.add_pass(ConvertLinearToConv2d(exported_program))
227229
self.add_pass(ConvertSquareToPow())
228230
self.add_pass(LiftConstantScalarOperands())
231+
self.add_pass(InsertReshapeForReduceOps())
229232
self._transform(exported_program.graph_module)
230233
ep = lift_constant_tensor_pass(exported_program)
231234
return ep

backends/qualcomm/tests/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,17 @@ runtime.python_library(
4747
":test_qnn_delegate"
4848
]
4949
)
50+
51+
runtime.python_test(
52+
name = "test_passes",
53+
srcs = [
54+
"test_passes.py",
55+
],
56+
deps = [
57+
"fbsource//third-party/pypi/expecttest:expecttest", # @manual
58+
"//caffe2:torch",
59+
"//executorch/exir:lib",
60+
"//executorch/backends/qualcomm/_passes:passes",
61+
"//executorch/backends/qualcomm/builders:builders",
62+
],
63+
)

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,21 +171,23 @@ def forward(self, y):
171171

172172

173173
class Argmax(torch.nn.Module):
174-
def __init__(self):
174+
def __init__(self, dim: Optional[int] = None, keepdim: bool = False):
175175
super().__init__()
176+
self.dim = dim
177+
self.keepdim = keepdim
176178

177179
def forward(self, x):
178-
x = torch.argmax(x, dim=0, keepdim=True)
179-
return x
180+
return torch.argmax(x, dim=self.dim, keepdim=self.keepdim)
180181

181182

182183
class Argmin(torch.nn.Module):
183-
def __init__(self):
184+
def __init__(self, dim: Optional[int] = None, keepdim: bool = False):
184185
super().__init__()
186+
self.dim = dim
187+
self.keepdim = keepdim
185188

186189
def forward(self, x):
187-
x = torch.argmin(x, dim=0, keepdim=True)
188-
return x
190+
return torch.argmin(x, dim=self.dim, keepdim=self.keepdim)
189191

190192

191193
class ArgminViewSqueezeConv2D(torch.nn.Module):
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import unittest
2+
3+
import torch
4+
from executorch.backends.qualcomm._passes import InsertReshapeForReduceOps
5+
6+
7+
class TestPasses(unittest.TestCase):
8+
def test_insert_reshape_for_argmax(self):
9+
class ArgmaxModule(torch.nn.Module):
10+
def forward(self, x):
11+
return torch.argmax(x, dim=None)
12+
13+
mod = ArgmaxModule()
14+
15+
x = torch.tensor([[1.0, 5.0], [3.0, 2.0]])
16+
ep = torch.export.export(mod, (x,))
17+
# Run original module for reference
18+
ref = mod(x)
19+
20+
reshape_nodes = [
21+
n for n in ep.graph.nodes if n.target == torch.ops.aten.reshape.default
22+
]
23+
argmax_nodes = [
24+
n for n in ep.graph.nodes if n.target == torch.ops.aten.argmax.default
25+
]
26+
self.assertTrue(len(reshape_nodes) == 0, "Reshape node not inserted")
27+
self.assertTrue(len(argmax_nodes) == 1, "Argmax node missing")
28+
29+
InsertReshapeForReduceOps()(ep.graph_module)
30+
31+
out = ep.graph_module(x)
32+
33+
# Check graph structure: argmax should take a reshape as input
34+
reshape_nodes = [
35+
n for n in ep.graph.nodes if n.target == torch.ops.aten.reshape.default
36+
]
37+
argmax_nodes = [
38+
n for n in ep.graph.nodes if n.target == torch.ops.aten.argmax.default
39+
]
40+
self.assertTrue(len(reshape_nodes) == 1, "Reshape node should be inserted")
41+
self.assertTrue(len(argmax_nodes) == 1, "Argmax node missing")
42+
43+
argmax_node = argmax_nodes[0]
44+
self.assertEqual(argmax_node.args[1], 0, "Argmax dim not set to 0")
45+
46+
# Execute new graph and compare with reference
47+
out = ep.graph_module(x)
48+
self.assertTrue(
49+
torch.equal(*out, ref), f"Output mismatch: got {out}, expected {ref}"
50+
)
51+
52+
53+
if __name__ == "__main__":
54+
unittest.main()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 106 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,60 @@ def test_qnn_backend_arange(self):
173173
self.lower_module_and_test_output(module, sample_input)
174174

175175
def test_qnn_backend_argmax(self):
176-
module = Argmax() # noqa: F405
177-
sample_input = (torch.randn(16, 3, 4, 4),)
178-
self.lower_module_and_test_output(module, sample_input)
176+
test_cases = [
177+
{
178+
"module": Argmax(), # noqa: F405
179+
"sample_input": (torch.randn(16, 3, 4, 4),),
180+
},
181+
{
182+
"module": Argmax(dim=0, keepdim=True), # noqa: F405
183+
"sample_input": (torch.randn(16, 3, 4, 4),),
184+
},
185+
{
186+
"module": Argmax(dim=1, keepdim=False), # noqa: F405
187+
"sample_input": (torch.randn(8, 5),),
188+
},
189+
{
190+
"module": Argmax(dim=None, keepdim=False), # noqa: F405
191+
"sample_input": (torch.tensor([5.0]),),
192+
},
193+
{
194+
"module": Argmax(dim=2, keepdim=True), # noqa: F405
195+
"sample_input": (torch.randn(2, 3, 4),),
196+
},
197+
]
198+
199+
for i, case in enumerate(test_cases):
200+
with self.subTest(i=i):
201+
self.lower_module_and_test_output(case["module"], case["sample_input"])
179202

180203
def test_qnn_backend_argmin(self):
181-
module = Argmin() # noqa: F405
182-
sample_input = (torch.rand(3, 4),)
183-
self.lower_module_and_test_output(module, sample_input)
204+
test_cases = [
205+
{
206+
"module": Argmin(), # noqa: F405
207+
"sample_input": (torch.randn(16, 3, 4, 4),),
208+
},
209+
{
210+
"module": Argmin(dim=0, keepdim=True), # noqa: F405
211+
"sample_input": (torch.randn(16, 3, 4, 4),),
212+
},
213+
{
214+
"module": Argmin(dim=1, keepdim=False), # noqa: F405
215+
"sample_input": (torch.randn(8, 5),),
216+
},
217+
{
218+
"module": Argmin(dim=None, keepdim=False), # noqa: F405
219+
"sample_input": (torch.tensor([5.0]),),
220+
},
221+
{
222+
"module": Argmin(dim=2, keepdim=True), # noqa: F405
223+
"sample_input": (torch.randn(2, 3, 4),),
224+
},
225+
]
226+
227+
for i, case in enumerate(test_cases):
228+
with self.subTest(i=i):
229+
self.lower_module_and_test_output(case["module"], case["sample_input"])
184230

185231
@unittest.expectedFailure
186232
def test_qnn_backend_asin(self):
@@ -1757,16 +1803,62 @@ def test_qnn_backend_arange(self):
17571803
self.lower_module_and_test_output(module, sample_input)
17581804

17591805
def test_qnn_backend_argmax(self):
1760-
module = Argmax() # noqa: F405
1761-
sample_input = (torch.randn(16, 3, 4, 4),)
1762-
module = self.get_qdq_module(module, sample_input)
1763-
self.lower_module_and_test_output(module, sample_input)
1806+
test_cases = [
1807+
{
1808+
"module": Argmax(), # noqa: F405
1809+
"sample_input": (torch.randn(16, 3, 4, 4),),
1810+
},
1811+
{
1812+
"module": Argmax(dim=0, keepdim=True), # noqa: F405
1813+
"sample_input": (torch.randn(16, 3, 4, 4),),
1814+
},
1815+
{
1816+
"module": Argmax(dim=1, keepdim=False), # noqa: F405
1817+
"sample_input": (torch.randn(8, 5),),
1818+
},
1819+
{
1820+
"module": Argmax(dim=None, keepdim=False), # noqa: F405
1821+
"sample_input": (torch.tensor([5.0]),),
1822+
},
1823+
{
1824+
"module": Argmax(dim=2, keepdim=True), # noqa: F405
1825+
"sample_input": (torch.randn(2, 3, 4),),
1826+
},
1827+
]
1828+
1829+
for i, case in enumerate(test_cases):
1830+
with self.subTest(i=i):
1831+
module = self.get_qdq_module(case["module"], case["sample_input"])
1832+
self.lower_module_and_test_output(module, case["sample_input"])
17641833

17651834
def test_qnn_backend_argmin(self):
1766-
module = Argmin() # noqa: F405
1767-
sample_input = (torch.randn(16, 3, 4, 4),)
1768-
module = self.get_qdq_module(module, sample_input)
1769-
self.lower_module_and_test_output(module, sample_input)
1835+
test_cases = [
1836+
{
1837+
"module": Argmin(), # noqa: F405
1838+
"sample_input": (torch.randn(16, 3, 4, 4),),
1839+
},
1840+
{
1841+
"module": Argmin(dim=0, keepdim=True), # noqa: F405
1842+
"sample_input": (torch.randn(16, 3, 4, 4),),
1843+
},
1844+
{
1845+
"module": Argmin(dim=1, keepdim=False), # noqa: F405
1846+
"sample_input": (torch.randn(8, 5),),
1847+
},
1848+
{
1849+
"module": Argmin(dim=None, keepdim=False), # noqa: F405
1850+
"sample_input": (torch.tensor([5.0]),),
1851+
},
1852+
{
1853+
"module": Argmin(dim=2, keepdim=True), # noqa: F405
1854+
"sample_input": (torch.randn(2, 3, 4),),
1855+
},
1856+
]
1857+
1858+
for i, case in enumerate(test_cases):
1859+
with self.subTest(i=i):
1860+
module = self.get_qdq_module(case["module"], case["sample_input"])
1861+
self.lower_module_and_test_output(module, case["sample_input"])
17701862

17711863
def test_qnn_backend_asin(self):
17721864
module = Asin() # noqa: F405

0 commit comments

Comments
 (0)