Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
to_nf4,
)
from torchao.testing.utils import skip_if_rocm
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
from torchao.utils import torch_version_at_least

bnb_available = False

Expand Down Expand Up @@ -123,7 +123,7 @@ def test_backward_dtype_match(self, dtype: torch.dtype):
@unittest.skipIf(not bnb_available, "Need bnb availble")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
TORCH_VERSION_AT_LEAST_2_7, reason="Failing in CI"
torch_version_at_least("2.7.0"), reason="Failing in CI"
) # TODO: fix this
@skip_if_rocm("ROCm enablement in progress")
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
Expand All @@ -150,7 +150,7 @@ def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@skip_if_rocm("ROCm enablement in progress")
@unittest.skipIf(
TORCH_VERSION_AT_LEAST_2_7, reason="Failing in CI"
torch_version_at_least("2.7.0"), reason="Failing in CI"
) # TODO: fix this
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_nf4_bnb_linear(self, dtype: torch.dtype):
Expand Down
6 changes: 3 additions & 3 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@
)
from torchao.testing.utils import skip_if_rocm
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_7,
benchmark_model,
check_cpu_version,
check_xpu_version,
is_fbcode,
is_sm_at_least_89,
is_sm_at_least_90,
torch_version_at_least,
unwrap_tensor_subclass,
)

Expand Down Expand Up @@ -1883,7 +1883,7 @@ def forward(self, x):
model(x)

api(model)
if not TORCH_VERSION_AT_LEAST_2_7:
if not torch_version_at_least("2.7.0"):
unwrap_tensor_subclass(model)

# running model
Expand Down Expand Up @@ -1942,7 +1942,7 @@ def forward(self, x):
model(x)

api(model)
if not TORCH_VERSION_AT_LEAST_2_7:
if not torch_version_at_least("2.7.0"):
unwrap_tensor_subclass(model)

# running model
Expand Down
4 changes: 2 additions & 2 deletions test/integration/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import torch

from packaging import version
from torchao.utils import TORCH_VERSION_AT_LEAST_2_8
from torchao.utils import torch_version_at_least

if not TORCH_VERSION_AT_LEAST_2_8:
if not torch_version_at_least("2.8.0"):
pytest.skip("Requires PyTorch 2.8 or higher", allow_module_level=True)


Expand Down
5 changes: 3 additions & 2 deletions test/prototype/inductor/test_int8_sdpa_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
_int8_sdpa_init,
custom_pass,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
from torchao.utils import torch_version_at_least


class SelfAttnLikeModule(torch.nn.Module):
Expand Down Expand Up @@ -149,7 +149,8 @@ def _check_common(

@skipIfRocm
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later"
not torch_version_at_least("2.7.0"),
reason="int8 sdpa requires torch 2.7 or later",
)
@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::qscaled_dot_product"),
Expand Down
4 changes: 2 additions & 2 deletions test/prototype/moe_training/test_scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import torch
from torch.nn import functional as F

from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
from torchao.utils import torch_version_at_least

# We need to skip before doing any imports which would use triton, since
# triton won't be available on CPU builds and torch < 2.5
if not (
TORCH_VERSION_AT_LEAST_2_7
torch_version_at_least("2.7.0")
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 9
):
Expand Down
8 changes: 4 additions & 4 deletions test/prototype/mx_formats/test_inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
from torchao.quantization.utils import compute_error
from torchao.testing.utils import skip_if_rocm
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_89,
is_sm_at_least_100,
torch_version_at_least,
)

torch.manual_seed(2)

if not TORCH_VERSION_AT_LEAST_2_8:
if not torch_version_at_least("2.8.0"):
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


Expand All @@ -45,7 +45,7 @@ def run_around_tests():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
)
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2])
@pytest.mark.parametrize("bias", [True, False])
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool):

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("compile", [True, False])
Expand Down
4 changes: 2 additions & 2 deletions test/prototype/mx_formats/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_89,
is_sm_at_least_100,
torch_version_at_least,
)

torch.manual_seed(0)

if not TORCH_VERSION_AT_LEAST_2_8:
if not torch_version_at_least("2.8.0"):
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


