Skip to content

Conversation

@jaimec00
Copy link
Contributor

@jaimec00 jaimec00 commented Oct 28, 2025

Description

Implement Selective Activation Checkpointing for LayerNormMLP by adding a "checkpoint" flag. If checkpoint=True, activations are recomputed in the backward pass, while skipping the recomputation of fc2, as it is not needed for the backward pass. This reduces memory significantly, allowing for larger MLPs without running into OOMs, while still keeping the functionality for SequenceParallel and TensorParallel. Only functionality that is changed is cpu offloading, since there are no more activations to offload when checkpoint=True.

When checkpoint=False, runs regular LayerNormMLP, and all tests in tests/pytorch/selective_layernorm_mlp (listed in "Changes") pass. When checkpoint=True, all tests pass, except for test_cuda_graphs.py.

NOTE: all tests pass now. The cuda_graphs.py issue was because the recomputation was being done outside of the autocast context. saving the autocast state (and quantizer states) via FP8GlobalStateManager in fwd, setting them to that in recomputation, and restoring for bwd fixed the issue.

Fixes #623

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • restructure transformer_engine/pytorch/module/layernorm_mlp.py to allow for selective activation checkpointing with checkpoint=True
  • Add tests for checkpoint=True (in tests/pytorch/selective_layernorm_mlp). Note, these are basically just refactored tests for LayerNormMLP:
    • test_numerics.py
    • test_sanity.py
    • test_cuda_graphs.py
    • test_deferred_init.py
    • test_recipe.py
    • distributed/test_numerics.py
    • compare.py, compares LayerNormMLP with and without checkpoint flag, here are a couple of example comparisons

Performance

with checkpoint=True, forward pass sees no notable change in runtime, but >6X reduction in memory. Backward pass is $\approx$ 20% slower, with larger memory than with checkpoint=False, but still significantly less memory overall.

#########################################

Model Config

seq hidden ffn_hidden layers
65536 2048 8192 12

TIME (ms)

Metric No Checkpointing Checkpointing
fwd time 166.687 167.204
bwd time 452.350 550.005

MEM (MB)

Metric No Checkpointing Checkpointing
fwd mem 64,430.801 10,737.943
bwd mem 2,751.767 7,584.130

MAX ABSOLUTE ERRORS

Tensor Max Abs Error
output 0.000e+00
layer_norm_weight.grad 0.000e+00
layer_norm_bias.grad 0.000e+00
fc1_weight.grad 0.000e+00
fc1_bias.grad 0.000e+00
fc2_weight.grad 0.000e+00
fc2_bias.grad 0.000e+00

#########################################

Model Config

seq hidden ffn_hidden layers
16384 1024 4096 12

TIME (ms)

Metric No Checkpointing Checkpointing
fwd time 11.924 12.428
bwd time 31.612 39.968

MEM (MB)

Metric No Checkpointing Checkpointing
fwd mem 8,054.637 1,342.308
bwd mem 352.343 956.453

MAX ABSOLUTE ERRORS

Tensor Max Abs Error
output 0.000e+00
layer_norm_weight.grad 0.000e+00
layer_norm_bias.grad 0.000e+00
fc1_weight.grad 0.000e+00
fc1_bias.grad 0.000e+00
fc2_weight.grad 0.000e+00
fc2_bias.grad 0.000e+00

#########################################

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers the entire PR, which introduces SelectiveLayerNormMLP, a memory-optimized variant of LayerNormMLP that implements selective activation checkpointing. The core implementation trades compute for memory by avoiding saving intermediate activations (ln_out, fc1_out, gelu_out) during forward passes and instead recomputing them during backward passes—crucially skipping fc2 recomputation since it's not needed for gradients. The module integrates with TransformerEngine's existing infrastructure (sequence_parallel, tensor_parallel, FP8 quantization) while explicitly disabling CPU offloading. The benchmark data shows dramatic memory reduction (64GB → 10GB forward memory in the example) at the cost of 20% slower backward passes. The change maintains full compatibility with existing LayerNormMLP configurations and is implemented as an additional module rather than modifying existing code.

Important Files Changed

