|
28 | 28 | from torchao.quantization import quantize_
|
29 | 29 | from torchao.quantization.utils import compute_error
|
30 | 30 | from torchao.utils import (
|
| 31 | + TORCH_VERSION_AT_LEAST_2_7, |
31 | 32 | TORCH_VERSION_AT_LEAST_2_8,
|
32 | 33 | is_sm_at_least_89,
|
33 | 34 | is_sm_at_least_100,
|
34 | 35 | )
|
35 | 36 |
|
36 | 37 | torch.manual_seed(2)
|
37 | 38 |
|
38 |
| -if not TORCH_VERSION_AT_LEAST_2_8: |
| 39 | +if not TORCH_VERSION_AT_LEAST_2_7: |
39 | 40 | pytest.skip("Unsupported PyTorch version", allow_module_level=True)
|
40 | 41 |
|
41 | 42 |
|
@@ -222,6 +223,8 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
|
222 | 223 | pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
|
223 | 224 |
|
224 | 225 | if recipe_name in ["mxfp8_cublas", "mxfp8_cutlass", "mxfp4_cutlass"]:
|
| 226 | + if not TORCH_VERSION_AT_LEAST_2_8: |
| 227 | + pytest.skip("torch.compile requires PyTorch 2.8+") |
225 | 228 | if not is_sm_at_least_100():
|
226 | 229 | pytest.skip("CUDA capability >= 10.0 required for MX gemms")
|
227 | 230 |
|
@@ -308,6 +311,9 @@ def test_inference_linear(elem_dtype, bias, input_shape):
|
308 | 311 |
|
309 | 312 |
|
310 | 313 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
| 314 | +@pytest.mark.skipif( |
| 315 | + not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" |
| 316 | +) |
311 | 317 | @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
|
312 | 318 | def test_inference_compile_simple(elem_dtype):
|
313 | 319 | """
|
|
0 commit comments