Skip to content

Commit 05bfa3f

Browse files
jaimec00pre-commit-ci[bot]greptile-apps[bot]ksivaman
authored
[PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP with checkpoint flag (#2311)
* custom tests for selective activation checkpointing for layernorm mlp Signed-off-by: Jaime Cardenas <[email protected]> * add selective layernorm mlp to te.pytorch Signed-off-by: Jaime Cardenas <[email protected]> * update test and fix SLNMLP bug Signed-off-by: Jaime Cardenas <[email protected]> * implement slnmlp Signed-off-by: Jaime Cardenas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Jaime Cardenas <[email protected]> * fix tests pointed out by greptile app bot, still pass Signed-off-by: Jaime Cardenas <[email protected]> * minor formatting change in tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Jaime <[email protected]> Signed-off-by: Jaime Cardenas <[email protected]> * remove duplicate import in test/pytorch/selective_layernorm_mlp/test_recipe.py Signed-off-by: Jaime Cardenas <[email protected]> * clean up tests, remove unused imports Signed-off-by: Jaime Cardenas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Jaime Cardenas <[email protected]> * remove unused paths in test_deffered_init Signed-off-by: Jaime Cardenas <[email protected]> * fix issue with zero_centered_gamma in test_numerics reference implementation Signed-off-by: Jaime Cardenas <[email protected]> * clean up tests Signed-off-by: Jaime Cardenas <[email protected]> * make comparison.py more extensive, cleaner output Signed-off-by: Jaime Cardenas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Jaime Cardenas <[email protected]> * fix small typo in tests/pytorch/selective_layernorm_mlp/compare.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Jaime <[email protected]> Signed-off-by: Jaime Cardenas <[email protected]> * fix typo by grepbot in compare.py Signed-off-by: Jaime Cardenas <[email protected]> * make selectiuve activation checkpointing optional in slnmlp via checkpoint flag Signed-off-by: Jaime Cardenas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Jaime Cardenas <[email protected]> * add comments to clarify logic Signed-off-by: Jaime Cardenas <[email protected]> * add checkpoint param to pytests, change compare.py to compare checkppoint=False vs checkpoint=True, skip cuda graph tests for checkpoint=True Signed-off-by: Jaime Cardenas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Jaime Cardenas <[email protected]> * refactor tests to call modified LayerNormMLP Signed-off-by: Jaime Cardenas <[email protected]> * refactor to implement selective activation checkpointing directly into LayerNormMLP, also fix bug to reach cleanup logic in fwd Signed-off-by: Jaime Cardenas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix skip explanation for cuda_graphs.py Signed-off-by: Jaime Cardenas <[email protected]> * make _recompute deal with lists instead of tuples Signed-off-by: Jaime Cardenas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix MOST cuda graph failures by initializing identical quantizers during fwd. Float8CurrentScaling with bf16 and fp16 still fail with checkpointing Signed-off-by: Jaime Cardenas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix cuda graphs issue, all tests pass now Signed-off-by: Jaime Cardenas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix small logic bugs, clean up Signed-off-by: Jaime Cardenas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * integrate tests into main testing scripts Signed-off-by: Jaime Cardenas <[email protected]> * incorporate rng state tracking in checkpointing Signed-off-by: Jaime Cardenas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean up tests Signed-off-by: Jaime Cardenas <[email protected]> * fix return type mismatches Signed-off-by: Jaime Cardenas <[email protected]> * remove checkpoint test from test_recipe, add sperate test in test_numerics Signed-off-by: Jaime Cardenas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor typo fix Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Jaime <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clear up assertions in tests/pytorch/layernorm_mlp/test_selective_activation_checkpoint.py Signed-off-by: Jaime Cardenas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add license and copyright info Signed-off-by: Jaime Cardenas <[email protected]> * fix lint issues in layernorm_mlp Signed-off-by: Jaime Cardenas <[email protected]> * fix cpu_offload_v1 error Signed-off-by: Jaime Cardenas <[email protected]> * possibly fix recomputation in cuda graph bug Signed-off-by: Jaime Cardenas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * skip cuda graphs test for SLNMLP with SM>=10.0 and using delayed scaling Signed-off-by: Jaime Cardenas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo for setting IS_FIRST_FP8_MODULE Signed-off-by: Jaime Cardenas <[email protected]> --------- Signed-off-by: Jaime Cardenas <[email protected]> Signed-off-by: Jaime <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
1 parent 30c0120 commit 05bfa3f

File tree

8 files changed

+624
-114
lines changed

8 files changed

+624
-114
lines changed

tests/pytorch/distributed/run_numerics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,7 @@ def test_layernorm_mlp():
10301030
{"return_bias": True},
10311031
{"return_layernorm_output": True},
10321032
{"delay_wgrad_compute": True},
1033+
{"checkpoint": True},
10331034
]
10341035

10351036
for kwargs in kwargs_list:

tests/pytorch/distributed/test_numerics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""
1414
Distributed numerics tests
1515
16-
These tests test the numerical corectness of the TransformerEngine layers.
16+
These tests test the numerical correctness of the TransformerEngine layers.
1717
Tests are parametrized by the layer and fp8 precision.
1818
One test consists of running multiple configurations from file run_numerics.py
1919
Such design is due to the fact the initialization of one test is long
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
import torch
6+
from transformer_engine.pytorch import LayerNormMLP
7+
import pytest
8+
9+
torch.manual_seed(1234)
10+
device = torch.device("cuda")
11+
12+
13+
class _Sequential(torch.nn.Sequential):
14+
"""Sequential model that forwards keyword arguments to modules"""
15+
16+
def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
17+
x = input_
18+
for module in self:
19+
x = module(x, **kwargs)
20+
return x
21+
22+
23+
class ModelConfig:
24+
def __init__(
25+
self,
26+
hidden_size: int = 128,
27+
ffn_hidden_size: int = 512,
28+
layers: int = 1,
29+
):
30+
self._hidden_size = hidden_size
31+
self._ffn_hidden_size = ffn_hidden_size
32+
self._layers = layers
33+
34+
def build(self):
35+
36+
ln_list, sln_list = [], []
37+
for _ in range(self._layers):
38+
ln = LayerNormMLP(self._hidden_size, self._ffn_hidden_size, checkpoint=False).to(device)
39+
sln = LayerNormMLP(self._hidden_size, self._ffn_hidden_size, checkpoint=True).to(device)
40+
with torch.no_grad():
41+
sln.layer_norm_weight = torch.nn.Parameter(ln.layer_norm_weight.clone())
42+
sln.layer_norm_bias = torch.nn.Parameter(ln.layer_norm_bias.clone())
43+
sln.fc1_weight = torch.nn.Parameter(ln.fc1_weight.clone())
44+
sln.fc2_weight = torch.nn.Parameter(ln.fc2_weight.clone())
45+
sln.fc1_bias = torch.nn.Parameter(ln.fc1_bias.clone())
46+
sln.fc2_bias = torch.nn.Parameter(ln.fc2_bias.clone())
47+
ln_list.append(ln)
48+
sln_list.append(sln)
49+
50+
ln_model = _Sequential(*ln_list)
51+
sln_model = _Sequential(*sln_list)
52+
53+
return ln_model, sln_model
54+
55+
56+
config = {
57+
"small": ModelConfig(128, 512, 12),
58+
"medium": ModelConfig(512, 2048, 12),
59+
"large": ModelConfig(1024, 4096, 12),
60+
"huge": ModelConfig(2048, 8192, 12),
61+
}
62+
63+
seq_sizes = [2**7, 2**10, 2**14, 2**16]
64+
65+
66+
def _warmup(model, tensor):
67+
for _ in range(3):
68+
model(tensor).sum().backward()
69+
70+
71+
def _run_fwd(model, tensor):
72+
73+
torch.cuda.reset_peak_memory_stats(device)
74+
start_time, end_time = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
75+
enable_timing=True
76+
)
77+
78+
torch.cuda.synchronize()
79+
start_mem = torch.cuda.memory_allocated(device)
80+
start_time.record()
81+
out = model(tensor)
82+
end_time.record()
83+
end_time.synchronize()
84+
elapsed = start_time.elapsed_time(end_time)
85+
peak_mem = torch.cuda.max_memory_allocated(device)
86+
mem = float(peak_mem - start_mem)
87+
88+
return out, elapsed, mem
89+
90+
91+
def _run_bwd(model, out):
92+
93+
model.zero_grad(set_to_none=False)
94+
loss = out.sum()
95+
96+
torch.cuda.reset_peak_memory_stats(device)
97+
start_time, end_time = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
98+
enable_timing=True
99+
)
100+
101+
torch.cuda.synchronize()
102+
start_mem = torch.cuda.memory_allocated(device)
103+
start_time.record()
104+
loss.backward()
105+
end_time.record()
106+
end_time.synchronize()
107+
elapsed = start_time.elapsed_time(end_time)
108+
peak_mem = torch.cuda.max_memory_allocated(device)
109+
mem = float(peak_mem - start_mem)
110+
111+
param_grads = _collect_param_grads(model)
112+
return param_grads, elapsed, mem
113+
114+
115+
def _max_diff(ref, other):
116+
"""Return max absolute difference between two tensors or collections."""
117+
if ref is None or other is None:
118+
return 0.0
119+
if isinstance(ref, (list, tuple)):
120+
diffs = [_max_diff(r, o) for r, o in zip(ref, other)]
121+
return max(diffs) if diffs else 0.0
122+
return torch.max(torch.abs(ref.detach() - other.detach())).item()
123+
124+
125+
def _collect_param_grads(model):
126+
grads = {}
127+
for name, param in model.named_parameters():
128+
if param.grad is None:
129+
continue
130+
key = _param_key(name)
131+
if key is not None:
132+
grads[key] = param.grad.detach().clone()
133+
return grads
134+
135+
136+
def _param_key(name):
137+
return name.split(".")[-1]
138+
139+
140+
@pytest.mark.parametrize("size", config.keys())
141+
@pytest.mark.parametrize("seq_size", seq_sizes)
142+
def test_selective_activation_checkpoint(size, seq_size):
143+
144+
ln_model, sln_model = config[size].build()
145+
data = torch.randn((seq_size, config[size]._hidden_size), device=device)
146+
147+
_warmup(ln_model, data)
148+
ln_fwd_out, ln_fwd_time, ln_fwd_mem = _run_fwd(ln_model, data)
149+
ln_grads, ln_bwd_time, ln_bwd_mem = _run_bwd(ln_model, ln_fwd_out)
150+
151+
_warmup(sln_model, data)
152+
sln_fwd_out, sln_fwd_time, sln_fwd_mem = _run_fwd(sln_model, data)
153+
sln_grads, sln_bwd_time, sln_bwd_mem = _run_bwd(sln_model, sln_fwd_out)
154+
155+
assert ln_fwd_mem > 6 * sln_fwd_mem, (
156+
"selective activation checkpointing does not reduce forward memory by 6X, only by"
157+
f" {ln_fwd_mem/sln_fwd_mem}!"
158+
)
159+
assert ln_bwd_time < sln_bwd_time, (
160+
"selective activation activation checkpointing backward pass is NOT slower than native!"
161+
f" got Native LayerNormMLP Backward Time: {ln_bwd_time} ms and Selective Activation"
162+
f" Checkpointed LayerNormMLP Backward Time: {sln_bwd_time} ms"
163+
)
164+
diff = _max_diff(ln_fwd_out, sln_fwd_out)
165+
assert diff == 0.0, f"outputs are not equal! maximum difference {diff}"
166+
for key in [
167+
"layer_norm_weight",
168+
"layer_norm_bias",
169+
"fc1_weight",
170+
"fc1_bias",
171+
"fc2_weight",
172+
"fc2_bias",
173+
]:
174+
diff = _max_diff(ln_grads[key], sln_grads[key])
175+
assert diff == 0.0, f"gradients for {key} are not equal! maximum difference: {diff}"

