-
Notifications
You must be signed in to change notification settings - Fork 546
[PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP with checkpoint flag #2311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers the entire PR, which introduces SelectiveLayerNormMLP, a memory-optimized variant of LayerNormMLP that implements selective activation checkpointing. The core implementation trades compute for memory by avoiding saving intermediate activations (ln_out, fc1_out, gelu_out) during forward passes and instead recomputing them during backward passes—crucially skipping fc2 recomputation since it's not needed for gradients. The module integrates with TransformerEngine's existing infrastructure (sequence_parallel, tensor_parallel, FP8 quantization) while explicitly disabling CPU offloading. The benchmark data shows dramatic memory reduction (64GB → 10GB forward memory in the example) at the cost of 20% slower backward passes. The change maintains full compatibility with existing LayerNormMLP configurations and is implemented as an additional module rather than modifying existing code.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/selective_layernorm_mlp.py | 3/5 | Implements the new SelectiveLayerNormMLP module with selective checkpointing; contains commented-out FSDP/offloading code that needs verification |
| tests/pytorch/selective_layernorm_mlp/test_deferred_init.py | 2/5 | Adds deferred init tests but has a critical bug: seq_length kwarg is not set for SelectiveLayerNormMLP (line 34 condition only checks LayerNormMLP) |
| tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py | 4/5 | Distributed test file that creates LayerNormMLP instances instead of SelectiveLayerNormMLP (lines 403-404), so doesn't actually test the new feature |
| tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py | 5/5 | Adds CUDA graph tests with proper module instantiation and parametrization; PR description notes tests fail with 4e-4 errors |
| tests/pytorch/selective_layernorm_mlp/test_numerics.py | 5/5 | Minor formatting changes only: function signature reformatted and blank lines removed for style consistency |
| tests/pytorch/selective_layernorm_mlp/test_sanity.py | 5/5 | Comprehensive sanity tests for the new module with minor formatting adjustments to match project style |
| tests/pytorch/selective_layernorm_mlp/test_recipe.py | 5/5 | Formatting cleanup: removed duplicate imports and adjusted parameter list formatting |
| tests/pytorch/selective_layernorm_mlp/compare.py | 5/5 | Benchmark script comparing LayerNormMLP vs SelectiveLayerNormMLP; formatting changes only (PEP8 compliance) |
Confidence score: 2/5
- This PR requires careful review due to critical bugs in test files and incomplete CUDA graph compatibility
- Score reflects three major issues: (1) test_deferred_init.py will fail due to missing
seq_lengthparameter handling, (2) distributed/run_numerics.py doesn't actually test the new module, (3) CUDA graph tests are failing per PR description, and (4) extensive commented-out code in the main implementation suggests incomplete feature development particularly around FSDP and CPU offloading - Pay close attention to transformer_engine/pytorch/module/selective_layernorm_mlp.py (commented sections lines 550-571, 915-929, 1545-1552), tests/pytorch/selective_layernorm_mlp/test_deferred_init.py (line 34 conditional logic), and tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py (lines 403-404 model instantiation)
8 files reviewed, 2 comments
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the formatting and style changes made since the last review, not the entire PR. The changes consist entirely of whitespace adjustments: removing extraneous blank lines from seven test files (test_sanity.py, test_cuda_graphs.py, compare.py, test_recipe.py, test_numerics.py, and distributed/run_numerics.py) and the end-of-file blank line. These modifications align the new test suite with the project's formatting standards enforced by pre-commit hooks (black formatter, end-of-file-fixer, trailing-whitespace checker) and the .clang-format configuration that specifies 100-character column limits, 2-space indentation for C++, and strict whitespace rules. Since SelectiveLayerNormMLP is a memory-optimized variant of LayerNormMLP with identical API and forward-pass behavior, the test files are appropriately adapted copies of the existing LayerNormMLP test infrastructure, ensuring comprehensive coverage of FP8 recipes, distributed configurations, deferred initialization, and numerical accuracy without duplicating test logic.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/test_sanity.py | 5/5 | Removed two extraneous blank lines (lines 131,161) for formatting consistency |
| tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py | 5/5 | Removed trailing blank line at end of file |
| tests/pytorch/selective_layernorm_mlp/compare.py | 5/5 | Removed leading blank line and reformatted long print statement to fit 100-char limit |
| tests/pytorch/selective_layernorm_mlp/test_recipe.py | 4/5 | Contains duplicate import of transformer_engine_torch (lines 22 and 29) |
| tests/pytorch/selective_layernorm_mlp/test_deferred_init.py | 5/5 | No visible formatting changes in this diff |
| tests/pytorch/selective_layernorm_mlp/test_numerics.py | 5/5 | Removed three blank lines (lines 202, 273, 531) and reformatted function signature (lines 406-408) |
| tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py | 5/5 | Removed three blank lines (lines 33, 383, 491) for consistent spacing |
| transformer_engine/pytorch/module/selective_layernorm_mlp.py | 5/5 | No whitespace-related changes; core implementation is unchanged |
Confidence score: 4/5
- This PR is generally safe to merge with respect to the formatting changes; however, the test_cuda_graphs.py failure and the duplicate import in test_recipe.py require attention before merging.
- Score reflects that the formatting changes are correct and align with project standards, but the duplicate import (test_recipe.py lines 22 and 29) should be removed, and the test_cuda_graphs.py failures mentioned in the PR description remain unresolved and may indicate numerical instability or a CUDA graph compatibility issue that requires investigation.
- Pay close attention to tests/pytorch/selective_layernorm_mlp/test_recipe.py (duplicate import) and confirm that test_cuda_graphs.py failures are understood and either resolved or deemed acceptable before merging.
8 files reviewed, 2 comments
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the most recent docstring correction in tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py (line 391), updating the reference from "LayerNormMLP" to "SelectiveLayerNormMLP". This is a trivial documentation fix that completes the refactoring of existing tests to validate the new SelectiveLayerNormMLP module introduced by this PR. All substantive issues (duplicate import, missing condition logic, extra blank lines) were already flagged in earlier reviews and are not repeated here.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py | 5/5 | Corrected docstring to reference SelectiveLayerNormMLP instead of LayerNormMLP (line 391). |
Confidence score: 5/5
- This single‑line docstring fix is safe to merge with no risk; it improves documentation accuracy.
- Score remains perfect because the change is purely cosmetic—no logic, imports, or functionality is affected.
- No files require additional attention; earlier reviews have already identified all substantive issues in this test file.
1 file reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. A duplicate import of transformer_engine_torch as tex has been removed from the test file tests/pytorch/selective_layernorm_mlp/test_recipe.py, cleaning up the code. However, the file still contains the same import on line 22, indicating either an incomplete change or a subsequent revert not captured in the diff.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/test_recipe.py | 5/5 | Removed duplicate import statement, no functional impact on test logic |
Confidence score: 5/5
- This change is safe to merge with minimal risk
- The duplicate import removal is a simple cleanup with no functional impact; the test logic remains identical and validation of FP8 quantizer behavior is unchanged
- No files require special attention
1 file reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The most recent updates focus on cleaning up test files for the new SelectiveLayerNormMLP module by removing unused imports and fixing minor formatting issues. The changes address previous review feedback by eliminating duplicate imports (like transformer_engine_torch as tex) and removing extraneous blank lines. However, several issues remain unresolved from earlier reviews, including the incomplete fix in test_deferred_init.py where the conditional logic in get_module_args still references unused module types (te.Linear, te.LayerNormLinear, te.MultiheadAttention, te.TransformerLayer), creating unreachable dead code. The CUDA graph test failures mentioned in the PR description (errors around 4e-4 and 0.1 absolute errors) persist and are acknowledged as incomplete work.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/test_recipe.py | 5/5 | Removed numerous unused imports while preserving FP8 recipe test functionality |
| tests/pytorch/selective_layernorm_mlp/test_deferred_init.py | 4/5 | Removed blank line but left dead code in get_module_args referencing unused module types |
| tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py | 3/5 | Simplified to test only SelectiveLayerNormMLP; tests currently failing per PR description |
| tests/pytorch/selective_layernorm_mlp/test_sanity.py | 4/5 | Cleaned up imports for focused SelectiveLayerNormMLP testing with comprehensive parameterization |
| tests/pytorch/selective_layernorm_mlp/test_numerics.py | 4/5 | Removed unused imports and functions; minor formatting inconsistency remains |
| tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py | 5/5 | Removed unused imports with no functional changes |
Confidence score: 3/5
- This PR requires careful review due to incomplete CUDA graph functionality and remaining code quality issues.
- Score reflects acknowledged test failures (test_cuda_graphs.py), dead code in helper methods (test_deferred_init.py line 25-43), and minor formatting inconsistencies. The changes are primarily test infrastructure cleanup, but the core feature is not yet production-ready.
- Pay close attention to test_cuda_graphs.py (numerical precision failures) and test_deferred_init.py (unreachable conditional branches for
te.Linear,te.MultiheadAttention,te.TransformerLayer).
6 files reviewed, 3 comments
d09f38b to
b601d83
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR introduces SelectiveLayerNormMLP, a memory-optimized variant of LayerNormMLP that implements selective activation checkpointing. The forward pass saves only input tensors and weights rather than intermediate activations (fc1_out, gelu_out), reducing memory consumption by ~80% during training. During backward, FC1 and activation are recomputed, but FC2 recomputation is skipped since it's unnecessary for gradient calculation. The implementation maintains compatibility with TransformerEngine's quantization systems (FP8, MXFP8, NVFP4) and distributed features (tensor parallel, sequence parallel), though CPU offloading is explicitly disabled. The module integrates into the existing PyTorch API alongside LayerNormMLP and includes comprehensive test coverage mirroring the standard module's test suite. CUDA graph tests currently fail with ~4e-4 numerical differences, indicating potential interactions between graph capture and the selective recomputation strategy.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/selective_layernorm_mlp.py | 4/5 | New 1000+ line module implementing selective checkpointing with bifurcated forward logic for training vs. recomputation |
| transformer_engine/pytorch/module/init.py | 5/5 | Adds SelectiveLayerNormMLP import to module's public API |
| transformer_engine/pytorch/init.py | 5/5 | Exposes SelectiveLayerNormMLP in top-level PyTorch API |
| tests/pytorch/selective_layernorm_mlp/test_numerics.py | 1/5 | Critical bug: reference implementations initialize normalization weights to zeros instead of ones when zero_centered_gamma=False |
| tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py | 4/5 | Missing import warnings causes NameError when zero-tensors are detected; otherwise sound distributed validation |
| tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py | 4/5 | CUDA graph tests currently failing per PR description; unused import and incorrect return type annotation present |
| tests/pytorch/selective_layernorm_mlp/compare.py | 3/5 | Performance comparison script with reversed weight-copy direction (copies from SLN to LN instead of vice versa) |
| tests/pytorch/selective_layernorm_mlp/utils.py | 4/5 | Test utilities with incorrect return type annotation (declares 2-tuple but returns 3 elements) |
| tests/pytorch/selective_layernorm_mlp/test_sanity.py | 5/5 | Comprehensive sanity tests covering dtypes, recipes, activations, normalizations, and microbatching |
| tests/pytorch/selective_layernorm_mlp/test_deferred_init.py | 4/5 | Deferred initialization test with dead code for untested modules |
| tests/pytorch/selective_layernorm_mlp/test_recipe.py | 4/5 | FP8 recipe validation test with unused capability-check imports |
| tests/pytorch/selective_layernorm_mlp/distributed/test_numerics.py | 4/5 | Distributed test wrapper with typo in docstring and unused variable |
Confidence score: 2/5
- This PR cannot be merged safely due to critical bugs in the test reference implementations and missing imports that will cause runtime failures.
- Score reflects: (1) test_numerics.py initializes reference normalization weights to zeros instead of ones, making all tests invalid; (2) distributed/run_numerics.py uses
warnings.warn()without importingwarnings; (3) CUDA graph tests are explicitly failing per PR description; (4) compare.py has reversed weight-copy direction; (5) multiple type annotation mismatches that will cause type-checking failures. - Pay close attention to tests/pytorch/selective_layernorm_mlp/test_numerics.py (lines 144-145, 170-171), distributed/run_numerics.py (line 34), and the core module's commented-out FSDP code which may indicate incomplete distributed functionality.
Sequence Diagram
sequenceDiagram
participant User
participant SelectiveLayerNormMLP
participant _SelectiveLayerNormMLP
participant ForwardPass
participant BackwardPass
participant Quantizers
participant GEMM
User->>SelectiveLayerNormMLP: forward(inp, is_first_microbatch)
SelectiveLayerNormMLP->>SelectiveLayerNormMLP: prepare_forward()
SelectiveLayerNormMLP->>SelectiveLayerNormMLP: _get_quantizers()
SelectiveLayerNormMLP->>Quantizers: Initialize quantizers
Quantizers-->>SelectiveLayerNormMLP: Return quantizers
SelectiveLayerNormMLP->>_SelectiveLayerNormMLP: _forward(..., recompute_for_bwd=False)
Note over _SelectiveLayerNormMLP: Save tensors for backward (inp, weights, etc.)
_SelectiveLayerNormMLP->>ForwardPass: apply_normalization()
ForwardPass-->>_SelectiveLayerNormMLP: ln_out, mu, rsigma
alt sequence_parallel
_SelectiveLayerNormMLP->>ForwardPass: gather_along_first_dim()
ForwardPass-->>_SelectiveLayerNormMLP: ln_out_total
end
_SelectiveLayerNormMLP->>GEMM: general_gemm(fc1_weight, ln_out_total)
Note over GEMM: FC1 GEMM with optional gelu fusion
GEMM-->>_SelectiveLayerNormMLP: fc1_out
_SelectiveLayerNormMLP->>ForwardPass: activation_func(fc1_out)
ForwardPass-->>_SelectiveLayerNormMLP: act_out
_SelectiveLayerNormMLP->>GEMM: general_gemm(fc2_weight, act_out)
Note over GEMM: FC2 GEMM
GEMM-->>_SelectiveLayerNormMLP: fc2_out
alt sequence_parallel
_SelectiveLayerNormMLP->>ForwardPass: reduce_scatter_along_first_dim()
ForwardPass-->>_SelectiveLayerNormMLP: fc2_out
end
_SelectiveLayerNormMLP-->>SelectiveLayerNormMLP: fc2_out
SelectiveLayerNormMLP-->>User: Return output
User->>BackwardPass: loss.backward()
BackwardPass->>_SelectiveLayerNormMLP: _recompute(ctx)
Note over _SelectiveLayerNormMLP: Restore saved tensors
_SelectiveLayerNormMLP->>_SelectiveLayerNormMLP: _forward(..., recompute_for_bwd=True)
Note over _SelectiveLayerNormMLP: Recompute activations WITHOUT saving fc2
_SelectiveLayerNormMLP->>ForwardPass: apply_normalization()
ForwardPass-->>_SelectiveLayerNormMLP: ln_out, mu, rsigma
_SelectiveLayerNormMLP->>GEMM: general_gemm(fc1_weight, ln_out_total)
GEMM-->>_SelectiveLayerNormMLP: fc1_out
_SelectiveLayerNormMLP->>ForwardPass: activation_func(fc1_out)
ForwardPass-->>_SelectiveLayerNormMLP: act_out
Note over _SelectiveLayerNormMLP: Skip FC2 forward (not needed for backward)
_SelectiveLayerNormMLP->>BackwardPass: backward(grad_outputs)
BackwardPass->>GEMM: general_gemm(fc2_weight, grad_output)
Note over GEMM: FC2 DGRAD
GEMM-->>BackwardPass: fc2_dgrad
BackwardPass->>ForwardPass: activation_func_bwd(fc2_dgrad, fc1_out)
ForwardPass-->>BackwardPass: dact
alt fc2_weight_requires_grad
BackwardPass->>GEMM: general_gemm(act_out, grad_output)
Note over GEMM: FC2 WGRAD
GEMM-->>BackwardPass: fc2_wgrad, fc2_bias_grad
end
BackwardPass->>GEMM: general_gemm(fc1_weight, dact)
Note over GEMM: FC1 DGRAD
GEMM-->>BackwardPass: fc1_dgrad
alt fc1_weight_requires_grad
BackwardPass->>GEMM: general_gemm(ln_out_total, dact)
Note over GEMM: FC1 WGRAD
GEMM-->>BackwardPass: fc1_wgrad, fc1_bias_grad
end
BackwardPass->>ForwardPass: layernorm_bwd/rmsnorm_bwd()
ForwardPass-->>BackwardPass: dgrad, dgamma, dbeta
BackwardPass-->>User: Return gradients
12 files reviewed, 13 comments
tests/pytorch/selective_layernorm_mlp/distributed/test_numerics.py
Outdated
Show resolved
Hide resolved
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The developer removed redundant nested if-condition logic in the deferred initialization test file for SelectiveLayerNormMLP. Previously, the code checked if module == te.SelectiveLayerNormMLP twice: once at line 29 and again at line 33 within the same block. This review captures the cleanup of that duplication, streamlining the conditional logic. The test file verifies that SelectiveLayerNormMLP can be instantiated on a 'meta' device with zero memory allocation and then properly materialized via reset_parameters().
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/test_deferred_init.py | 5/5 | Removed redundant nested if-check for SelectiveLayerNormMLP module condition |
Confidence score: 5/5
- This PR is safe to merge with minimal risk
- Score reflects straightforward refactoring with no functional changes or edge cases introduced
- No files require special attention
1 file reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The latest commit fixes a critical bug in the test reference implementations (TorchLayerNorm and TorchRMSNorm) within tests/pytorch/selective_layernorm_mlp/test_numerics.py. Previously, the weight initialization logic was inverted: when zero_centered_gamma=False, the gamma parameter was incorrectly initialized to zeros instead of ones, which would cause the normalized output to be scaled incorrectly. The fix swaps the conditional branches so that zero_centered_gamma=False correctly initializes weights to ones (standard LayerNorm behavior), and zero_centered_gamma=True initializes to zeros (zero-centered variant). This correction ensures that the torch-based reference models produce the expected numerical outputs, allowing accurate validation of the SelectiveLayerNormMLP module against known-correct baselines. This change is isolated to the test infrastructure and does not affect the production implementation of SelectiveLayerNormMLP.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/test_numerics.py | 5/5 | Fixed inverted weight initialization logic in TorchLayerNorm and TorchRMSNorm reference implementations |
Confidence score: 5/5
- This change is safe to merge—it fixes a clear bug in test reference code with no risk to production paths
- The fix corrects an inverted conditional that was causing reference models to initialize gamma weights incorrectly, ensuring accurate numerical validation going forward
- No files require special attention—the change is a straightforward two-line correction in test utilities
1 file reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The developer has addressed previous feedback by fixing the parameter-setting logic in test_deferred_init.py, correcting the weight-copy direction in compare.py, and refining timing/measurement code. However, several critical issues remain: a type annotation bug in test_cuda_graphs.py (get_nvfp4_inp_supported_dtypes returns List[torch.dtype] but is annotated as bool), unused imports (warnings in run_numerics.py, time in compare.py), and a typo in the PR description (ffn_fidden on line 261 of compare.py). The test files validate that SelectiveLayerNormMLP maintains numerical parity with LayerNormMLP across distributed and non-distributed configurations while significantly reducing memory usage by recomputing activations instead of caching them. The PR description notes that test_cuda_graphs.py is failing with numerical errors (typically 4e-4, occasionally 0.1 absolute), which correlates with the type annotation bug in that file.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py | 2/5 | Critical type annotation bug on line 71: function returns List[torch.dtype] but annotated as bool, causing type-checking failures and possibly contributing to test failures |
| tests/pytorch/selective_layernorm_mlp/test_numerics.py | 4/5 | Adds comprehensive numerical validation comparing SelectiveLayerNormMLP against torch reference implementations across dtypes, activations, and FP8/NVFP4 recipes |
| tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py | 3/5 | Validates distributed numerics for tensor/sequence parallelism; unused warnings import added but comment updated correctly |
| tests/pytorch/selective_layernorm_mlp/test_deferred_init.py | 4/5 | Simplified parameter-setting logic by removing redundant conditional check; correct for single-module test file |
| tests/pytorch/selective_layernorm_mlp/test_sanity.py | 4/5 | Standard sanity tests adapted for SelectiveLayerNormMLP; comprehensive coverage of configurations and FP8 recipes |
| tests/pytorch/selective_layernorm_mlp/compare.py | 3/5 | Benchmark script refactored to use CUDA events and per-config stats; unused time import remains and typo ffn_fidden on line 261 |
Confidence score: 3/5
- This PR requires careful attention due to a critical type annotation bug and known test failures (cuda_graphs), though the core numerical validation appears sound
- Score reflects: (1) critical type annotation error in
test_cuda_graphs.pythat breaks type checking and may contribute to reported test failures; (2) unused imports in two files suggesting incomplete cleanup; (3) PR description acknowledges cuda_graphs tests are failing with numerical errors, indicating incomplete functionality - Pay close attention to
test_cuda_graphs.py(line 71 type annotation) and investigate why CUDA graph tests fail with 4e-4 errors—the annotation bug may be masking logic errors in the supported-dtypes check
6 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The most recent changes apply code formatting to tests/pytorch/selective_layernorm_mlp/compare.py, reformatting a multi-line f-string descriptor to comply with line-length and style guidelines. While spacing, operator placement, and list formatting improvements are beneficial, a critical syntax error was introduced: the descriptor string is now missing its closing parenthesis after the f-string definition, preventing the script from running. No other changes were made to imports, logic, or functionality in this update.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/compare.py | 0/5 | Applied code formatting but introduced syntax error: missing closing parenthesis after multi-line f-string definition (line 261) |
Confidence score: 0/5
- This PR will fail to run due to a syntax error that prevents the comparison script from executing
- Score is zero because a closing parenthesis is missing after the f-string definition on line 261, causing a guaranteed SyntaxError when the script is imported or executed
- The file
tests/pytorch/selective_layernorm_mlp/compare.pyrequires immediate attention to add the missing closing parenthesis after line 262
1 file reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This update addresses a single-character syntax fix in the benchmark comparison script for SelectiveLayerNormMLP. The change adds a missing closing parenthesis on line 262 of tests/pytorch/selective_layernorm_mlp/compare.py that completes the multi-line f-string definition for the desc variable. This variable formats configuration metadata (sequence length, hidden size, FFN hidden size, and number of layers) that labels benchmark results comparing LayerNormMLP and SelectiveLayerNormMLP. Without this closing parenthesis, the script would raise a SyntaxError and fail to execute. The fix ensures the benchmarking tool can run properly and display properly formatted configuration descriptions in its output.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/selective_layernorm_mlp/compare.py | 5/5 | Added missing closing parenthesis to complete multi-line f-string definition |
Confidence score: 5/5
- This PR is safe to merge with minimal risk
- The change fixes a syntax error that would have prevented the script from running, and introduces no new logic or behavioral changes
- No files require special attention; this is a trivial syntax correction
1 file reviewed, no comments
|
Hi @jaimec00, thank you for the contribution! |
|
Thanks for the comment, @ptrendx! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The most recent update adds a new SelectiveLayerNormMLP module that implements selective activation checkpointing, trading ~20% additional backward compute time for dramatic memory savings (up to 6x reduction in forward peak memory). The implementation introduces a checkpoint parameter that gates three key behaviors: (1) which tensors are saved during forward (only inputs/weights vs. full activations), (2) whether FC2 is recomputed during backward (it's skipped since FC2's output isn't needed for weight gradients), and (3) CPU offloading availability (disabled when checkpointing is active). The module closely mirrors LayerNormMLP's structure, reusing the same FP8/tensor-parallel/sequence-parallel infrastructure while adding conditional logic through save_for_checkpoint and is_recomputation flags. This new module integrates with the existing PyTorch module hierarchy under transformer_engine/pytorch/module/, alongside sibling modules like LayerNormMLP and LayerNormLinear.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/selective_layernorm_mlp.py | 3/5 | Adds new SelectiveLayerNormMLP module with selective activation checkpointing that skips FC2 recomputation during backward, reducing memory at the cost of additional compute |
Confidence score: 3/5
- This PR requires careful review due to incomplete testing and a known CUDA graphs compatibility issue that causes numerical errors (4e-4 typical, 0.1 absolute in some cases)
- Score reflects the unresolved CUDA graphs test failures and the significant complexity of the selective checkpointing logic that conditionally saves/recomputes tensors across forward/backward passes—while the core implementation appears sound, the interaction between checkpoint flags, FP8 metadata tracking, and CUDA graph capture needs verification
- Pay close attention to the CUDA graphs test failures in test_cuda_graphs.py (lines documenting 4e-4 typical errors with occasional 0.1 absolute errors), the conditional tensor saving logic in lines 621-691 (which determines what gets saved vs. recomputed based on checkpoint/is_grad_enabled/save_for_checkpoint flags), and the _recompute method's dual-path behavior (lines 920-925) that either reruns forward with
recompute_for_bwd=Trueor loads saved tensors
1 file reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The developer has addressed the review feedback by correcting module references, removing dead code, fixing initialization logic for RMSNorm/LayerNorm gamma parameters, refactoring the comparison script to test SelectiveLayerNormMLP against itself (with/without checkpointing), and adding missing imports. The core implementation of selective activation checkpointing in SelectiveLayerNormMLP._forward remains unchanged. Test infrastructure has been updated to parametrize the new checkpoint flag across all test suites, though CUDA graph tests remain skipped when checkpoint=True.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/selective_layernorm_mlp.py | 5/5 | Core implementation of selective activation checkpointing for LayerNormMLP; no changes since last review |
| tests/pytorch/selective_layernorm_mlp/compare.py | 5/5 | Refactored to compare SelectiveLayerNormMLP(checkpoint=False) vs SelectiveLayerNormMLP(checkpoint=True) instead of comparing against LayerNormMLP |
| tests/pytorch/selective_layernorm_mlp/test_numerics.py | 5/5 | Added checkpoint parameter to test matrix; fixed gamma initialization logic for RMSNorm/LayerNorm |
| tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py | 5/5 | Added test cases for checkpoint=True and checkpoint=False; added missing warnings import |
| tests/pytorch/selective_layernorm_mlp/test_sanity.py | 5/5 | Added checkpoint parameter to sanity test matrix |
| tests/pytorch/selective_layernorm_mlp/test_recipe.py | 5/5 | Added checkpoint parameter to quantizer update test |
| tests/pytorch/selective_layernorm_mlp/test_deferred_init.py | 5/5 | Added checkpoint parameter to deferred initialization tests |
| tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py | 4/5 | Added checkpoint parameter but explicitly skips tests when checkpoint=True due to known failures |
Confidence score: 5/5
- This PR is safe to merge with minimal risk; the selective checkpointing implementation is sound and well-tested
- Score reflects that all previous review issues were addressed, tests pass (except known CUDA graph limitations), and the implementation follows TE patterns for custom autograd functions with memory management
- Pay close attention to
test_cuda_graphs.py—the developer explicitly asks for help with CUDA graph failures whencheckpoint=True, which remain unresolved and are currently skipped
8 files reviewed, 3 comments
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Jaime Cardenas <[email protected]>
Signed-off-by: Jaime Cardenas <[email protected]>
Signed-off-by: Jaime Cardenas <[email protected]>
Signed-off-by: Jaime Cardenas <[email protected]>
for more information, see https://pre-commit.ci Signed-off-by: Jaime Cardenas <[email protected]>
Signed-off-by: Jaime Cardenas <[email protected]>
…ributed/run_numerics.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Jaime <[email protected]> Signed-off-by: Jaime Cardenas <[email protected]>
…recipe.py Signed-off-by: Jaime Cardenas <[email protected]>
Signed-off-by: Jaime Cardenas <[email protected]>
for more information, see https://pre-commit.ci Signed-off-by: Jaime Cardenas <[email protected]>
Signed-off-by: Jaime Cardenas <[email protected]>
|
Hi @jaimec00, sorry for the delay - I will look at it tomorrow and I think we should be ready to merge soon :-). |
|
@ptrendx perfect, thank you so much! |
|
/te-ci pytorch L0 L1 |
ksivaman
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks @jaimec00 !!
We can merge once the CI is complete
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 1 comment
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
|
Hi @ksivaman @ptrendx , thanks for your help! I just saw the outputs of the CI. From the failed L0 tests, it look like I missed something when I merged the latest CPU offloading changes from 'main' into this version, since the failures are for LayerNormMLP and Transformer, I am looking into that now. For the L1 onnx tests, it looks like it might be an issue with RMSNorm in general, since the failures are but let me know if you think this is due to my changes and I can start trying to fix it. Same for the thunder integration tests, as from the logs it uses te.Linear which I haven't touched. for the distributed tests, it looks like all the LayerNormMLP modules passed, including with checkpoint=True (I searched the logs for "layernorm_mlp" and "LayerNormMLP", all those tests passed) So I will start working on the fix for the CPU offloading, but let me know if it is possible that I am wrong about any of the other failures being independant, and I can start working on those too. Thanks! |
|
The cpu offloading v1 test that you mentioned and the cuda graphs tests are the relevant failing tests. The L1 thunder, onnx, and distributed tests are failing in |
|
@ksivaman sounds good! Looking into both right now, I missed the CUDA graphs error because I was developing on H100, and it looks like the failure happens on B200 only. You are probably right about it being an RNG state issue, just a little weird that it passes for the other hardware. Will try to get a fix for both of these soon! |
Signed-off-by: Jaime Cardenas <[email protected]>
Signed-off-by: Jaime Cardenas <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, no comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
|
@ksivaman the test_cpu_offloading_v1.py now passes, it was a mistake on my part when I merged main into this branch, but it is fixed now. The CUDA graphs issue is still in progress. I made a change that could fix it, particularly resetting the FP8GlobalStateManager.IS_FIRST_FP8_MODULE attribute while saving the recomputed activations if it is selective activation checkpointing and we are recomputing. While it could be the issue you mentioned about RNG states, I am inclined to think that it has to do with the CUDA version or the compute capability, since the CUDA graphs test passed for the A100, L40, and H100 runners, but only failed for the tests on the B200 runner. An additional note is that the failed tests all used DelayedScaling recipe. I have been developing on an H100 with CUDA 12.8, so as of right now I was not able to reproduce the error (all tests pass on my machine). However, I will see what I can do tomorrow to get access to a B200 to see if I can reproduce the error then. It could also be the CUDA version, but I was not able to find if the different runners used different CUDA versions. I know that the CUDA graphs in CUDA 12.9 show much better performance and have more features, so it is possible that I need to take that into account. It could of course be an RNG state issue, like you said, so I will check that tomorrow once I can reproduce the error. Could you clarify what CUDA versions each of the runners uses? Thank you! |
Signed-off-by: Jaime Cardenas <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, no comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
|
Hi @ksivaman, unfortunately, I do not think I will be able to access a B200 to replicate the error. For now, I simply skip the CUDA graphs test when checkpoint=True, the recipe is DelayedScaling, and the compute capability is greater than 10.0. This should pass the CI, but let me know if this is acceptable or not. Thanks! |
Signed-off-by: Jaime Cardenas <[email protected]>
Signed-off-by: Jaime Cardenas <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, no comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
|
/te-ci pytorch L0 L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 1 comment
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
|
Thanks again @jaimec00 😃 |
Description
Implement Selective Activation Checkpointing for LayerNormMLP by adding a "checkpoint" flag. If checkpoint=True, activations are recomputed in the backward pass, while skipping the recomputation of fc2, as it is not needed for the backward pass. This reduces memory significantly, allowing for larger MLPs without running into OOMs, while still keeping the functionality for SequenceParallel and TensorParallel. Only functionality that is changed is cpu offloading, since there are no more activations to offload when checkpoint=True.
When checkpoint=False, runs regular LayerNormMLP, and all tests in tests/pytorch/selective_layernorm_mlp (listed in "Changes") pass. When checkpoint=True, all tests pass,
except for test_cuda_graphs.py.NOTE: all tests pass now. The cuda_graphs.py issue was because the recomputation was being done outside of the autocast context. saving the autocast state (and quantizer states) via FP8GlobalStateManager in fwd, setting them to that in recomputation, and restoring for bwd fixed the issue.
Fixes #623
Type of change
Changes
Performance
with checkpoint=True, forward pass sees no notable change in runtime, but >6X reduction in memory. Backward pass is$\approx$ 20% slower, with larger memory than with checkpoint=False, but still significantly less memory overall.
#########################################
Model Config
TIME (ms)
MEM (MB)
MAX ABSOLUTE ERRORS
#########################################
Model Config
TIME (ms)
MEM (MB)
MAX ABSOLUTE ERRORS
#########################################
Checklist: