Skip to content

Commit fe9fde5

Browse files
committed
Remove TORCH_VERSION_AT_LEAST* warnings when importing torch
**Summary:** We recently deprecated these variables but we're still using them in torchao. We should replace all of them with `torch_version_at_least` so users don't see these deprecation warnings when they're just importing torchao. **Test Plan:** Manual import
1 parent a9ffa50 commit fe9fde5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+132
-141
lines changed

test/dtypes/test_nf4.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
to_nf4,
4444
)
4545
from torchao.testing.utils import skip_if_rocm
46-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
46+
from torchao.utils import torch_version_at_least
4747

4848
bnb_available = False
4949

@@ -123,7 +123,7 @@ def test_backward_dtype_match(self, dtype: torch.dtype):
123123
@unittest.skipIf(not bnb_available, "Need bnb availble")
124124
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
125125
@unittest.skipIf(
126-
TORCH_VERSION_AT_LEAST_2_7, reason="Failing in CI"
126+
torch_version_at_least("2.7.0"), reason="Failing in CI"
127127
) # TODO: fix this
128128
@skip_if_rocm("ROCm enablement in progress")
129129
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@@ -150,7 +150,7 @@ def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype):
150150
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
151151
@skip_if_rocm("ROCm enablement in progress")
152152
@unittest.skipIf(
153-
TORCH_VERSION_AT_LEAST_2_7, reason="Failing in CI"
153+
torch_version_at_least("2.7.0"), reason="Failing in CI"
154154
) # TODO: fix this
155155
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
156156
def test_nf4_bnb_linear(self, dtype: torch.dtype):

test/integration/test_integration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@
7676
)
7777
from torchao.testing.utils import skip_if_rocm
7878
from torchao.utils import (
79-
TORCH_VERSION_AT_LEAST_2_7,
8079
benchmark_model,
8180
check_cpu_version,
8281
check_xpu_version,
8382
is_fbcode,
8483
is_sm_at_least_89,
8584
is_sm_at_least_90,
85+
torch_version_at_least,
8686
unwrap_tensor_subclass,
8787
)
8888

@@ -1883,7 +1883,7 @@ def forward(self, x):
18831883
model(x)
18841884

18851885
api(model)
1886-
if not TORCH_VERSION_AT_LEAST_2_7:
1886+
if not torch_version_at_least("2.7.0"):
18871887
unwrap_tensor_subclass(model)
18881888

18891889
# running model
@@ -1942,7 +1942,7 @@ def forward(self, x):
19421942
model(x)
19431943

19441944
api(model)
1945-
if not TORCH_VERSION_AT_LEAST_2_7:
1945+
if not torch_version_at_least("2.7.0"):
19461946
unwrap_tensor_subclass(model)
19471947

19481948
# running model

test/integration/test_vllm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
import torch
1818

1919
from packaging import version
20-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_8
20+
from torchao.utils import torch_version_at_least
2121

22-
if not TORCH_VERSION_AT_LEAST_2_8:
22+
if not torch_version_at_least("2.8.0"):
2323
pytest.skip("Requires PyTorch 2.8 or higher", allow_module_level=True)
2424

2525

test/prototype/inductor/test_int8_sdpa_fusion.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
_int8_sdpa_init,
1616
custom_pass,
1717
)
18-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
18+
from torchao.utils import torch_version_at_least
1919

2020

2121
class SelfAttnLikeModule(torch.nn.Module):
@@ -149,7 +149,8 @@ def _check_common(
149149

150150
@skipIfRocm
151151
@unittest.skipIf(
152-
not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later"
152+
not torch_version_at_least("2.7.0"),
153+
reason="int8 sdpa requires torch 2.7 or later",
153154
)
154155
@unittest.skipIf(
155156
"CPU" not in torch._C._dispatch_dump("torchao::qscaled_dot_product"),

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
import torch
99
from torch.nn import functional as F
1010

11-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
11+
from torchao.utils import torch_version_at_least
1212

1313
# We need to skip before doing any imports which would use triton, since
1414
# triton won't be available on CPU builds and torch < 2.5
1515
if not (
16-
TORCH_VERSION_AT_LEAST_2_7
16+
torch_version_at_least("2.7.0")
1717
and torch.cuda.is_available()
1818
and torch.cuda.get_device_capability()[0] >= 9
1919
):

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
from torchao.quantization.utils import compute_error
2323
from torchao.testing.utils import skip_if_rocm
2424
from torchao.utils import (
25-
TORCH_VERSION_AT_LEAST_2_8,
2625
is_sm_at_least_89,
2726
is_sm_at_least_100,
27+
torch_version_at_least,
2828
)
2929

3030
torch.manual_seed(2)
3131

32-
if not TORCH_VERSION_AT_LEAST_2_8:
32+
if not torch_version_at_least("2.8.0"):
3333
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
3434

3535

@@ -45,7 +45,7 @@ def run_around_tests():
4545

4646
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
4747
@pytest.mark.skipif(
48-
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
48+
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
4949
)
5050
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2])
5151
@pytest.mark.parametrize("bias", [True, False])
@@ -96,7 +96,7 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool):
9696

9797
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
9898
@pytest.mark.skipif(
99-
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
99+
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
100100
)
101101
@pytest.mark.parametrize("bias", [True, False])
102102
@pytest.mark.parametrize("compile", [True, False])

test/prototype/mx_formats/test_kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@
4444
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx
4545
from torchao.prototype.mx_formats.utils import to_blocked
4646
from torchao.utils import (
47-
TORCH_VERSION_AT_LEAST_2_8,
4847
is_sm_at_least_89,
4948
is_sm_at_least_100,
49+
torch_version_at_least,
5050
)
5151

5252
torch.manual_seed(0)
5353

54-
if not TORCH_VERSION_AT_LEAST_2_8:
54+
if not torch_version_at_least("2.8.0"):
5555
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
5656

5757

test/prototype/mx_formats/test_mx_dtensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
import pytest
1616
import torch
1717

18-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7, is_sm_at_least_100
18+
from torchao.utils import is_sm_at_least_100, torch_version_at_least
1919

20-
if not TORCH_VERSION_AT_LEAST_2_7:
20+
if not torch_version_at_least("2.7.0"):
2121
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2222

2323
from torch.distributed._tensor import DTensor, Shard, distribute_tensor

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@
2626
from torchao.quantization import quantize_
2727
from torchao.quantization.utils import compute_error
2828
from torchao.utils import (
29-
TORCH_VERSION_AT_LEAST_2_8,
3029
is_sm_at_least_89,
3130
is_sm_at_least_100,
31+
torch_version_at_least,
3232
)
3333

3434
torch.manual_seed(2)
3535

36-
if not TORCH_VERSION_AT_LEAST_2_8:
36+
if not torch_version_at_least("2.8.0"):
3737
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
3838

3939

@@ -57,7 +57,7 @@ def run_around_tests():
5757
# only test one type of mixed-dtype overrides, to save testing time
5858
(torch.float8_e4m3fn, torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2),
5959
]
60-
if TORCH_VERSION_AT_LEAST_2_8
60+
if torch_version_at_least("2.8.0")
6161
else [
6262
# test each dtype
6363
(torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn),
@@ -276,7 +276,7 @@ def test_linear_compile(
276276
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
277277

278278
if recipe_name in ["mxfp8_cublas", "mxfp4_cutlass"]:
279-
if not TORCH_VERSION_AT_LEAST_2_8:
279+
if not torch_version_at_least("2.8.0"):
280280
pytest.skip("torch.compile requires PyTorch 2.8+")
281281
if not is_sm_at_least_100():
282282
pytest.skip("CUDA capability >= 10.0 required for MX gemms")

test/prototype/mx_formats/test_mx_mm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
from torchao.prototype.mx_formats.mx_tensor import MXTensor
1414
from torchao.prototype.mx_formats.utils import to_blocked
1515
from torchao.utils import (
16-
TORCH_VERSION_AT_LEAST_2_8,
1716
is_sm_at_least_100,
17+
torch_version_at_least,
1818
)
1919

20-
if not TORCH_VERSION_AT_LEAST_2_8:
20+
if not torch_version_at_least("2.8.0"):
2121
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2222

2323

@@ -79,7 +79,7 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:
7979
ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}",
8080
)
8181
@pytest.mark.parametrize(
82-
"format", ["fp8", "fp4"] if TORCH_VERSION_AT_LEAST_2_8 else ["fp8"]
82+
"format", ["fp8", "fp4"] if torch_version_at_least("2.8.0") else ["fp8"]
8383
)
8484
def test_matrix_multiplication(size, format):
8585
M, K, N = size

0 commit comments

Comments
 (0)