diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index 4298d17c9c..bad09bf32a 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -7,19 +7,9 @@ import subprocess from pathlib import Path from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -import torch -from packaging.version import Version as PkgVersion - - -def get_torch_version(): - """Get PyTorch version from __version__""" +from transformer_engine.pytorch.utils import torch_version - def get_torch_version_str(): - import torch - - return str(torch.__version__) - - return PkgVersion(get_torch_version_str()) +import torch fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -44,7 +34,7 @@ def _run_test(fp_init, sharding_dims): @pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs") @pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs") -@pytest.mark.skipif(not get_torch_version() >= PkgVersion("2.4"), reason="Requires PyTorch 2.4.0+") +@pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") @pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2])) @pytest.mark.parametrize("fp8_init", (False, True)) def test_distributed(fp8_init, sharding_dims): diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index bb826e552e..b4631eb9a7 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -16,6 +16,7 @@ canonicalize_device, canonicalize_dtype, devices_match, + torch_version, ) @@ -98,8 +99,13 @@ def maybe_autocast_dtype( default_dtype: Optional[torch.dtype] = None, ) -> torch.dtype: """Get autocast dtype if enabled""" - if torch.is_autocast_enabled(device_type): - return torch.get_autocast_dtype(device_type) + + if torch_version() >= (2, 4, 3): + if torch.is_autocast_enabled(device_type): + return torch.get_autocast_dtype(device_type) + else: + if torch.is_autocast_enabled(): + return torch.get_autocast_gpu_dtype() return canonicalize_dtype(default_dtype) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 1922a7e867..4678097dc4 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -8,6 +8,7 @@ import math import os from typing import Any, Callable, List, Optional, Tuple +from packaging.version import Version as PkgVersion import torch import transformer_engine.pytorch.cpp_extensions as ext @@ -386,3 +387,9 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None: # Pop NVTX range torch.cuda.nvtx.range_pop() + + +@functools.lru_cache(maxsize=None) +def torch_version() -> tuple[int, ...]: + """Get PyTorch version""" + return PkgVersion(str(torch.__version__)).release