diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 9b8e173f38..2a711413f0 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -43,7 +43,7 @@ to_nf4, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import torch_version_at_least bnb_available = False @@ -123,7 +123,7 @@ def test_backward_dtype_match(self, dtype: torch.dtype): @unittest.skipIf(not bnb_available, "Need bnb availble") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( - TORCH_VERSION_AT_LEAST_2_7, reason="Failing in CI" + torch_version_at_least("2.7.0"), reason="Failing in CI" ) # TODO: fix this @skip_if_rocm("ROCm enablement in progress") @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @@ -150,7 +150,7 @@ def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @skip_if_rocm("ROCm enablement in progress") @unittest.skipIf( - TORCH_VERSION_AT_LEAST_2_7, reason="Failing in CI" + torch_version_at_least("2.7.0"), reason="Failing in CI" ) # TODO: fix this @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_nf4_bnb_linear(self, dtype: torch.dtype): diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index afa6cfff99..39cfc1873d 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -76,13 +76,13 @@ ) from torchao.testing.utils import skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_7, benchmark_model, check_cpu_version, check_xpu_version, is_fbcode, is_sm_at_least_89, is_sm_at_least_90, + torch_version_at_least, unwrap_tensor_subclass, ) @@ -1883,7 +1883,7 @@ def forward(self, x): model(x) api(model) - if not TORCH_VERSION_AT_LEAST_2_7: + if not torch_version_at_least("2.7.0"): unwrap_tensor_subclass(model) # running model @@ -1942,7 +1942,7 @@ def forward(self, x): model(x) api(model) - if not TORCH_VERSION_AT_LEAST_2_7: + if not torch_version_at_least("2.7.0"): unwrap_tensor_subclass(model) # running model diff --git a/test/integration/test_vllm.py b/test/integration/test_vllm.py index 4fc863f34f..f798a9cd6a 100644 --- a/test/integration/test_vllm.py +++ b/test/integration/test_vllm.py @@ -17,9 +17,9 @@ import torch from packaging import version -from torchao.utils import TORCH_VERSION_AT_LEAST_2_8 +from torchao.utils import torch_version_at_least -if not TORCH_VERSION_AT_LEAST_2_8: +if not torch_version_at_least("2.8.0"): pytest.skip("Requires PyTorch 2.8 or higher", allow_module_level=True) diff --git a/test/prototype/inductor/test_int8_sdpa_fusion.py b/test/prototype/inductor/test_int8_sdpa_fusion.py index ceb9e840c1..37c7c6994b 100644 --- a/test/prototype/inductor/test_int8_sdpa_fusion.py +++ b/test/prototype/inductor/test_int8_sdpa_fusion.py @@ -15,7 +15,7 @@ _int8_sdpa_init, custom_pass, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import torch_version_at_least class SelfAttnLikeModule(torch.nn.Module): @@ -149,7 +149,8 @@ def _check_common( @skipIfRocm @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later" + not torch_version_at_least("2.7.0"), + reason="int8 sdpa requires torch 2.7 or later", ) @unittest.skipIf( "CPU" not in torch._C._dispatch_dump("torchao::qscaled_dot_product"), diff --git a/test/prototype/moe_training/test_scaled_grouped_mm.py b/test/prototype/moe_training/test_scaled_grouped_mm.py index 4b76b29a27..9b340a900f 100644 --- a/test/prototype/moe_training/test_scaled_grouped_mm.py +++ b/test/prototype/moe_training/test_scaled_grouped_mm.py @@ -8,12 +8,12 @@ import torch from torch.nn import functional as F -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import torch_version_at_least # We need to skip before doing any imports which would use triton, since # triton won't be available on CPU builds and torch < 2.5 if not ( - TORCH_VERSION_AT_LEAST_2_7 + torch_version_at_least("2.7.0") and torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9 ): diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index 53441c297a..988a879b5b 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -22,14 +22,14 @@ from torchao.quantization.utils import compute_error from torchao.testing.utils import skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, is_sm_at_least_100, + torch_version_at_least, ) torch.manual_seed(2) -if not TORCH_VERSION_AT_LEAST_2_8: +if not torch_version_at_least("2.8.0"): pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -45,7 +45,7 @@ def run_around_tests(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" ) @pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]) @pytest.mark.parametrize("bias", [True, False]) @@ -96,7 +96,7 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" ) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("compile", [True, False]) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 0957bf0fb9..024586419a 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -44,14 +44,14 @@ from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, is_sm_at_least_100, + torch_version_at_least, ) torch.manual_seed(0) -if not TORCH_VERSION_AT_LEAST_2_8: +if not torch_version_at_least("2.8.0"): pytest.skip("Unsupported PyTorch version", allow_module_level=True) diff --git a/test/prototype/mx_formats/test_mx_dtensor.py b/test/prototype/mx_formats/test_mx_dtensor.py index 1fd7f13337..9dc850a872 100644 --- a/test/prototype/mx_formats/test_mx_dtensor.py +++ b/test/prototype/mx_formats/test_mx_dtensor.py @@ -15,9 +15,9 @@ import pytest import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7, is_sm_at_least_100 +from torchao.utils import is_sm_at_least_100, torch_version_at_least -if not TORCH_VERSION_AT_LEAST_2_7: +if not torch_version_at_least("2.7.0"): pytest.skip("Unsupported PyTorch version", allow_module_level=True) from torch.distributed._tensor import DTensor, Shard, distribute_tensor diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index b74878a0af..c858657af6 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -26,14 +26,14 @@ from torchao.quantization import quantize_ from torchao.quantization.utils import compute_error from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, is_sm_at_least_100, + torch_version_at_least, ) torch.manual_seed(2) -if not TORCH_VERSION_AT_LEAST_2_8: +if not torch_version_at_least("2.8.0"): pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -57,7 +57,7 @@ def run_around_tests(): # only test one type of mixed-dtype overrides, to save testing time (torch.float8_e4m3fn, torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2), ] - if TORCH_VERSION_AT_LEAST_2_8 + if torch_version_at_least("2.8.0") else [ # test each dtype (torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn), @@ -276,7 +276,7 @@ def test_linear_compile( pytest.skip("CUDA capability >= 8.9 required for float8 in triton") if recipe_name in ["mxfp8_cublas", "mxfp4_cutlass"]: - if not TORCH_VERSION_AT_LEAST_2_8: + if not torch_version_at_least("2.8.0"): pytest.skip("torch.compile requires PyTorch 2.8+") if not is_sm_at_least_100(): pytest.skip("CUDA capability >= 10.0 required for MX gemms") diff --git a/test/prototype/mx_formats/test_mx_mm.py b/test/prototype/mx_formats/test_mx_mm.py index 84bf14f415..7cc876de6b 100644 --- a/test/prototype/mx_formats/test_mx_mm.py +++ b/test/prototype/mx_formats/test_mx_mm.py @@ -13,11 +13,11 @@ from torchao.prototype.mx_formats.mx_tensor import MXTensor from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_100, + torch_version_at_least, ) -if not TORCH_VERSION_AT_LEAST_2_8: +if not torch_version_at_least("2.8.0"): pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -79,7 +79,7 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float: ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}", ) @pytest.mark.parametrize( - "format", ["fp8", "fp4"] if TORCH_VERSION_AT_LEAST_2_8 else ["fp8"] + "format", ["fp8", "fp4"] if torch_version_at_least("2.8.0") else ["fp8"] ) def test_matrix_multiplication(size, format): M, K, N = size diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index ea1b7c6459..870a31e978 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -25,14 +25,14 @@ ) from torchao.quantization.utils import compute_error from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, is_sm_at_least_100, + torch_version_at_least, ) torch.manual_seed(2) -if not TORCH_VERSION_AT_LEAST_2_8: +if not torch_version_at_least("2.8.0"): pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -605,7 +605,7 @@ def to_f8(x): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" ) def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale): from torchao.prototype.mx_formats.nvfp4_tensor import ( @@ -674,7 +674,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold): "use_triton_kernel", [False, True] if torch.cuda.is_available() else [False] ) @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" ) def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): from torchao.prototype.mx_formats.utils import from_blocked, to_blocked @@ -707,7 +707,7 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): ], ) @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape): @@ -746,7 +746,7 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" ) def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec): """ @@ -841,7 +841,7 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" ) def test_nvfp4_swizzled_scales_slicing_errors(slice_dim, slice_spec, expected_error): """ @@ -862,7 +862,7 @@ def test_nvfp4_swizzled_scales_slicing_errors(slice_dim, slice_spec, expected_er @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" ) def test_nvfp4_swizzled_scales_view_semantics(): """ @@ -889,7 +889,7 @@ def test_nvfp4_swizzled_scales_view_semantics(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" ) def test_nvfp4_swizzled_scales_serialization(): """ @@ -931,7 +931,7 @@ def test_nvfp4_swizzled_scales_serialization(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" ) def test_nvfp4_swizzled_scales_get_scales_method(): """ diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index 3712d8929b..3256063deb 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -21,13 +21,13 @@ from torchao.quantization.utils import compute_error from torchao.testing.utils import skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_100, + torch_version_at_least, ) torch.manual_seed(2) -if not TORCH_VERSION_AT_LEAST_2_8: +if not torch_version_at_least("2.8.0"): pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -42,7 +42,7 @@ ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" ) def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale): from torchao.prototype.mx_formats.nvfp4_tensor import ( @@ -107,7 +107,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold): ], ) @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape): @@ -146,7 +146,7 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" ) def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec): """ @@ -241,7 +241,7 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" ) def test_nvfp4_swizzled_scales_slicing_errors(slice_dim, slice_spec, expected_error): """ @@ -262,7 +262,7 @@ def test_nvfp4_swizzled_scales_slicing_errors(slice_dim, slice_spec, expected_er @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" ) def test_nvfp4_swizzled_scales_view_semantics(): """ @@ -289,7 +289,7 @@ def test_nvfp4_swizzled_scales_view_semantics(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" ) def test_nvfp4_swizzled_scales_serialization(): """ @@ -331,7 +331,7 @@ def test_nvfp4_swizzled_scales_serialization(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" ) def test_nvfp4_swizzled_scales_get_scales_method(): """ @@ -425,7 +425,7 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" ) @pytest.mark.parametrize("use_gelu", [True, False]) @pytest.mark.parametrize( diff --git a/test/quantization/pt2e/test_arm_inductor_quantizer.py b/test/quantization/pt2e/test_arm_inductor_quantizer.py index 42a826e43a..f74b6620db 100644 --- a/test/quantization/pt2e/test_arm_inductor_quantizer.py +++ b/test/quantization/pt2e/test_arm_inductor_quantizer.py @@ -36,7 +36,7 @@ from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import ( QUANT_ANNOTATION_KEY, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import torch_version_at_least def skipIfNoArm(fn): @@ -348,7 +348,7 @@ def _test_quantizer( @skipIfNoInductorSupport -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestQuantizePT2EArmInductor(ArmInductorQuantTestCase): @skipIfNoArm def test_conv2d(self): diff --git a/test/quantization/pt2e/test_duplicate_dq.py b/test/quantization/pt2e/test_duplicate_dq.py index 3b5a43726e..90050c4c9f 100644 --- a/test/quantization/pt2e/test_duplicate_dq.py +++ b/test/quantization/pt2e/test_duplicate_dq.py @@ -33,7 +33,7 @@ OP_TO_ANNOTATOR, QuantizationConfig, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import torch_version_at_least class TestHelperModules: @@ -97,7 +97,7 @@ def forward(self, x): @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestDuplicateDQPass(QuantizationTestCase): def _test_duplicate_dq( self, diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index cb54eba66d..eee33e3b13 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -20,7 +20,7 @@ get_symmetric_quantization_config, ) from torchao.testing.pt2e._xnnpack_quantizer_utils import OP_TO_ANNOTATOR -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import torch_version_at_least class TestHelperModules: @@ -64,7 +64,7 @@ def _tag_partitions( # TODO: rename to TestPortMetadataPass to align with the util name? @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestMetaDataPorting(QuantizationTestCase): def _test_quant_tag_preservation_through_decomp( self, model, example_inputs, from_node_to_tags diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index a050f476ef..75e9688806 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -18,16 +18,17 @@ prepare_for_propagation_comparison, ) from torchao.testing.pt2e.utils import PT2ENumericDebuggerTestCase -from torchao.utils import TORCH_VERSION_AT_LEAST_2_8 +from torchao.utils import torch_version_at_least # Increase cache size limit to avoid FailOnRecompileLimitHit error when running multiple tests # that use torch.export.export, which causes many dynamo recompilations -if TORCH_VERSION_AT_LEAST_2_8: +if torch_version_at_least("2.8.0"): torch._dynamo.config.cache_size_limit = 128 @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_8, "Requires torch 2.8 and above, including nightly" + not torch_version_at_least("2.8.0"), + "Requires torch 2.8 and above, including nightly", ) @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") class TestNumericDebuggerInfra(PT2ENumericDebuggerTestCase): diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 4f480a069a..fcf2ac3a47 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -67,11 +67,11 @@ QuantizationConfig, ) from torchao.testing.pt2e.utils import PT2EQuantizationTestCase -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import torch_version_at_least DEVICE_LIST = ["cpu"] + (["cuda"] if TEST_CUDA else []) -if TORCH_VERSION_AT_LEAST_2_7: +if torch_version_at_least("2.7.0"): from torch.testing._internal.common_utils import ( TEST_HPU, ) @@ -80,7 +80,7 @@ @skipIfNoQNNPACK -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestQuantizePT2E(PT2EQuantizationTestCase): def test_simple_quantizer(self): # TODO: use OP_TO_ANNOTATOR @@ -1218,7 +1218,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: @parametrize("dtype", (torch.float32, torch.bfloat16)) @parametrize("quant_dtype", (torch.int16, torch.float8_e5m2, torch.float8_e4m3fn)) def test_quantization_dtype(self, dtype, quant_dtype): - if TORCH_VERSION_AT_LEAST_2_7 and TEST_HPU: + if torch_version_at_least("2.7.0") and TEST_HPU: unittest.SkipTest("test doesn't currently work with HPU") class DtypeActQuantizer(Quantizer): @@ -2015,7 +2015,7 @@ def test_disallow_eval_train(self): m.train() def test_allow_exported_model_train_eval(self): - if TORCH_VERSION_AT_LEAST_2_7 and TEST_HPU: + if torch_version_at_least("2.7.0") and TEST_HPU: unittest.SkipTest("test doesn't currently work with HPU") class M(torch.nn.Module): @@ -2945,7 +2945,7 @@ def has_inplace_ops(graph_module: torch.fx.GraphModule) -> bool: @skipIfNoQNNPACK -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestQuantizePT2EAffineQuantization(PT2EQuantizationTestCase): def test_channel_group_quantization(self): from torchao.quantization.pt2e._affine_quantization import ( diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index 57988e028c..fb1b17ce9f 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -51,7 +51,7 @@ XNNPACKQuantizer, get_symmetric_quantization_config, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import torch_version_at_least class PT2EQATTestCase(QuantizationTestCase): @@ -423,7 +423,7 @@ def _verify_symmetric_xnnpack_qat_graph_helper( ) -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase): """ Base TestCase to be used for all conv-bn[-relu] fusion patterns. @@ -866,7 +866,7 @@ def test_fold_bn_erases_bn_node(self): @skipIfNoQNNPACK -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base): dim = 1 example_inputs = (torch.randn(1, 3, 5),) @@ -876,7 +876,7 @@ class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base): @skipIfNoQNNPACK -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestQuantizePT2EQAT_ConvBn2d(TestQuantizePT2EQAT_ConvBn_Base): dim = 2 example_inputs = (torch.randn(1, 3, 5, 5),) @@ -1045,7 +1045,7 @@ def validate(self, model: torch.fx.GraphModule): @skipIfNoQNNPACK -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestQuantizePT2EQATModels(PT2EQATTestCase): @skip_if_no_torchvision @skipIfNoQNNPACK @@ -1068,7 +1068,7 @@ def test_qat_mobilenet_v2(self): self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestQuantizeMixQATAndPTQ(QuantizationTestCase): class TwoLinear(torch.nn.Module): def __init__(self) -> None: diff --git a/test/quantization/pt2e/test_representation.py b/test/quantization/pt2e/test_representation.py index f79b11213f..cd431c4ccb 100644 --- a/test/quantization/pt2e/test_representation.py +++ b/test/quantization/pt2e/test_representation.py @@ -27,11 +27,11 @@ XNNPACKQuantizer, get_symmetric_quantization_config, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import torch_version_at_least @skipIfNoQNNPACK -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestPT2ERepresentation(QuantizationTestCase): def _test_representation( self, diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 099b77e0db..6e3772c76a 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -45,7 +45,7 @@ X86InductorQuantizer, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_8 +from torchao.utils import torch_version_at_least # The dict value is match_nodes(computation_op+unary_op) unary_list = { @@ -269,7 +269,7 @@ def _test_code_common( torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Requires torch 2.8+") +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Requires torch 2.8+") class TestPatternMatcher(TestPatternMatcherBase): def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False): class M(torch.nn.Module): @@ -2426,7 +2426,7 @@ def matcher_check_fn(): "specialize_float": True, } ) -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Requires torch 2.8+") +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Requires torch 2.8+") class TestDynamicPatternMatcher(TestPatternMatcherBase): def test_qconv2d_maxpool2d_linear_dynamic_cpu(self, include_ops=None): r""" diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index c0ec05350e..0d46771a68 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -35,7 +35,7 @@ QUANT_ANNOTATION_KEY, X86InductorQuantizer, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import torch_version_at_least class NodePosType(Enum): @@ -703,7 +703,7 @@ def _test_quantizer( @skipIfNoInductorSupport -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): @skipIfNoX86 def test_conv2d(self): diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 82275c4587..4263969b2b 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -26,10 +26,10 @@ from torchao.quantization.utils import compute_error from torchao.testing.utils import TorchAOIntegrationTestCase from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, _is_fbgemm_genai_gpu_available, is_sm_at_least_89, is_sm_at_least_90, + torch_version_at_least, ) # Needed since changing args to function causes recompiles @@ -49,7 +49,7 @@ def forward(self, x): # TODO: move tests in test_affine_quantized_float.py here after we migrated all implementations -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") class TestFloat8Tensor(TorchAOIntegrationTestCase): diff --git a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py index de7cd35feb..cc8f10faba 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py @@ -22,9 +22,7 @@ from torchao.quantization.utils import compute_error from torchao.sparsity.sparse_api import apply_fake_sparsity from torchao.testing.utils import skip_if_rocm -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, -) +from torchao.utils import torch_version_at_least BF16_ACT_CONFIG = Int4WeightOnlyConfig( group_size=128, @@ -33,7 +31,7 @@ ) -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") class TestInt4MarlinSparseTensor(TestCase): def setUp(self): diff --git a/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py index 67f8416050..01ef99ae96 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py @@ -22,9 +22,9 @@ ) from torchao.quantization.utils import compute_error from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, _is_fbgemm_genai_gpu_available, is_sm_at_least_90, + torch_version_at_least, ) BF16_ACT_CONFIG = Int4WeightOnlyConfig( @@ -39,7 +39,7 @@ ) -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") @unittest.skipIf( diff --git a/test/quantization/quantize_/workflows/int4/test_int4_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_tensor.py index 4a817c2d3c..c8493f5491 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_tensor.py @@ -16,10 +16,10 @@ from torchao.quantization import Int4WeightOnlyConfig, quantize_ from torchao.quantization.utils import compute_error from torchao.testing.utils import TorchAOIntegrationTestCase -from torchao.utils import TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_90 +from torchao.utils import is_sm_at_least_90, torch_version_at_least -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") class TestInt4Tensor(TorchAOIntegrationTestCase): diff --git a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_tensor.py b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_tensor.py index 3a9480f675..b7eed222af 100644 --- a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_tensor.py +++ b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_tensor.py @@ -18,12 +18,10 @@ ) from torchao.quantization.granularity import PerGroup from torchao.quantization.utils import compute_error -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, -) +from torchao.utils import torch_version_at_least -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") class TestIntxUnpackedTensor(TestCase): def setUp(self): self.config = IntxWeightOnlyConfig( diff --git a/test/quantization/test_da8w4_cpu.py b/test/quantization/test_da8w4_cpu.py index 84f0946841..d4f68c4333 100644 --- a/test/quantization/test_da8w4_cpu.py +++ b/test/quantization/test_da8w4_cpu.py @@ -23,10 +23,7 @@ Int8DynamicActivationInt4WeightConfig, ) from torchao.quantization.quant_primitives import MappingType -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_7, - TORCH_VERSION_AT_LEAST_2_8, -) +from torchao.utils import torch_version_at_least class ToyLinearModel(torch.nn.Module): @@ -53,14 +50,14 @@ class TestDa8w4Cpu(TestCase): "CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"), reason="cpp kernels not built", ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Test only enabled for 2.7+") + @unittest.skipIf(not torch_version_at_least("2.7.0"), "Test only enabled for 2.7+") @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) @common_utils.parametrize("x_dim", [2, 3]) @common_utils.parametrize("bias", [True, False]) @common_utils.parametrize("bs", [1, 160]) @common_utils.parametrize("sym_quant_a", [True, False]) def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a): - if sym_quant_a and not TORCH_VERSION_AT_LEAST_2_8: + if sym_quant_a and not torch_version_at_least("2.8.0"): # not supported until PT 2.8 return device = "cpu" @@ -119,7 +116,7 @@ def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a): "CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"), reason="cpp kernels not built", ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Test only enabled for 2.8+") + @unittest.skipIf(not torch_version_at_least("2.8.0"), "Test only enabled for 2.8+") @common_utils.parametrize("x_dim", [2, 3]) @common_utils.parametrize("bias", [True, False]) def test_8da4w_concat_linear_cpu(self, x_dim, bias): diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index f979c9a588..67d1255b5e 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -66,9 +66,9 @@ from torchao.quantization.utils import compute_error from torchao.testing.utils import skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, is_sm_at_least_90, + torch_version_at_least, unwrap_tensor_subclass, ) @@ -221,7 +221,7 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): torch.testing.assert_close(quantized, compiled, atol=0, rtol=0) @unittest.skipIf(not torch.xpu.is_available(), "Need XPU available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "only works for torch 2.8+") + @unittest.skipIf(not torch_version_at_least("2.8.0"), "only works for torch 2.8+") def test_int4_wo_quant_save_load(self): m = ToyLinearModel().eval().cpu() diff --git a/test/test_low_bit_optim.py b/test/test_low_bit_optim.py index 64df37ac88..b0edfc7fc5 100644 --- a/test/test_low_bit_optim.py +++ b/test/test_low_bit_optim.py @@ -41,8 +41,8 @@ from torchao.optim.subclass_fp8 import OptimStateFp8 from torchao.testing.utils import skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_7, get_available_devices, + torch_version_at_least, ) try: @@ -242,7 +242,7 @@ def test_subclass_slice(self, subclass, shape, device): ) @skip_if_rocm("ROCm enablement in progress") @pytest.mark.skipif( - TORCH_VERSION_AT_LEAST_2_7, reason="Failing in CI" + torch_version_at_least("2.7.0"), reason="Failing in CI" ) # TODO: fix this @parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) def test_optim_8bit_correctness(self, optim_name): diff --git a/test/test_ops.py b/test/test_ops.py index bc9fe0e4f9..89512b673d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -28,8 +28,8 @@ ) from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24 from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_7, compute_max_diff, + torch_version_at_least, ) IS_CUDA = torch.cuda.is_available() and torch.version.cuda @@ -155,7 +155,8 @@ def _scaled_dot_product_int8_op_ref( return out.to(torch.uint8) @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later" + not torch_version_at_least("2.7.0"), + reason="int8 sdpa requires torch 2.7 or later", ) @pytest.mark.skipif(not IS_LINUX, reason="only support on linux") @pytest.mark.skipif( diff --git a/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py index 2f696b1131..8d0cfaddeb 100644 --- a/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py +++ b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py @@ -16,10 +16,7 @@ register_layout, ) from torchao.dtypes.utils import Layout, PlainLayout, is_device -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_7, - TORCH_VERSION_AT_LEAST_2_8, -) +from torchao.utils import torch_version_at_least from .int4_cpu_layout import ( Int4CPUAQTTensorImpl, @@ -246,7 +243,7 @@ def _aqt_is_uint4(aqt): def _linear_int8_act_int4_weight_cpu_check(input_tensor, weight_tensor, bias): return ( - TORCH_VERSION_AT_LEAST_2_7 + torch_version_at_least("2.7.0") and is_device(input_tensor.device.type, "cpu") and is_device(weight_tensor.device.type, "cpu") and (bias is None or is_device(bias.device.type, "cpu")) @@ -262,11 +259,11 @@ def _linear_int8_act_int4_weight_cpu_check(input_tensor, weight_tensor, bias): def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias): - assert TORCH_VERSION_AT_LEAST_2_7, ( + assert torch_version_at_least("2.7.0"), ( f"Requires PyTorch version at least 2.7, but got: {torch.__version__}" ) if _aqt_is_int8(input_tensor): - assert TORCH_VERSION_AT_LEAST_2_8, ( + assert torch_version_at_least("2.8.0"), ( f"Requires PyTorch version at least 2.8, but got: {torch.__version__}" ) assert is_device(input_tensor.device.type, "cpu"), ( diff --git a/torchao/dtypes/uintx/int4_xpu_layout.py b/torchao/dtypes/uintx/int4_xpu_layout.py index 955a7a8610..a01fad31c2 100644 --- a/torchao/dtypes/uintx/int4_xpu_layout.py +++ b/torchao/dtypes/uintx/int4_xpu_layout.py @@ -20,8 +20,8 @@ from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device from torchao.quantization.quant_primitives import ZeroPointDomain from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, fill_defaults, + torch_version_at_least, ) aten = torch.ops.aten @@ -248,7 +248,7 @@ def from_plain( ): assert isinstance(_layout, Int4XPULayout) - if TORCH_VERSION_AT_LEAST_2_8: + if torch_version_at_least("2.8.0"): assert int_data.dtype == torch.int32, ( "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" ) diff --git a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py index 5e032f01c2..0cea1c2c70 100644 --- a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py +++ b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py @@ -15,10 +15,10 @@ register_lowering_pattern, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import torch_version_at_least -if TORCH_VERSION_AT_LEAST_2_7: - # TORCH_VERSION_AT_LEAST_2_7 is needed for functions in int8 sdpa lowering +if torch_version_at_least("2.7.0"): + # PyTorch 2.7+ is needed for functions in int8 sdpa lowering from ..int8_sdpa_lowering import register_int8_sdpa # noqa: F401 else: make_fallback(torch.ops.torchao.qscaled_dot_product.default) @@ -370,8 +370,8 @@ def _register_int8_sdpa_lowerings(custom_pass_dict): custom_pass = None -if TORCH_VERSION_AT_LEAST_2_7: - # TORCH_VERSION_AT_LEAST_2_7 is needed for custom graph pass +if torch_version_at_least("2.7.0"): + # PyTorch 2.7+ is needed for custom graph pass from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files # define the custom pass @@ -390,7 +390,7 @@ def uuid(self) -> bytes: @functools.lru_cache(None) def _int8_sdpa_init(): - if TORCH_VERSION_AT_LEAST_2_7: + if torch_version_at_least("2.7.0"): _register_int8_sdpa_lowerings(config.post_grad_custom_pre_pass) else: pass diff --git a/torchao/prototype/mx_formats/constants.py b/torchao/prototype/mx_formats/constants.py index ffac3b1d5f..3111bc771b 100644 --- a/torchao/prototype/mx_formats/constants.py +++ b/torchao/prototype/mx_formats/constants.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_8 +from torchao.utils import torch_version_at_least # This is conceptually an enum of non-core dtypes # TODO(future PR): change to a cleaner way to represent this without @@ -23,7 +23,7 @@ ] SUPPORTED_ELEM_DTYPES = ( SUPPORTED_ELEM_DTYPES + [torch.float4_e2m1fn_x2] - if TORCH_VERSION_AT_LEAST_2_8 + if torch_version_at_least("2.8.0") else SUPPORTED_ELEM_DTYPES ) @@ -33,7 +33,7 @@ DTYPE_FP6_E2M3: "f6e2m3", DTYPE_FP6_E3M2: "f6e3m2", } -if TORCH_VERSION_AT_LEAST_2_8: +if torch_version_at_least("2.8.0"): DTYPE_TO_SHORT_STR[torch.float4_e2m1fn_x2] = "f4e2m1" F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0 diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 241ce295bd..34cf9e9506 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -27,8 +27,8 @@ register_quantize_module_handler, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_100, + torch_version_at_least, ) @@ -148,7 +148,7 @@ class NVFP4InferenceConfig(AOBaseConfig): def __post_init__(self): # Validate PyTorch version - if not TORCH_VERSION_AT_LEAST_2_8: + if not torch_version_at_least("2.8.0"): raise RuntimeError("NVFP4InferenceConfig requires PyTorch 2.8 or later") diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index cd605917af..be23057ac7 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -17,8 +17,8 @@ _floatx_unpacked_to_f32, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_7, is_sm_at_least_100, + torch_version_at_least, ) # TODO(future): if needed, make the below work on previous PyTorch versions, @@ -821,7 +821,7 @@ def _(uint8_data): return torch.empty(*out_shape, device=uint8_data.device, dtype=torch.uint8) -if TORCH_VERSION_AT_LEAST_2_7 and has_triton(): +if torch_version_at_least("2.7.0") and has_triton(): import triton import triton.language as tl from torch.library import triton_op, wrap_triton diff --git a/torchao/quantization/pt2e/quantize_pt2e.py b/torchao/quantization/pt2e/quantize_pt2e.py index 88f0eb490c..8a7314359b 100644 --- a/torchao/quantization/pt2e/quantize_pt2e.py +++ b/torchao/quantization/pt2e/quantize_pt2e.py @@ -6,9 +6,9 @@ import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import torch_version_at_least -if TORCH_VERSION_AT_LEAST_2_7: +if torch_version_at_least("2.7.0"): from .constant_fold import constant_fold from typing import Union @@ -320,7 +320,7 @@ def convert_pt2e( pm = PassManager([PortNodeMetaForQDQ()]) model = pm(model).graph_module - if fold_quantize and TORCH_VERSION_AT_LEAST_2_7: + if fold_quantize and torch_version_at_least("2.7.0"): constant_fold(model, _quant_node_constraint) if use_reference_representation: diff --git a/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py b/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py index 84a66447c1..656f4fbbeb 100644 --- a/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py +++ b/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py @@ -1634,8 +1634,8 @@ def validate(self, model: torch.fx.GraphModule) -> None: _register_quantization_weight_pack_pass, quant_lift_up, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_8 +from torchao.utils import torch_version_at_least -if TORCH_VERSION_AT_LEAST_2_8: +if torch_version_at_least("2.8.0"): torch._inductor.config.pre_grad_custom_pass = quant_lift_up _register_quantization_weight_pack_pass() diff --git a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_tensor.py index bd6d08b998..03d653a442 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_tensor.py @@ -18,7 +18,6 @@ quantize_affine, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, fill_defaults, ) @@ -274,6 +273,5 @@ def _(func, types, args, kwargs): IntxUnpackedTensor.__module__ = "torchao.quantization" -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with IntxUnpackedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([IntxUnpackedTensor]) +# Allow a model with IntxUnpackedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([IntxUnpackedTensor]) diff --git a/torchao/testing/pt2e/utils.py b/torchao/testing/pt2e/utils.py index 41bd5f0310..f031386012 100644 --- a/torchao/testing/pt2e/utils.py +++ b/torchao/testing/pt2e/utils.py @@ -29,7 +29,7 @@ prepare_pt2e, prepare_qat_pt2e, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import torch_version_at_least class PT2EQuantizationTestCase(QuantizationTestCase): @@ -132,7 +132,7 @@ def _test_quantizer( return m -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class PT2ENumericDebuggerTestCase(TestCase): """ Base test case class for PT2E numeric debugger tests containing common utility functions diff --git a/torchao/utils.py b/torchao/utils.py index 68d17ededf..bef5c038a0 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -29,22 +29,22 @@ "get_model_size_in_bytes", "unwrap_tensor_subclass", "TorchAOBaseTensor", + "is_MI300", + "is_sm_at_least_89", + "is_sm_at_least_90", + "is_package_at_least", + "DummyModule", + # Deprecated "TORCH_VERSION_AT_LEAST_2_2", "TORCH_VERSION_AT_LEAST_2_3", "TORCH_VERSION_AT_LEAST_2_4", "TORCH_VERSION_AT_LEAST_2_5", "TORCH_VERSION_AT_LEAST_2_6", "TORCH_VERSION_AT_LEAST_2_7", - # Needs to be deprecated in the future "TORCH_VERSION_AFTER_2_2", "TORCH_VERSION_AFTER_2_3", "TORCH_VERSION_AFTER_2_4", "TORCH_VERSION_AFTER_2_5", - "is_MI300", - "is_sm_at_least_89", - "is_sm_at_least_90", - "is_package_at_least", - "DummyModule", ]