@@ -173,14 +173,60 @@ def test_qnn_backend_arange(self):
173173 self .lower_module_and_test_output (module , sample_input )
174174
175175 def test_qnn_backend_argmax (self ):
176- module = Argmax () # noqa: F405
177- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
178- self .lower_module_and_test_output (module , sample_input )
176+ test_cases = [
177+ {
178+ "module" : Argmax (), # noqa: F405
179+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
180+ },
181+ {
182+ "module" : Argmax (dim = 0 , keepdim = True ), # noqa: F405
183+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
184+ },
185+ {
186+ "module" : Argmax (dim = 1 , keepdim = False ), # noqa: F405
187+ "sample_input" : (torch .randn (8 , 5 ),),
188+ },
189+ {
190+ "module" : Argmax (dim = None , keepdim = False ), # noqa: F405
191+ "sample_input" : (torch .tensor ([5.0 ]),),
192+ },
193+ {
194+ "module" : Argmax (dim = 2 , keepdim = True ), # noqa: F405
195+ "sample_input" : (torch .randn (2 , 3 , 4 ),),
196+ },
197+ ]
198+
199+ for i , case in enumerate (test_cases ):
200+ with self .subTest (i = i ):
201+ self .lower_module_and_test_output (case ["module" ], case ["sample_input" ])
179202
180203 def test_qnn_backend_argmin (self ):
181- module = Argmin () # noqa: F405
182- sample_input = (torch .rand (3 , 4 ),)
183- self .lower_module_and_test_output (module , sample_input )
204+ test_cases = [
205+ {
206+ "module" : Argmin (), # noqa: F405
207+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
208+ },
209+ {
210+ "module" : Argmin (dim = 0 , keepdim = True ), # noqa: F405
211+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
212+ },
213+ {
214+ "module" : Argmin (dim = 1 , keepdim = False ), # noqa: F405
215+ "sample_input" : (torch .randn (8 , 5 ),),
216+ },
217+ {
218+ "module" : Argmin (dim = None , keepdim = False ), # noqa: F405
219+ "sample_input" : (torch .tensor ([5.0 ]),),
220+ },
221+ {
222+ "module" : Argmin (dim = 2 , keepdim = True ), # noqa: F405
223+ "sample_input" : (torch .randn (2 , 3 , 4 ),),
224+ },
225+ ]
226+
227+ for i , case in enumerate (test_cases ):
228+ with self .subTest (i = i ):
229+ self .lower_module_and_test_output (case ["module" ], case ["sample_input" ])
184230
185231 @unittest .expectedFailure
186232 def test_qnn_backend_asin (self ):
@@ -1757,16 +1803,62 @@ def test_qnn_backend_arange(self):
17571803 self .lower_module_and_test_output (module , sample_input )
17581804
17591805 def test_qnn_backend_argmax (self ):
1760- module = Argmax () # noqa: F405
1761- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
1762- module = self .get_qdq_module (module , sample_input )
1763- self .lower_module_and_test_output (module , sample_input )
1806+ test_cases = [
1807+ {
1808+ "module" : Argmax (), # noqa: F405
1809+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
1810+ },
1811+ {
1812+ "module" : Argmax (dim = 0 , keepdim = True ), # noqa: F405
1813+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
1814+ },
1815+ {
1816+ "module" : Argmax (dim = 1 , keepdim = False ), # noqa: F405
1817+ "sample_input" : (torch .randn (8 , 5 ),),
1818+ },
1819+ {
1820+ "module" : Argmax (dim = None , keepdim = False ), # noqa: F405
1821+ "sample_input" : (torch .tensor ([5.0 ]),),
1822+ },
1823+ {
1824+ "module" : Argmax (dim = 2 , keepdim = True ), # noqa: F405
1825+ "sample_input" : (torch .randn (2 , 3 , 4 ),),
1826+ },
1827+ ]
1828+
1829+ for i , case in enumerate (test_cases ):
1830+ with self .subTest (i = i ):
1831+ module = self .get_qdq_module (case ["module" ], case ["sample_input" ])
1832+ self .lower_module_and_test_output (module , case ["sample_input" ])
17641833
17651834 def test_qnn_backend_argmin (self ):
1766- module = Argmin () # noqa: F405
1767- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
1768- module = self .get_qdq_module (module , sample_input )
1769- self .lower_module_and_test_output (module , sample_input )
1835+ test_cases = [
1836+ {
1837+ "module" : Argmin (), # noqa: F405
1838+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
1839+ },
1840+ {
1841+ "module" : Argmin (dim = 0 , keepdim = True ), # noqa: F405
1842+ "sample_input" : (torch .randn (16 , 3 , 4 , 4 ),),
1843+ },
1844+ {
1845+ "module" : Argmin (dim = 1 , keepdim = False ), # noqa: F405
1846+ "sample_input" : (torch .randn (8 , 5 ),),
1847+ },
1848+ {
1849+ "module" : Argmin (dim = None , keepdim = False ), # noqa: F405
1850+ "sample_input" : (torch .tensor ([5.0 ]),),
1851+ },
1852+ {
1853+ "module" : Argmin (dim = 2 , keepdim = True ), # noqa: F405
1854+ "sample_input" : (torch .randn (2 , 3 , 4 ),),
1855+ },
1856+ ]
1857+
1858+ for i , case in enumerate (test_cases ):
1859+ with self .subTest (i = i ):
1860+ module = self .get_qdq_module (case ["module" ], case ["sample_input" ])
1861+ self .lower_module_and_test_output (module , case ["sample_input" ])
17701862
17711863 def test_qnn_backend_asin (self ):
17721864 module = Asin () # noqa: F405
0 commit comments