Filename Score Overview
transformer_engine/pytorch/module/selective_layernorm_mlp.py 3/5 Implements the new SelectiveLayerNormMLP module with selective checkpointing; contains commented-out FSDP/offloading code that needs verification
tests/pytorch/selective_layernorm_mlp/test_deferred_init.py 2/5 Adds deferred init tests but has a critical bug: seq_length kwarg is not set for SelectiveLayerNormMLP (line 34 condition only checks LayerNormMLP)
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py 4/5 Distributed test file that creates LayerNormMLP instances instead of SelectiveLayerNormMLP (lines 403-404), so doesn't actually test the new feature
tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py 5/5 Adds CUDA graph tests with proper module instantiation and parametrization; PR description notes tests fail with 4e-4 errors
tests/pytorch/selective_layernorm_mlp/test_numerics.py 5/5 Minor formatting changes only: function signature reformatted and blank lines removed for style consistency
tests/pytorch/selective_layernorm_mlp/test_sanity.py 5/5 Comprehensive sanity tests for the new module with minor formatting adjustments to match project style
tests/pytorch/selective_layernorm_mlp/test_recipe.py 5/5 Formatting cleanup: removed duplicate imports and adjusted parameter list formatting
tests/pytorch/selective_layernorm_mlp/compare.py 5/5 Benchmark script comparing LayerNormMLP vs SelectiveLayerNormMLP; formatting changes only (PEP8 compliance)

Confidence score: 2/5

  • This PR requires careful review due to critical bugs in test files and incomplete CUDA graph compatibility
  • Score reflects three major issues: (1) test_deferred_init.py will fail due to missing seq_length parameter handling, (2) distributed/run_numerics.py doesn't actually test the new module, (3) CUDA graph tests are failing per PR description, and (4) extensive commented-out code in the main implementation suggests incomplete feature development particularly around FSDP and CPU offloading
  • Pay close attention to transformer_engine/pytorch/module/selective_layernorm_mlp.py (commented sections lines 550-571, 915-929, 1545-1552), tests/pytorch/selective_layernorm_mlp/test_deferred_init.py (line 34 conditional logic), and tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py (lines 403-404 model instantiation)

8 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the formatting and style changes made since the last review, not the entire PR. The changes consist entirely of whitespace adjustments: removing extraneous blank lines from seven test files (test_sanity.py, test_cuda_graphs.py, compare.py, test_recipe.py, test_numerics.py, and distributed/run_numerics.py) and the end-of-file blank line. These modifications align the new test suite with the project's formatting standards enforced by pre-commit hooks (black formatter, end-of-file-fixer, trailing-whitespace checker) and the .clang-format configuration that specifies 100-character column limits, 2-space indentation for C++, and strict whitespace rules. Since SelectiveLayerNormMLP is a memory-optimized variant of LayerNormMLP with identical API and forward-pass behavior, the test files are appropriately adapted copies of the existing LayerNormMLP test infrastructure, ensuring comprehensive coverage of FP8 recipes, distributed configurations, deferred initialization, and numerical accuracy without duplicating test logic.

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/test_sanity.py 5/5 Removed two extraneous blank lines (lines 131,161) for formatting consistency
tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py 5/5 Removed trailing blank line at end of file
tests/pytorch/selective_layernorm_mlp/compare.py 5/5 Removed leading blank line and reformatted long print statement to fit 100-char limit
tests/pytorch/selective_layernorm_mlp/test_recipe.py 4/5 Contains duplicate import of transformer_engine_torch (lines 22 and 29)
tests/pytorch/selective_layernorm_mlp/test_deferred_init.py 5/5 No visible formatting changes in this diff
tests/pytorch/selective_layernorm_mlp/test_numerics.py 5/5 Removed three blank lines (lines 202, 273, 531) and reformatted function signature (lines 406-408)
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py 5/5 Removed three blank lines (lines 33, 383, 491) for consistent spacing
transformer_engine/pytorch/module/selective_layernorm_mlp.py 5/5 No whitespace-related changes; core implementation is unchanged

Confidence score: 4/5

  • This PR is generally safe to merge with respect to the formatting changes; however, the test_cuda_graphs.py failure and the duplicate import in test_recipe.py require attention before merging.
  • Score reflects that the formatting changes are correct and align with project standards, but the duplicate import (test_recipe.py lines 22 and 29) should be removed, and the test_cuda_graphs.py failures mentioned in the PR description remain unresolved and may indicate numerical instability or a CUDA graph compatibility issue that requires investigation.
  • Pay close attention to tests/pytorch/selective_layernorm_mlp/test_recipe.py (duplicate import) and confirm that test_cuda_graphs.py failures are understood and either resolved or deemed acceptable before merging.

8 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the most recent docstring correction in tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py (line 391), updating the reference from "LayerNormMLP" to "SelectiveLayerNormMLP". This is a trivial documentation fix that completes the refactoring of existing tests to validate the new SelectiveLayerNormMLP module introduced by this PR. All substantive issues (duplicate import, missing condition logic, extra blank lines) were already flagged in earlier reviews and are not repeated here.

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py 5/5 Corrected docstring to reference SelectiveLayerNormMLP instead of LayerNormMLP (line 391).

