@@ -173,9 +173,32 @@ def test_qnn_backend_arange(self):
173
173
self .lower_module_and_test_output (module , sample_input )
174
174
175
175
def test_qnn_backend_argmax (self ):
176
- module = Argmax () # noqa: F405
177
- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
178
- self .lower_module_and_test_output (module , sample_input )
176
+ test_cases = [
177
+ {
178
+ "module" : Argmax (), # noqa: F405
179
+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
180
+ },
181
+ {
182
+ "module" : Argmax (dim = 0 , keepdim = True ), # noqa: F405
183
+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
184
+ },
185
+ {
186
+ "module" : Argmax (dim = 1 , keepdim = False ), # noqa: F405
187
+ "sample_input" : (torch .randn (8 , 5 ),),
188
+ },
189
+ {
190
+ "module" : Argmax (dim = None , keepdim = False ), # noqa: F405
191
+ "sample_input" : (torch .tensor ([5.0 ]),),
192
+ },
193
+ {
194
+ "module" : Argmax (dim = 2 , keepdim = True ), # noqa: F405
195
+ "sample_input" : (torch .randn (2 , 3 , 4 ),),
196
+ },
197
+ ]
198
+
199
+ for i , case in enumerate (test_cases ):
200
+ with self .subTest (i = i ):
201
+ self .lower_module_and_test_output (case ["module" ], case ["sample_input" ])
179
202
180
203
def test_qnn_backend_argmin (self ):
181
204
module = Argmin () # noqa: F405
@@ -1757,10 +1780,33 @@ def test_qnn_backend_arange(self):
1757
1780
self .lower_module_and_test_output (module , sample_input )
1758
1781
1759
1782
def test_qnn_backend_argmax (self ):
1760
- module = Argmax () # noqa: F405
1761
- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
1762
- module = self .get_qdq_module (module , sample_input )
1763
- self .lower_module_and_test_output (module , sample_input )
1783
+ test_cases = [
1784
+ {
1785
+ "module" : Argmax (), # noqa: F405
1786
+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
1787
+ },
1788
+ {
1789
+ "module" : Argmax (dim = 0 , keepdim = True ), # noqa: F405
1790
+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
1791
+ },
1792
+ {
1793
+ "module" : Argmax (dim = 1 , keepdim = False ), # noqa: F405
1794
+ "sample_input" : (torch .randn (8 , 5 ),),
1795
+ },
1796
+ {
1797
+ "module" : Argmax (dim = None , keepdim = False ), # noqa: F405
1798
+ "sample_input" : (torch .tensor ([5.0 ]),),
1799
+ },
1800
+ {
1801
+ "module" : Argmax (dim = 2 , keepdim = True ), # noqa: F405
1802
+ "sample_input" : (torch .randn (2 , 3 , 4 ),),
1803
+ },
1804
+ ]
1805
+
1806
+ for i , case in enumerate (test_cases ):
1807
+ with self .subTest (i = i ):
1808
+ module = self .get_qdq_module (case ["module" ], case ["sample_input" ])
1809
+ self .lower_module_and_test_output (module , case ["sample_input" ])
1764
1810
1765
1811
def test_qnn_backend_argmin (self ):
1766
1812
module = Argmin () # noqa: F405
0 commit comments