@@ -1593,15 +1593,17 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
15931593 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_5 , "autoquant requires 2.5+." )
15941594 def test_autoquant_compile (self , device , dtype , m1 , m2 , k , n ):
15951595 undo_recommended_configs ()
1596-
1596+
15971597 # Check if we're running on a supported device
15981598 is_cuda_available = torch .cuda .is_available ()
15991599 is_rocm_available = torch .version .hip is not None
1600- is_supported_device = device == "cuda" and (is_cuda_available or is_rocm_available )
1601-
1600+ is_supported_device = device == "cuda" and (
1601+ is_cuda_available or is_rocm_available
1602+ )
1603+
16021604 if not is_supported_device :
16031605 self .skipTest (f"autoquant currently does not support { device } " )
1604-
1606+
16051607 # Check CUDA-specific requirements if running on CUDA
16061608 if is_cuda_available and not is_rocm_available :
16071609 device_capability = torch .cuda .get_device_capability ()
@@ -1610,7 +1612,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
16101612 self .skipTest ("bfloat16 requires sm80+" )
16111613 if m1 == 1 or m2 == 1 :
16121614 self .skipTest (f"Shape { (m1 , m2 , k , n )} requires sm80+" )
1613-
1615+
16141616 # Skip certain shapes on older PyTorch versions
16151617 if (m1 == 1 or m2 == 1 ) and not TORCH_VERSION_AT_LEAST_2_5 :
16161618 self .skipTest (f"Shape { (m1 , m2 , k , n )} requires torch version > 2.4" )
0 commit comments