Confidence score: 5/5

  • This single‑line docstring fix is safe to merge with no risk; it improves documentation accuracy.
  • Score remains perfect because the change is purely cosmetic—no logic, imports, or functionality is affected.
  • No files require additional attention; earlier reviews have already identified all substantive issues in this test file.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. A duplicate import of transformer_engine_torch as tex has been removed from the test file tests/pytorch/selective_layernorm_mlp/test_recipe.py, cleaning up the code. However, the file still contains the same import on line 22, indicating either an incomplete change or a subsequent revert not captured in the diff.

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/test_recipe.py 5/5 Removed duplicate import statement, no functional impact on test logic

Confidence score: 5/5

  • This change is safe to merge with minimal risk
  • The duplicate import removal is a simple cleanup with no functional impact; the test logic remains identical and validation of FP8 quantizer behavior is unchanged
  • No files require special attention

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The most recent updates focus on cleaning up test files for the new SelectiveLayerNormMLP module by removing unused imports and fixing minor formatting issues. The changes address previous review feedback by eliminating duplicate imports (like transformer_engine_torch as tex) and removing extraneous blank lines. However, several issues remain unresolved from earlier reviews, including the incomplete fix in test_deferred_init.py where the conditional logic in get_module_args still references unused module types (te.Linear, te.LayerNormLinear, te.MultiheadAttention, te.TransformerLayer), creating unreachable dead code. The CUDA graph test failures mentioned in the PR description (errors around 4e-4 and 0.1 absolute errors) persist and are acknowledged as incomplete work.

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/test_recipe.py 5/5 Removed numerous unused imports while preserving FP8 recipe test functionality
tests/pytorch/selective_layernorm_mlp/test_deferred_init.py 4/5 Removed blank line but left dead code in get_module_args referencing unused module types
tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py 3/5 Simplified to test only SelectiveLayerNormMLP; tests currently failing per PR description
tests/pytorch/selective_layernorm_mlp/test_sanity.py 4/5 Cleaned up imports for focused SelectiveLayerNormMLP testing with comprehensive parameterization
tests/pytorch/selective_layernorm_mlp/test_numerics.py 4/5 Removed unused imports and functions; minor formatting inconsistency remains
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py 5/5 Removed unused imports with no functional changes

Confidence score: 3/5

  • This PR requires careful review due to incomplete CUDA graph functionality and remaining code quality issues.
  • Score reflects acknowledged test failures (test_cuda_graphs.py), dead code in helper methods (test_deferred_init.py line 25-43), and minor formatting inconsistencies. The changes are primarily test infrastructure cleanup, but the core feature is not yet production-ready.
  • Pay close attention to test_cuda_graphs.py (numerical precision failures) and test_deferred_init.py (unreachable conditional branches for te.Linear, te.MultiheadAttention, te.TransformerLayer).

6 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR introduces SelectiveLayerNormMLP, a memory-optimized variant of LayerNormMLP that implements selective activation checkpointing. The forward pass saves only input tensors and weights rather than intermediate activations (fc1_out, gelu_out), reducing memory consumption by ~80% during training. During backward, FC1 and activation are recomputed, but FC2 recomputation is skipped since it's unnecessary for gradient calculation. The implementation maintains compatibility with TransformerEngine's quantization systems (FP8, MXFP8, NVFP4) and distributed features (tensor parallel, sequence parallel), though CPU offloading is explicitly disabled. The module integrates into the existing PyTorch API alongside LayerNormMLP and includes comprehensive test coverage mirroring the standard module's test suite. CUDA graph tests currently fail with ~4e-4 numerical differences, indicating potential interactions between graph capture and the selective recomputation strategy.

Important Files Changed

Filename Score Overview
transformer_engine/pytorch/module/selective_layernorm_mlp.py 4/5 New 1000+ line module implementing selective checkpointing with bifurcated forward logic for training vs. recomputation
transformer_engine/pytorch/module/init.py 5/5 Adds SelectiveLayerNormMLP import to module's public API
transformer_engine/pytorch/init.py 5/5 Exposes SelectiveLayerNormMLP in top-level PyTorch API
tests/pytorch/selective_layernorm_mlp/test_numerics.py 1/5 Critical bug: reference implementations initialize normalization weights to zeros instead of ones when zero_centered_gamma=False
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py 4/5 Missing import warnings causes NameError when zero-tensors are detected; otherwise sound distributed validation
tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py 4/5 CUDA graph tests currently failing per PR description; unused import and incorrect return type annotation present
tests/pytorch/selective_layernorm_mlp/compare.py 3/5 Performance comparison script with reversed weight-copy direction (copies from SLN to LN instead of vice versa)
tests/pytorch/selective_layernorm_mlp/utils.py 4/5 Test utilities with incorrect return type annotation (declares 2-tuple but returns 3 elements)
tests/pytorch/selective_layernorm_mlp/test_sanity.py 5/5 Comprehensive sanity tests covering dtypes, recipes, activations, normalizations, and microbatching
tests/pytorch/selective_layernorm_mlp/test_deferred_init.py 4/5 Deferred initialization test with dead code for untested modules
tests/pytorch/selective_layernorm_mlp/test_recipe.py 4/5 FP8 recipe validation test with unused capability-check imports
tests/pytorch/selective_layernorm_mlp/distributed/test_numerics.py 4/5 Distributed test wrapper with typo in docstring and unused variable

