Skip to content

Commit 07fbc89

Browse files
authored
fix incorrect torch version test (#2786)
* fix torch version detector * add pre-release parser for torch_version_at_least and remove compare_versions - Co-authored-by: andrewor14 <[email protected]> * add pre-release parser for torch_version_at_least and remove compare_versions - Co-authored-by: andrewor14 <[email protected]> * remove local test code * update PyTorch pre-release version indicator * update pre-release patterns
1 parent a9ffa50 commit 07fbc89

File tree

2 files changed

+27
-21
lines changed

2 files changed

+27
-21
lines changed

test/test_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
class TestTorchVersion(unittest.TestCase):
1717
def test_torch_version_at_least(self):
1818
test_cases = [
19-
("2.5.0a0+git9f17037", "2.5.0", True),
20-
("2.5.0a0+git9f17037", "2.4.0", True),
21-
("2.5.0.dev20240708+cu121", "2.5.0", True),
22-
("2.5.0.dev20240708+cu121", "2.4.0", True),
23-
("2.5.0", "2.4.0", True),
24-
("2.5.0", "2.5.0", True),
25-
("2.4.0", "2.4.0", True),
26-
("2.4.0", "2.5.0", False),
19+
("2.5.0a0+git9f17037", "2.5.0", False), # [2, 5, -1] < [2, 5, 0]
20+
("2.5.0a0+git9f17037", "2.4.0", True), # [2, 5, -1] > [2, 4, 0]
21+
("2.5.0.dev20240708+cu121", "2.5.0", False), # [2, 5, -1] < [2, 5, 0]
22+
("2.5.0.dev20240708+cu121", "2.4.0", True), # [2, 5, -1] > [2, 4, 0]
23+
("2.5.0", "2.4.0", True), # [2, 5, 0] > [2, 4, 0]
24+
("2.5.0", "2.5.0", True), # [2, 5, 0] >= [2, 5, 0]
25+
("2.4.0", "2.4.0", True), # [2, 4, 0] >= [2, 4, 0]
26+
("2.4.0", "2.5.0", False), # [2, 4, 0] < [2, 5, 0]
2727
]
2828

2929
for torch_version, compare_version, expected_result in test_cases:

torchao/utils.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -348,27 +348,33 @@ def _is_float8_type(dtype: torch.dtype) -> bool:
348348

349349

350350
def parse_version(version_string):
351-
# Extract just the X.Y.Z part from the version string
352-
match = re.match(r"(\d+\.\d+\.\d+)", version_string)
351+
"""
352+
Parse version string representing pre-release with -1
353+
354+
Examples: "2.5.0.dev20240708+cu121" -> [2, 5, -1], "2.5.0" -> [2, 5, 0]
355+
"""
356+
# Check for pre-release indicators
357+
is_prerelease = bool(re.search(r"(git|dev)", version_string))
358+
match = re.match(r"(\d+)\.(\d+)\.(\d+)", version_string)
353359
if match:
354-
version = match.group(1)
355-
return [int(x) for x in version.split(".")]
360+
major, minor, patch = map(int, match.groups())
361+
if is_prerelease:
362+
patch = -1
363+
return [major, minor, patch]
356364
else:
357365
raise ValueError(f"Invalid version string format: {version_string}")
358366

359367

360-
def compare_versions(v1, v2):
361-
v1_parts = parse_version(v1)
362-
v2_parts = parse_version(v2)
363-
return (v1_parts > v2_parts) - (v1_parts < v2_parts)
364-
365-
366368
def is_fbcode():
367369
return not hasattr(torch.version, "git_version")
368370

369371

370372
def torch_version_at_least(min_version):
371-
return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0
373+
if is_fbcode():
374+
return True
375+
376+
# Parser for local identifiers
377+
return parse_version(torch.__version__) >= parse_version(min_version)
372378

373379

374380
def _deprecated_torch_version_at_least(version_str: str) -> str:
@@ -1085,13 +1091,13 @@ def is_sm_at_least_100():
10851091
def check_cpu_version(device, version="2.6.0"):
10861092
if isinstance(device, torch.device):
10871093
device = device.type
1088-
return device == "cpu" and compare_versions(torch.__version__, version) >= 0
1094+
return device == "cpu" and torch_version_at_least(version)
10891095

10901096

10911097
def check_xpu_version(device, version="2.8.0"):
10921098
if isinstance(device, torch.device):
10931099
device = device.type
1094-
return device == "xpu" and compare_versions(torch.__version__, version) >= 0
1100+
return device == "xpu" and torch_version_at_least(version)
10951101

10961102

10971103
def ceil_div(a, b):

0 commit comments

Comments
 (0)