tests/pytorch/test_cuda_graphs.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,8 @@ def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
190190
# creating TMA descriptor for MXFP8 quantization.
191191
"linear",
192192
"transformer",
193-
"layernorm_mlp",
193+
"layernorm_mlp_nocheckpoint",
194+
"layernorm_mlp_checkpoint",
194195
"layernorm_linear",
195196
"mha",
196197
"linear_op",
@@ -232,12 +233,23 @@ def _test_cuda_graphs(
232233
)
233234
for _ in range(num_layers)
234235
]
235-
elif module == "layernorm_mlp":
236+
elif module == "layernorm_mlp_nocheckpoint":
236237
modules = [
237238
LayerNormMLP(
238239
model_config.hidden_size,
239240
model_config.hidden_size,
240241
params_dtype=dtype,
242+
checkpoint=False,
243+
)
244+
for _ in range(num_layers)
245+
]
246+
elif module == "layernorm_mlp_checkpoint":
247+
modules = [
248+
LayerNormMLP(
249+
model_config.hidden_size,
250+
model_config.hidden_size,
251+
params_dtype=dtype,
252+
checkpoint=True,
241253
)
242254
for _ in range(num_layers)
243255
]
@@ -376,6 +388,17 @@ def test_make_graphed_callables(
376388
)
377389
if fp8_params:
378390
pytest.skip("NVFP4 params not supported")
391+
if (
392+
fp8
393+
and fp8_recipe.delayed()
394+
and torch.cuda.get_device_capability() >= (10, 0)
395+
and module == "layernorm_mlp_checkpoint"
396+
):
397+
pytest.skip(
398+
"CUDA graphs not supported for LayerNormMLP "
399+
"with checkpoint=True, SM>=10, "
400+
"and DelayedScaling recipe"
401+
)
379402