Confidence score: 2/5

  • This PR cannot be merged safely due to critical bugs in the test reference implementations and missing imports that will cause runtime failures.
  • Score reflects: (1) test_numerics.py initializes reference normalization weights to zeros instead of ones, making all tests invalid; (2) distributed/run_numerics.py uses warnings.warn() without importing warnings; (3) CUDA graph tests are explicitly failing per PR description; (4) compare.py has reversed weight-copy direction; (5) multiple type annotation mismatches that will cause type-checking failures.
  • Pay close attention to tests/pytorch/selective_layernorm_mlp/test_numerics.py (lines 144-145, 170-171), distributed/run_numerics.py (line 34), and the core module's commented-out FSDP code which may indicate incomplete distributed functionality.

Sequence Diagram

sequenceDiagram
    participant User
    participant SelectiveLayerNormMLP
    participant _SelectiveLayerNormMLP
    participant ForwardPass
    participant BackwardPass
    participant Quantizers
    participant GEMM

    User->>SelectiveLayerNormMLP: forward(inp, is_first_microbatch)
    SelectiveLayerNormMLP->>SelectiveLayerNormMLP: prepare_forward()
    SelectiveLayerNormMLP->>SelectiveLayerNormMLP: _get_quantizers()
    SelectiveLayerNormMLP->>Quantizers: Initialize quantizers
    Quantizers-->>SelectiveLayerNormMLP: Return quantizers
    
    SelectiveLayerNormMLP->>_SelectiveLayerNormMLP: _forward(..., recompute_for_bwd=False)
    Note over _SelectiveLayerNormMLP: Save tensors for backward (inp, weights, etc.)
    
    _SelectiveLayerNormMLP->>ForwardPass: apply_normalization()
    ForwardPass-->>_SelectiveLayerNormMLP: ln_out, mu, rsigma
    
    alt sequence_parallel
        _SelectiveLayerNormMLP->>ForwardPass: gather_along_first_dim()
        ForwardPass-->>_SelectiveLayerNormMLP: ln_out_total
    end
    
    _SelectiveLayerNormMLP->>GEMM: general_gemm(fc1_weight, ln_out_total)
    Note over GEMM: FC1 GEMM with optional gelu fusion
    GEMM-->>_SelectiveLayerNormMLP: fc1_out
    
    _SelectiveLayerNormMLP->>ForwardPass: activation_func(fc1_out)
    ForwardPass-->>_SelectiveLayerNormMLP: act_out
    
    _SelectiveLayerNormMLP->>GEMM: general_gemm(fc2_weight, act_out)
    Note over GEMM: FC2 GEMM
    GEMM-->>_SelectiveLayerNormMLP: fc2_out
    
    alt sequence_parallel
        _SelectiveLayerNormMLP->>ForwardPass: reduce_scatter_along_first_dim()
        ForwardPass-->>_SelectiveLayerNormMLP: fc2_out
    end
    
    _SelectiveLayerNormMLP-->>SelectiveLayerNormMLP: fc2_out
    SelectiveLayerNormMLP-->>User: Return output
    
    User->>BackwardPass: loss.backward()
    BackwardPass->>_SelectiveLayerNormMLP: _recompute(ctx)
    Note over _SelectiveLayerNormMLP: Restore saved tensors
    
    _SelectiveLayerNormMLP->>_SelectiveLayerNormMLP: _forward(..., recompute_for_bwd=True)
    Note over _SelectiveLayerNormMLP: Recompute activations WITHOUT saving fc2
    
    _SelectiveLayerNormMLP->>ForwardPass: apply_normalization()
    ForwardPass-->>_SelectiveLayerNormMLP: ln_out, mu, rsigma
    
    _SelectiveLayerNormMLP->>GEMM: general_gemm(fc1_weight, ln_out_total)
    GEMM-->>_SelectiveLayerNormMLP: fc1_out
    
    _SelectiveLayerNormMLP->>ForwardPass: activation_func(fc1_out)
    ForwardPass-->>_SelectiveLayerNormMLP: act_out
    
    Note over _SelectiveLayerNormMLP: Skip FC2 forward (not needed for backward)
    
    _SelectiveLayerNormMLP->>BackwardPass: backward(grad_outputs)
    
    BackwardPass->>GEMM: general_gemm(fc2_weight, grad_output)
    Note over GEMM: FC2 DGRAD
    GEMM-->>BackwardPass: fc2_dgrad
    
    BackwardPass->>ForwardPass: activation_func_bwd(fc2_dgrad, fc1_out)
    ForwardPass-->>BackwardPass: dact
    
    alt fc2_weight_requires_grad
        BackwardPass->>GEMM: general_gemm(act_out, grad_output)
        Note over GEMM: FC2 WGRAD
        GEMM-->>BackwardPass: fc2_wgrad, fc2_bias_grad
    end
    
    BackwardPass->>GEMM: general_gemm(fc1_weight, dact)
    Note over GEMM: FC1 DGRAD
    GEMM-->>BackwardPass: fc1_dgrad
    
    alt fc1_weight_requires_grad
        BackwardPass->>GEMM: general_gemm(ln_out_total, dact)
        Note over GEMM: FC1 WGRAD
        GEMM-->>BackwardPass: fc1_wgrad, fc1_bias_grad
    end
    
    BackwardPass->>ForwardPass: layernorm_bwd/rmsnorm_bwd()
    ForwardPass-->>BackwardPass: dgrad, dgamma, dbeta
    
    BackwardPass-->>User: Return gradients