Expand Down
4 changes: 2 additions & 2 deletions test/prototype/mx_formats/test_mx_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
import pytest
import torch

from torchao.utils import TORCH_VERSION_AT_LEAST_2_7, is_sm_at_least_100
from torchao.utils import is_sm_at_least_100, torch_version_at_least

if not TORCH_VERSION_AT_LEAST_2_7:
if not torch_version_at_least("2.7.0"):
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

from torch.distributed._tensor import DTensor, Shard, distribute_tensor
Expand Down
8 changes: 4 additions & 4 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
from torchao.quantization import quantize_
from torchao.quantization.utils import compute_error
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_89,
is_sm_at_least_100,
torch_version_at_least,
)

torch.manual_seed(2)

if not TORCH_VERSION_AT_LEAST_2_8:
if not torch_version_at_least("2.8.0"):
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


Expand All @@ -57,7 +57,7 @@ def run_around_tests():
# only test one type of mixed-dtype overrides, to save testing time
(torch.float8_e4m3fn, torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2),
]
if TORCH_VERSION_AT_LEAST_2_8
if torch_version_at_least("2.8.0")
else [
# test each dtype
(torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn),
Expand Down Expand Up @@ -276,7 +276,7 @@ def test_linear_compile(
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")

if recipe_name in ["mxfp8_cublas", "mxfp4_cutlass"]:
if not TORCH_VERSION_AT_LEAST_2_8:
if not torch_version_at_least("2.8.0"):
pytest.skip("torch.compile requires PyTorch 2.8+")
if not is_sm_at_least_100():
pytest.skip("CUDA capability >= 10.0 required for MX gemms")
Expand Down
6 changes: 3 additions & 3 deletions test/prototype/mx_formats/test_mx_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_100,
torch_version_at_least,
)

if not TORCH_VERSION_AT_LEAST_2_8:
if not torch_version_at_least("2.8.0"):
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


Expand Down Expand Up @@ -79,7 +79,7 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:
ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}",
)
@pytest.mark.parametrize(
"format", ["fp8", "fp4"] if TORCH_VERSION_AT_LEAST_2_8 else ["fp8"]
"format", ["fp8", "fp4"] if torch_version_at_least("2.8.0") else ["fp8"]
)
def test_matrix_multiplication(size, format):
M, K, N = size
Expand Down
20 changes: 10 additions & 10 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
)
from torchao.quantization.utils import compute_error
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_89,
is_sm_at_least_100,
torch_version_at_least,
)

torch.manual_seed(2)

if not TORCH_VERSION_AT_LEAST_2_8:
if not torch_version_at_least("2.8.0"):
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


Expand Down Expand Up @@ -605,7 +605,7 @@ def to_f8(x):
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
)
def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale):
from torchao.prototype.mx_formats.nvfp4_tensor import (
Expand Down Expand Up @@ -674,7 +674,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
"use_triton_kernel", [False, True] if torch.cuda.is_available() else [False]
)
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
)
def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool):
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked
Expand Down Expand Up @@ -707,7 +707,7 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool):
],
)
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape):
Expand Down Expand Up @@ -746,7 +746,7 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape):
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+"
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
)
def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec):
"""
Expand Down Expand Up @@ -841,7 +841,7 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec):
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+"
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
)
def test_nvfp4_swizzled_scales_slicing_errors(slice_dim, slice_spec, expected_error):
"""
Expand All @@ -862,7 +862,7 @@ def test_nvfp4_swizzled_scales_slicing_errors(slice_dim, slice_spec, expected_er

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+"
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
)
def test_nvfp4_swizzled_scales_view_semantics():
"""
Expand All @@ -889,7 +889,7 @@ def test_nvfp4_swizzled_scales_view_semantics():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+"
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
)
def test_nvfp4_swizzled_scales_serialization():
"""
Expand Down Expand Up @@ -931,7 +931,7 @@ def test_nvfp4_swizzled_scales_serialization():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+"
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
)
def test_nvfp4_swizzled_scales_get_scales_method():
"""
Expand Down
Loading
Loading