380403
# Run model with different CUDA graph settings.
381404
model_config = model_configs[model_config]
@@ -402,7 +425,8 @@ def test_make_graphed_callables(
402425

403426
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
404427
"transformer",
405-
"layernorm_mlp",
428+
"layernorm_mlp_nocheckpoint",
429+
"layernorm_mlp_checkpoint",
406430
"layernorm_linear",
407431
"linear",
408432
"mha",

tests/pytorch/test_numerics.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
185185
return dict(rtol=1e-3, atol=1e-5)
186186
if dtype == torch.bfloat16:
187187
return dict(rtol=1.6e-2, atol=1e-5)
188-
raise ValueError(f"Unsuppored dtype ({dtype})")
188+
raise ValueError(f"Unsupported dtype ({dtype})")
189189

190190

191191
def assert_allclose(
@@ -1363,7 +1363,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
13631363
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe)
13641364
te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe)
13651365

1366-
# Shoule be bit-wise match
1366+
# Should be bit-wise match
13671367
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
13681368
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
13691369

@@ -1696,7 +1696,11 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
16961696
@pytest.mark.parametrize("bias", all_boolean)
16971697
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
16981698
def test_layernorm_mlp_accuracy_delay_wgrad_compute(
1699-
dtype, bs, model, bias, fuse_wgrad_accumulation
1699+
dtype,
1700+
bs,
1701+
model,
1702+
bias,
1703+
fuse_wgrad_accumulation,
17001704
):
17011705
config = model_configs[model]
17021706