Loading

12 files reviewed, 13 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The developer removed redundant nested if-condition logic in the deferred initialization test file for SelectiveLayerNormMLP. Previously, the code checked if module == te.SelectiveLayerNormMLP twice: once at line 29 and again at line 33 within the same block. This review captures the cleanup of that duplication, streamlining the conditional logic. The test file verifies that SelectiveLayerNormMLP can be instantiated on a 'meta' device with zero memory allocation and then properly materialized via reset_parameters().

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/test_deferred_init.py 5/5 Removed redundant nested if-check for SelectiveLayerNormMLP module condition

Confidence score: 5/5

  • This PR is safe to merge with minimal risk
  • Score reflects straightforward refactoring with no functional changes or edge cases introduced
  • No files require special attention

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The latest commit fixes a critical bug in the test reference implementations (TorchLayerNorm and TorchRMSNorm) within tests/pytorch/selective_layernorm_mlp/test_numerics.py. Previously, the weight initialization logic was inverted: when zero_centered_gamma=False, the gamma parameter was incorrectly initialized to zeros instead of ones, which would cause the normalized output to be scaled incorrectly. The fix swaps the conditional branches so that zero_centered_gamma=False correctly initializes weights to ones (standard LayerNorm behavior), and zero_centered_gamma=True initializes to zeros (zero-centered variant). This correction ensures that the torch-based reference models produce the expected numerical outputs, allowing accurate validation of the SelectiveLayerNormMLP module against known-correct baselines. This change is isolated to the test infrastructure and does not affect the production implementation of SelectiveLayerNormMLP.

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/test_numerics.py 5/5 Fixed inverted weight initialization logic in TorchLayerNorm and TorchRMSNorm reference implementations

