Skip to content

Commit

Permalink
Fix TE ops API compatibility with PyTorch versions < 2.4.3 (#1494)
Browse files Browse the repository at this point in the history
* Fix te sequential for older pytorch versions

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* FIxes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman authored Feb 20, 2025
1 parent fceff07 commit b612cde
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
16 changes: 3 additions & 13 deletions tests/pytorch/distributed/test_torch_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,9 @@
import subprocess
from pathlib import Path
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import torch
from packaging.version import Version as PkgVersion


def get_torch_version():
"""Get PyTorch version from __version__"""
from transformer_engine.pytorch.utils import torch_version

def get_torch_version_str():
import torch

return str(torch.__version__)

return PkgVersion(get_torch_version_str())
import torch


fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
Expand All @@ -44,7 +34,7 @@ def _run_test(fp_init, sharding_dims):

@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs")
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs")
@pytest.mark.skipif(not get_torch_version() >= PkgVersion("2.4"), reason="Requires PyTorch 2.4.0+")
@pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2]))
@pytest.mark.parametrize("fp8_init", (False, True))
def test_distributed(fp8_init, sharding_dims):
Expand Down
10 changes: 8 additions & 2 deletions transformer_engine/pytorch/ops/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
canonicalize_device,
canonicalize_dtype,
devices_match,
torch_version,
)


Expand Down Expand Up @@ -98,8 +99,13 @@ def maybe_autocast_dtype(
default_dtype: Optional[torch.dtype] = None,
) -> torch.dtype:
"""Get autocast dtype if enabled"""
if torch.is_autocast_enabled(device_type):
return torch.get_autocast_dtype(device_type)

if torch_version() >= (2, 4, 3):
if torch.is_autocast_enabled(device_type):
return torch.get_autocast_dtype(device_type)
else:
if torch.is_autocast_enabled():
return torch.get_autocast_gpu_dtype()
return canonicalize_dtype(default_dtype)


Expand Down
7 changes: 7 additions & 0 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import math
import os
from typing import Any, Callable, List, Optional, Tuple
from packaging.version import Version as PkgVersion

import torch
import transformer_engine.pytorch.cpp_extensions as ext
Expand Down Expand Up @@ -386,3 +387,9 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None:

# Pop NVTX range
torch.cuda.nvtx.range_pop()


@functools.lru_cache(maxsize=None)
def torch_version() -> tuple[int, ...]:
"""Get PyTorch version"""
return PkgVersion(str(torch.__version__)).release

0 comments on commit b612cde

Please sign in to comment.