Skip to content

Commit e122173

Browse files
authored
[PyTorch] Cache RHT device tensors properly (#2395)
* Cache device tensors properly Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix annotation and add test Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * skip nvfp4 test if not supported Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
1 parent d677a26 commit e122173

File tree

2 files changed

+50
-16
lines changed

2 files changed

+50
-16
lines changed

tests/pytorch/distributed/test_sanity.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,16 @@
77
import pytest
88
import torch
99
import transformer_engine
10-
from transformer_engine.pytorch import DotProductAttention, TransformerLayer, Linear, GroupedLinear
10+
from transformer_engine.pytorch import (
11+
DotProductAttention,
12+
TransformerLayer,
13+
Linear,
14+
GroupedLinear,
15+
NVFP4Quantizer,
16+
autocast,
17+
is_nvfp4_available,
18+
)
19+
from transformer_engine.common import recipe
1120

1221
_current_file = pathlib.Path(__file__).resolve()
1322
sys.path.append(str(_current_file.parent.parent))
@@ -17,6 +26,8 @@
1726
"small": ModelConfig(2, 10, 2, 16),
1827
}
1928

29+
nvfp4_available, reason_for_no_nvfp4 = is_nvfp4_available(return_reason=True)
30+
2031

2132
@pytest.mark.parametrize("model", ["small"])
2233
@pytest.mark.parametrize(
@@ -138,3 +149,24 @@ def test_current_device(model, module):
138149
assert (
139150
tensor_device_grad == tensor_device
140151
), "The gradient tensor should be the same as the input tensors!"
152+
153+
154+
@pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4)
155+
def test_nvfp4_rht_cache():
156+
"""Ensure correct RHT cache for NVFP4."""
157+
158+
num_devices = torch.cuda.device_count()
159+
assert num_devices > 1, "This test requires more than one GPU!"
160+
161+
# Populate cache on last device.
162+
with torch.cuda.device(num_devices - 1):
163+
_ = NVFP4Quantizer()
164+
165+
hidden_size = 128
166+
dtype = torch.bfloat16
167+
168+
model = Linear(hidden_size, hidden_size, params_dtype=dtype)
169+
inp = torch.randn(hidden_size, hidden_size, device=torch.cuda.current_device(), dtype=dtype)
170+
fp4_recipe = recipe.NVFP4BlockScaling()
171+
with autocast(recipe=fp4_recipe):
172+
_ = model(inp)

transformer_engine/pytorch/tensor/nvfp4_tensor.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
aten = torch.ops.aten
2929

3030

31-
def get_no_random_sign_vector() -> torch.Tensor:
31+
def get_no_random_sign_vector(device: int) -> torch.Tensor:
3232
"""Non-random sign vector for Hadamard transform."""
33-
return torch.tensor([1], dtype=torch.float32, device="cuda")
33+
return torch.tensor([1], dtype=torch.float32, device=device)
3434

3535

3636
def get_sign_from_vector(vector: torch.Tensor) -> int:
@@ -45,7 +45,7 @@ def get_sign_from_vector(vector: torch.Tensor) -> int:
4545
return mask.item()
4646

4747

48-
def get_wgrad_sign_vector() -> torch.Tensor:
48+
def get_wgrad_sign_vector(device: int) -> torch.Tensor:
4949
"""Hard-coded random signs for Hadamard transform.
5050
5151
https://xkcd.com/221/
@@ -54,11 +54,11 @@ def get_wgrad_sign_vector() -> torch.Tensor:
5454
return torch.tensor(
5555
[1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1],
5656
dtype=torch.float32,
57-
device="cuda",
57+
device=device,
5858
)
5959

6060

61-
def get_hadamard_matrix(hadamard_dimension: int) -> torch.Tensor:
61+
def get_hadamard_matrix(hadamard_dimension: int, device: int) -> torch.Tensor:
6262
"""Construct a 16x16 Hadamard matrix."""
6363
assert hadamard_dimension == 16, "Only hadamard dimension 16 is supported."
6464
hadamard_scale = 1 / math.sqrt(hadamard_dimension)
@@ -83,30 +83,30 @@ def get_hadamard_matrix(hadamard_dimension: int) -> torch.Tensor:
8383
[1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, -1, 1],
8484
],
8585
dtype=torch.float32,
86-
device="cuda",
86+
device=device,
8787
)
8888
* hadamard_scale
8989
)
9090

9191

9292
@functools.lru_cache(maxsize=None)
93-
def get_rht_matrix(with_random_sign_mask: bool) -> torch.Tensor:
93+
def get_rht_matrix(with_random_sign_mask: bool, device: int) -> torch.Tensor:
9494
"""Construct matrix used in random Hadamard transform."""
9595
hadamard_dimension = 16
9696
if with_random_sign_mask:
97-
signs = get_wgrad_sign_vector()
97+
signs = get_wgrad_sign_vector(device=device)
9898
else:
99-
signs = get_no_random_sign_vector()
100-
sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32, device="cuda")
101-
rht_matrix = sign_matrix @ get_hadamard_matrix(hadamard_dimension)
99+
signs = get_no_random_sign_vector(device=device)
100+
sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32, device=device)
101+
rht_matrix = sign_matrix @ get_hadamard_matrix(hadamard_dimension, device=device)
102102
return rht_matrix.to(dtype=torch.bfloat16)
103103

104104

105105
@functools.lru_cache(maxsize=None)
106-
def get_random_sign_mask_for_rht(with_random_sign_mask: bool) -> int:
106+
def get_random_sign_mask_for_rht(with_random_sign_mask: bool, device: int) -> int:
107107
"""Sign mask for random Hadamard transform."""
108108
if with_random_sign_mask:
109-
return get_sign_from_vector(get_wgrad_sign_vector())
109+
return get_sign_from_vector(get_wgrad_sign_vector(device=device))
110110
return 0
111111

112112

@@ -152,8 +152,10 @@ def __init__(
152152
self.amax_reduction_group = amax_reduction_group
153153
self.with_2d_quantization = with_2d_quantization
154154
self.stochastic_rounding = stochastic_rounding
155-
self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht(with_random_sign_mask)
156-
self.rht_matrix = get_rht_matrix(with_random_sign_mask)
155+
self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht(
156+
with_random_sign_mask, torch.cuda.current_device()
157+
)
158+
self.rht_matrix = get_rht_matrix(with_random_sign_mask, torch.cuda.current_device())
157159

158160
def update_quantized(
159161
self,

0 commit comments

Comments
 (0)