@@ -1747,6 +1751,58 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
17471751
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
17481752

17491753

1754+
@pytest.mark.parametrize("dtype", param_types)
1755+
@pytest.mark.parametrize("bs", [2])
1756+
@pytest.mark.parametrize("model", ["small"])
1757+
@pytest.mark.parametrize("bias", all_boolean)
1758+
def test_layernorm_mlp_accuracy_checkpoint(
1759+
dtype,
1760+
bs,
1761+
model,
1762+
bias,
1763+
):
1764+
config = model_configs[model]
1765+
1766+
ln_mlp = LayerNormMLP(
1767+
hidden_size=config.hidden_size,
1768+
ffn_hidden_size=4 * config.hidden_size,
1769+
eps=config.eps,
1770+
bias=bias,
1771+
params_dtype=dtype,
1772+
device="cuda",
1773+
checkpoint=True,
1774+
).eval()
1775+
1776+
ln_mlp_ref = LayerNormMLP(
1777+
hidden_size=config.hidden_size,
1778+
ffn_hidden_size=4 * config.hidden_size,
1779+
eps=config.eps,
1780+
bias=bias,
1781+
params_dtype=dtype,
1782+
device="cuda",
1783+
checkpoint=False,
1784+
).eval()
1785+
1786+
# Share params
1787+
with torch.no_grad():
1788+
ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone())
1789+
ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
1790+
ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone())
1791+
ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone())
1792+
if bias:
1793+
ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone())
1794+
ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone())
1795+
1796+
te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=False)
1797+
te_outputs_ref = _test_granular_accuracy(
1798+
ln_mlp_ref, bs, dtype, config, delay_wgrad_compute=False
1799+
)
1800+
1801+
# Shoule be bit-wise match
1802+
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
1803+
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
1804+
1805+
17501806
def _test_grouped_linear_accuracy(
17511807
block,
17521808
num_gemms,

tests/pytorch/test_recipe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
)
3030
import transformer_engine.pytorch.ops as te_ops
3131
from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling
32-
import transformer_engine_torch as tex
3332

3433
# Check if FP8 is supported
3534
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)

tests/pytorch/test_sanity.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,7 @@ def test_sanity_grouped_linear(
525525
@pytest.mark.parametrize("activation", all_activations)
526526
@pytest.mark.parametrize("normalization", all_normalizations)
527527
@pytest.mark.parametrize("microbatching", all_boolean)
528+
@pytest.mark.parametrize("checkpoint", all_boolean)
528529
def test_sanity_layernorm_mlp(
529530
dtype,
530531
fp8_recipe,
@@ -535,6 +536,7 @@ def test_sanity_layernorm_mlp(
535536
activation,
536537
normalization,
537538
microbatching,
539+
checkpoint,
538540
):
539541
config = model_configs[model]
540542

@@ -559,6 +561,7 @@ def test_sanity_layernorm_mlp(
559561
normalization=normalization,
560562
params_dtype=dtype,
561563
device="cuda",
564+
checkpoint=checkpoint,
562565
)
563566
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
564567

0 commit comments

Comments
 (0)