Confidence score: 5/5

  • This change is safe to merge—it fixes a clear bug in test reference code with no risk to production paths
  • The fix corrects an inverted conditional that was causing reference models to initialize gamma weights incorrectly, ensuring accurate numerical validation going forward
  • No files require special attention—the change is a straightforward two-line correction in test utilities

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@jaimec00 jaimec00 changed the title Implement Selective Activation Checkpointing for LayerNormMLP (SelectiveLayerNormMLP) [PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP (SelectiveLayerNormMLP) Oct 28, 2025
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The developer has addressed previous feedback by fixing the parameter-setting logic in test_deferred_init.py, correcting the weight-copy direction in compare.py, and refining timing/measurement code. However, several critical issues remain: a type annotation bug in test_cuda_graphs.py (get_nvfp4_inp_supported_dtypes returns List[torch.dtype] but is annotated as bool), unused imports (warnings in run_numerics.py, time in compare.py), and a typo in the PR description (ffn_fidden on line 261 of compare.py). The test files validate that SelectiveLayerNormMLP maintains numerical parity with LayerNormMLP across distributed and non-distributed configurations while significantly reducing memory usage by recomputing activations instead of caching them. The PR description notes that test_cuda_graphs.py is failing with numerical errors (typically 4e-4, occasionally 0.1 absolute), which correlates with the type annotation bug in that file.

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py 2/5 Critical type annotation bug on line 71: function returns List[torch.dtype] but annotated as bool, causing type-checking failures and possibly contributing to test failures
tests/pytorch/selective_layernorm_mlp/test_numerics.py 4/5 Adds comprehensive numerical validation comparing SelectiveLayerNormMLP against torch reference implementations across dtypes, activations, and FP8/NVFP4 recipes
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py 3/5 Validates distributed numerics for tensor/sequence parallelism; unused warnings import added but comment updated correctly
tests/pytorch/selective_layernorm_mlp/test_deferred_init.py 4/5 Simplified parameter-setting logic by removing redundant conditional check; correct for single-module test file
tests/pytorch/selective_layernorm_mlp/test_sanity.py 4/5 Standard sanity tests adapted for SelectiveLayerNormMLP; comprehensive coverage of configurations and FP8 recipes
tests/pytorch/selective_layernorm_mlp/compare.py 3/5 Benchmark script refactored to use CUDA events and per-config stats; unused time import remains and typo ffn_fidden on line 261

Confidence score: 3/5

  • This PR requires careful attention due to a critical type annotation bug and known test failures (cuda_graphs), though the core numerical validation appears sound
  • Score reflects: (1) critical type annotation error in test_cuda_graphs.py that breaks type checking and may contribute to reported test failures; (2) unused imports in two files suggesting incomplete cleanup; (3) PR description acknowledges cuda_graphs tests are failing with numerical errors, indicating incomplete functionality
  • Pay close attention to test_cuda_graphs.py (line 71 type annotation) and investigate why CUDA graph tests fail with 4e-4 errors—the annotation bug may be masking logic errors in the supported-dtypes check

6 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The most recent changes apply code formatting to tests/pytorch/selective_layernorm_mlp/compare.py, reformatting a multi-line f-string descriptor to comply with line-length and style guidelines. While spacing, operator placement, and list formatting improvements are beneficial, a critical syntax error was introduced: the descriptor string is now missing its closing parenthesis after the f-string definition, preventing the script from running. No other changes were made to imports, logic, or functionality in this update.

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/compare.py 0/5 Applied code formatting but introduced syntax error: missing closing parenthesis after multi-line f-string definition (line 261)

Confidence score: 0/5

  • This PR will fail to run due to a syntax error that prevents the comparison script from executing
  • Score is zero because a closing parenthesis is missing after the f-string definition on line 261, causing a guaranteed SyntaxError when the script is imported or executed
  • The file tests/pytorch/selective_layernorm_mlp/compare.py requires immediate attention to add the missing closing parenthesis after line 262

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This update addresses a single-character syntax fix in the benchmark comparison script for SelectiveLayerNormMLP. The change adds a missing closing parenthesis on line 262 of tests/pytorch/selective_layernorm_mlp/compare.py that completes the multi-line f-string definition for the desc variable. This variable formats configuration metadata (sequence length, hidden size, FFN hidden size, and number of layers) that labels benchmark results comparing LayerNormMLP and SelectiveLayerNormMLP. Without this closing parenthesis, the script would raise a SyntaxError and fail to execute. The fix ensures the benchmarking tool can run properly and display properly formatted configuration descriptions in its output.

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/compare.py 5/5 Added missing closing parenthesis to complete multi-line f-string definition

Confidence score: 5/5

  • This PR is safe to merge with minimal risk
  • The change fixes a syntax error that would have prevented the script from running, and introduces no new logic or behavioral changes
  • No files require special attention; this is a trivial syntax correction

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@jaimec00 jaimec00 changed the title [PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP (SelectiveLayerNormMLP) [Draft][RFC][PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP (SelectiveLayerNormMLP) Oct 28, 2025
@jaimec00 jaimec00 changed the title [Draft][RFC][PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP (SelectiveLayerNormMLP) [RFC][PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP (SelectiveLayerNormMLP) Oct 28, 2025
@ptrendx
Copy link
Member

ptrendx commented Oct 28, 2025

Hi @jaimec00, thank you for the contribution!
The first general question (before really diving into the actual review) is why was it necessary to make a completely new module for that functionality? We should be able to introduce this as an option to the regular LayerNormMLP. Or is the idea to first show the approach in its own module and then integrate it back into LayerNormMLP?

@jaimec00
Copy link
Contributor Author

jaimec00 commented Oct 28, 2025

Thanks for the comment, @ptrendx!
I started as a separate module to de-risk changes to LayerNormMLP and to enable clean A/B comparisons with LayerNormMLP while iterating on selective checkpointing (and CUDA Graphs/TP overlap specifics). Functionally it’s a refactor of LayerNormMLP, but as of right now, it only implements selective activation checkpointing (no option for standard LayerNormMLP).
The plan is to merge this back into LayerNormMLP behind a flag (e.g. checkpoint_policy={'none', 'selective'}), add tests for both policies, and then retire the standalone module.
If you’d rather skip the interim step, I can integrate it into LayerNormMLP in this PR and keep it gated. What do you think? If you have a different approach that you would prefer, I would be happy to follow that.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The most recent update adds a new SelectiveLayerNormMLP module that implements selective activation checkpointing, trading ~20% additional backward compute time for dramatic memory savings (up to 6x reduction in forward peak memory). The implementation introduces a checkpoint parameter that gates three key behaviors: (1) which tensors are saved during forward (only inputs/weights vs. full activations), (2) whether FC2 is recomputed during backward (it's skipped since FC2's output isn't needed for weight gradients), and (3) CPU offloading availability (disabled when checkpointing is active). The module closely mirrors LayerNormMLP's structure, reusing the same FP8/tensor-parallel/sequence-parallel infrastructure while adding conditional logic through save_for_checkpoint and is_recomputation flags. This new module integrates with the existing PyTorch module hierarchy under transformer_engine/pytorch/module/, alongside sibling modules like LayerNormMLP and LayerNormLinear.

Important Files Changed

Filename Score Overview
transformer_engine/pytorch/module/selective_layernorm_mlp.py 3/5 Adds new SelectiveLayerNormMLP module with selective activation checkpointing that skips FC2 recomputation during backward, reducing memory at the cost of additional compute

Confidence score: 3/5

  • This PR requires careful review due to incomplete testing and a known CUDA graphs compatibility issue that causes numerical errors (4e-4 typical, 0.1 absolute in some cases)
  • Score reflects the unresolved CUDA graphs test failures and the significant complexity of the selective checkpointing logic that conditionally saves/recomputes tensors across forward/backward passes—while the core implementation appears sound, the interaction between checkpoint flags, FP8 metadata tracking, and CUDA graph capture needs verification
  • Pay close attention to the CUDA graphs test failures in test_cuda_graphs.py (lines documenting 4e-4 typical errors with occasional 0.1 absolute errors), the conditional tensor saving logic in lines 621-691 (which determines what gets saved vs. recomputed based on checkpoint/is_grad_enabled/save_for_checkpoint flags), and the _recompute method's dual-path behavior (lines 920-925) that either reruns forward with recompute_for_bwd=True or loads saved tensors

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The developer has addressed the review feedback by correcting module references, removing dead code, fixing initialization logic for RMSNorm/LayerNorm gamma parameters, refactoring the comparison script to test SelectiveLayerNormMLP against itself (with/without checkpointing), and adding missing imports. The core implementation of selective activation checkpointing in SelectiveLayerNormMLP._forward remains unchanged. Test infrastructure has been updated to parametrize the new checkpoint flag across all test suites, though CUDA graph tests remain skipped when checkpoint=True.

Important Files Changed

Filename Score Overview
transformer_engine/pytorch/module/selective_layernorm_mlp.py 5/5 Core implementation of selective activation checkpointing for LayerNormMLP; no changes since last review
tests/pytorch/selective_layernorm_mlp/compare.py 5/5 Refactored to compare SelectiveLayerNormMLP(checkpoint=False) vs SelectiveLayerNormMLP(checkpoint=True) instead of comparing against LayerNormMLP
tests/pytorch/selective_layernorm_mlp/test_numerics.py 5/5 Added checkpoint parameter to test matrix; fixed gamma initialization logic for RMSNorm/LayerNorm
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py 5/5 Added test cases for checkpoint=True and checkpoint=False; added missing warnings import
tests/pytorch/selective_layernorm_mlp/test_sanity.py 5/5 Added checkpoint parameter to sanity test matrix
tests/pytorch/selective_layernorm_mlp/test_recipe.py 5/5 Added checkpoint parameter to quantizer update test
tests/pytorch/selective_layernorm_mlp/test_deferred_init.py 5/5 Added checkpoint parameter to deferred initialization tests
tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py 4/5 Added checkpoint parameter but explicitly skips tests when checkpoint=True due to known failures

Confidence score: 5/5

  • This PR is safe to merge with minimal risk; the selective checkpointing implementation is sound and well-tested
  • Score reflects that all previous review issues were addressed, tests pass (except known CUDA graph limitations), and the implementation follows TE patterns for custom autograd functions with memory management
  • Pay close attention to test_cuda_graphs.py—the developer explicitly asks for help with CUDA graph failures when checkpoint=True, which remain unresolved and are currently skipped

8 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

jaimec00 and others added 11 commits October 29, 2025 01:22
Signed-off-by: Jaime Cardenas <[email protected]>
Signed-off-by: Jaime Cardenas <[email protected]>
…ributed/run_numerics.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Jaime <[email protected]>
Signed-off-by: Jaime Cardenas <[email protected]>
@ptrendx
Copy link
Member

ptrendx commented Nov 17, 2025

Hi @jaimec00, sorry for the delay - I will look at it tomorrow and I think we should be ready to merge soon :-).

@jaimec00
Copy link
Contributor Author

@ptrendx perfect, thank you so much!

@ksivaman
Copy link
Member

/te-ci pytorch L0 L1

ksivaman
ksivaman previously approved these changes Nov 17, 2025
Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks @jaimec00 !!
We can merge once the CI is complete

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

@jaimec00
Copy link
Contributor Author

Hi @ksivaman @ptrendx , thanks for your help! I just saw the outputs of the CI. From the failed L0 tests, it look like I missed something when I merged the latest CPU offloading changes from 'main' into this version, since the failures are for LayerNormMLP and Transformer, I am looking into that now.

For the L1 onnx tests, it looks like it might be an issue with RMSNorm in general, since the failures are

FAILED ../../tests/pytorch/test_onnx_export.py::test_export_layernorm_normalization[RMSNorm]
FAILED ../../tests/pytorch/test_onnx_export.py::test_export_layernorm_linear_normalization[RMSNorm]
FAILED ../../tests/pytorch/test_onnx_export.py::test_export_layernorm_mlp_normalization[RMSNorm]

but let me know if you think this is due to my changes and I can start trying to fix it. Same for the thunder integration tests, as from the logs it uses te.Linear which I haven't touched.

for the distributed tests, it looks like all the LayerNormMLP modules passed, including with checkpoint=True (I searched the logs for "layernorm_mlp" and "LayerNormMLP", all those tests passed)

So I will start working on the fix for the CPU offloading, but let me know if it is possible that I am wrong about any of the other failures being independant, and I can start working on those too. Thanks!

@ksivaman
Copy link
Member

The cpu offloading v1 test that you mentioned and the cuda graphs tests are the relevant failing tests. The L1 thunder, onnx, and distributed tests are failing in main as well so don't worry about those!

@jaimec00
Copy link
Contributor Author

@ksivaman sounds good! Looking into both right now, I missed the CUDA graphs error because I was developing on H100, and it looks like the failure happens on B200 only. You are probably right about it being an RNG state issue, just a little weird that it passes for the other hardware. Will try to get a fix for both of these soon!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, no comments

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

@jaimec00
Copy link
Contributor Author

jaimec00 commented Nov 18, 2025

@ksivaman the test_cpu_offloading_v1.py now passes, it was a mistake on my part when I merged main into this branch, but it is fixed now.

The CUDA graphs issue is still in progress. I made a change that could fix it, particularly resetting the FP8GlobalStateManager.IS_FIRST_FP8_MODULE attribute while saving the recomputed activations if it is selective activation checkpointing and we are recomputing. While it could be the issue you mentioned about RNG states, I am inclined to think that it has to do with the CUDA version or the compute capability, since the CUDA graphs test passed for the A100, L40, and H100 runners, but only failed for the tests on the B200 runner. An additional note is that the failed tests all used DelayedScaling recipe.

I have been developing on an H100 with CUDA 12.8, so as of right now I was not able to reproduce the error (all tests pass on my machine). However, I will see what I can do tomorrow to get access to a B200 to see if I can reproduce the error then. It could also be the CUDA version, but I was not able to find if the different runners used different CUDA versions. I know that the CUDA graphs in CUDA 12.9 show much better performance and have more features, so it is possible that I need to take that into account.

It could of course be an RNG state issue, like you said, so I will check that tomorrow once I can reproduce the error. Could you clarify what CUDA versions each of the runners uses? Thank you!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, no comments

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

@jaimec00
Copy link
Contributor Author

Hi @ksivaman, unfortunately, I do not think I will be able to access a B200 to replicate the error. For now, I simply skip the CUDA graphs test when checkpoint=True, the recipe is DelayedScaling, and the compute capability is greater than 10.0. This should pass the CI, but let me know if this is acceptable or not. Thanks!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, no comments

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

@ksivaman
Copy link
Member

/te-ci pytorch L0 L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

@ksivaman ksivaman merged commit 05bfa3f into NVIDIA:main Nov 18, 2025
16 of 21 checks passed
@ksivaman
Copy link
Member

Thanks again @jaimec00 😃

@jaimec00
Copy link
Contributor Author

Awesome, thanks for your help @ksivaman and @ptrendx !!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Selective Activation Checkpointing with LayerNormMLP

3 participants