@@ -1593,15 +1593,17 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
1593
1593
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_5 , "autoquant requires 2.5+." )
1594
1594
def test_autoquant_compile (self , device , dtype , m1 , m2 , k , n ):
1595
1595
undo_recommended_configs ()
1596
-
1596
+
1597
1597
# Check if we're running on a supported device
1598
1598
is_cuda_available = torch .cuda .is_available ()
1599
1599
is_rocm_available = torch .version .hip is not None
1600
- is_supported_device = device == "cuda" and (is_cuda_available or is_rocm_available )
1601
-
1600
+ is_supported_device = device == "cuda" and (
1601
+ is_cuda_available or is_rocm_available
1602
+ )
1603
+
1602
1604
if not is_supported_device :
1603
1605
self .skipTest (f"autoquant currently does not support { device } " )
1604
-
1606
+
1605
1607
# Check CUDA-specific requirements if running on CUDA
1606
1608
if is_cuda_available and not is_rocm_available :
1607
1609
device_capability = torch .cuda .get_device_capability ()
@@ -1610,7 +1612,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
1610
1612
self .skipTest ("bfloat16 requires sm80+" )
1611
1613
if m1 == 1 or m2 == 1 :
1612
1614
self .skipTest (f"Shape { (m1 , m2 , k , n )} requires sm80+" )
1613
-
1615
+
1614
1616
# Skip certain shapes on older PyTorch versions
1615
1617
if (m1 == 1 or m2 == 1 ) and not TORCH_VERSION_AT_LEAST_2_5 :
1616
1618
self .skipTest (f"Shape { (m1 , m2 , k , n )} requires torch version > 2.4" )
0 commit comments