Skip to content

Conversation

@vthumbe1503
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@vthumbe1503 vthumbe1503 changed the title FSDP2 Weight Update Fix [Pytorch] FSDP2 Weight Update Fix Oct 8, 2025
@vthumbe1503 vthumbe1503 changed the title [Pytorch] FSDP2 Weight Update Fix [PyTorch] FSDP2 Weight Update Fix Oct 8, 2025
@vthumbe1503 vthumbe1503 changed the title [PyTorch] FSDP2 Weight Update Fix [PyTorch] TE FSDP2 Support for FP8/MXFP8 Oct 17, 2025
Comment on lines 788 to 791
# Detect if we're within fp8_autocast scope. We'll be in
# forward pass when within the fp8_autocast scope. Backward
# pass when outside of the fp8_autocast scope.
is_in_fp8_autocast = FP8GlobalStateManager.is_fp8_enabled()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming that the forward pass is within an autocast and that the backward pass is outside is quite tenuous. I don't see a logical reason why this has to be true. For an example, te.ops.Linear supports the case with FP8 params and non-FP8 compute.

Copy link
Collaborator Author

@vthumbe1503 vthumbe1503 Oct 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it now. Using FSDP module's training state to decide to as to whether we are in forward pass.

Comment on lines 113 to 115
# fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
# fp8_recipe = Float8CurrentScaling(fp8_format=fp8_format)
fp8_recipe = MXFP8BlockScaling(fp8_format=fp8_format)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The quantization recipe should be configurable so that we can test all of them.

Comment on lines 347 to 354
for (scale_inv, split_size) in zip(scale_invs, split_sizes_for_scale):
scale_inv_out = scale_inv.__torch_dispatch__(
func,
types,
[scale_inv, split_size] + list(args[2:]),
kwargs,
)
out_data.append(scale_inv_out)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should take the scale-inv padding into account. Alternatively, we could fall back to the dequantize-requantize impl if the dims are not multiples of 512.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the dims multiple be 128 instead of 512? I am already doing that at the moment and in that case it does default back to dequantize-requantize phase.

@vthumbe1503 vthumbe1503 changed the title [PyTorch] TE FSDP2 Support for FP8/MXFP8 [PyTorch] FSDP2 Support for TE Oct 28, 2025
@vthumbe1503 vthumbe1503 marked this pull request as ready for review October 28, 2025 18:59
@vthumbe1503 vthumbe1503 requested review from denera and ptrendx October 28, 2025 18:59
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR enables FSDP2 (Fully Sharded Data Parallel v2) support for Transformer Engine's quantized tensor types (Float8Tensor, MXFP8Tensor) and modules (Linear, LayerNormLinear, LayerNormMLP, TransformerLayer). The integration allows TE to work with PyTorch's DTensor-based sharding by implementing custom FSDP2 lifecycle hooks (fsdp_pre_all_gather, fsdp_post_all_gather) that handle gathering/sharding of quantized weights with proper transpose cache management. Key architectural changes include detecting when weights are already quantized after FSDP2 all-gather and extracting their embedded quantizers rather than reconfiguring them, handling DTensor parameters during deferred initialization from meta device, and conditionally skipping dgrad computation when inputs don't require gradients. The PR also extends test coverage to validate all three quantization recipes (delayed_scaling, current_scaling, mx_fp8_block_scaling) and adds comprehensive integration tests in run_fsdp2_model.py.

Important Files Changed

Filename Score Overview
transformer_engine/pytorch/tensor/float8_tensor.py 3/5 Implements FSDP2hooks for Float8Tensor with transpose buffer management and training-state-based usage configuration, but relies on brittle training state detection that may not work for all use cases (e.g., FP8 params with BF16 compute)
transformer_engine/pytorch/tensor/mxfp8_tensor.py 2/5 Adds FSDP2support and torch dispatch handlers for MXFP8Tensor, but contains critical bugs including None pointer dereferences in copy_/split operations, incorrect scale_strides construction in as_strided, and silent failure mode in view operation
transformer_engine/pytorch/module/base.py 3/5 Adds DTensor parameter handling for FSDP2 with proper meta-device materialization and amax reduction group configuration, but the param_init_meta re-initialization guard relies on internal FSDP2 behavior
transformer_engine/pytorch/module/linear.py 4/5 Detects and extracts quantizers from FSDP2-allgathered QuantizedTensors to avoid redundant configuration, cleanly integrating with existing quantization flow
transformer_engine/pytorch/module/layernorm_linear.py 4/5 Adds quantizer extraction for FSDP2-allgathered weights and conditionally skips dgrad computation when input doesn't require gradients, properly optimizing for FSDP2's sharding patterns
transformer_engine/pytorch/module/layernorm_mlp.py 2/5 Skips quantizer state initialization for pre-quantized weights but removes columnwise usage update before backward pass without clear justification, potentially breaking gradient computation
transformer_engine/pytorch/module/grouped_linear.py 3/5 Similar quantizer extraction changes as Linear module for FSDP2 compatibility
transformer_engine/pytorch/quantized_tensor.py 3/5 Adds recursive list handling for FSDP2's list-based in-place ops and removes data kwarg from make_like factory method, which is a breaking API change that may affect existing callers
tests/pytorch/distributed/run_fsdp2_model.py 4/5 Comprehensive test harness refactor that validates FSDP2 with multiple layer types and quantization recipes, includes FP8 all-gather correctness check, but contains a commented-out assertion and workaround for TransformerLayer compatibility
tests/pytorch/distributed/test_torch_fsdp2.py 4/5 Extends test matrix to cover three quantization recipes, providing good coverage for FSDP2 integration validation
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py 5/5 Cosmetic whitespace cleanup with no functional changes

Confidence score: 2/5

  • This PR requires careful review due to multiple critical implementation issues and potential breaking changes in core quantization logic.
  • Score lowered due to: (1) critical bugs in MXFP8Tensor operations (None pointer dereferences, incorrect stride calculation) that will cause runtime failures, (2) brittle assumptions about training state detection in Float8Tensor hooks that don't hold for all TE use cases, (3) breaking API change in make_like removing data kwarg without validation that existing callers are updated, (4) removed logic in LayerNormMLP that updated columnwise usage before backward without clear replacement, and (5) workarounds in test code (TransformerLayer sharding limitation, commented-out assertions) indicating incomplete integration.
  • Pay close attention to transformer_engine/pytorch/tensor/mxfp8_tensor.py (lines 433-436, 337, 402), transformer_engine/pytorch/tensor/float8_tensor.py (lines 736-777 training state detection), transformer_engine/pytorch/module/layernorm_mlp.py (removed lines 541-548), and transformer_engine/pytorch/quantized_tensor.py (lines 504-508 breaking change). All FSDP2 lifecycle hooks and torch dispatch handlers require thorough integration testing with real distributed workloads.

Sequence Diagram

sequenceDiagram
    participant User
    participant FSDP2
    participant TEModule as TE Module (Linear/LayerNormLinear/LayerNormMLP)
    participant Float8Tensor
    participant MXFP8Tensor
    participant Quantizer
    participant CUDA as CUDA Kernels

    User->>FSDP2: Initialize model with TE modules
    FSDP2->>TEModule: Wrap modules for sharding
    
    Note over User,CUDA: Forward Pass - Weight All-Gather
    
    User->>FSDP2: forward()
    FSDP2->>Float8Tensor: fsdp_pre_all_gather(mesh, orig_size, ...)
    Float8Tensor->>Quantizer: Update amax_reduction_group with mesh
    Float8Tensor->>Quantizer: Set usage (rowwise/columnwise) based on training state
    Float8Tensor-->>FSDP2: Return (sharded_tensors, metadata)
    
    FSDP2->>FSDP2: All-gather sharded tensors across ranks
    
    FSDP2->>Float8Tensor: fsdp_post_all_gather(all_gather_outputs, metadata, ...)
    Float8Tensor->>Float8Tensor: Reconstruct full tensor from gathered data
    Float8Tensor->>Float8Tensor: update_usage() with quantizer settings
    Float8Tensor-->>FSDP2: Return full Float8Tensor
    
    FSDP2->>TEModule: forward(input)
    TEModule->>Quantizer: quantize(input)
    Quantizer->>CUDA: Launch FP8 quantization kernel
    CUDA-->>Quantizer: Quantized input
    TEModule->>CUDA: Execute GEMM (y = xW^T + b)
    CUDA-->>TEModule: Output tensor
    TEModule-->>User: forward output
    
    Note over User,CUDA: Backward Pass - Gradient Computation & Weight Reduce-Scatter
    
    User->>TEModule: backward(grad_output)
    TEModule->>TEModule: Compute dgrad (dx = dy * W)
    TEModule->>TEModule: Compute wgrad (dW = dy^T * x)
    
    TEModule->>FSDP2: Return gradients
    FSDP2->>FSDP2: Reduce-scatter weight gradients
    
    Note over User,CUDA: Parameter Update with FP8 Weights
    
    FSDP2->>Float8Tensor: Update sharded parameters
    Float8Tensor->>TEModule: reset_parameters()
    
    alt Primary weights in FP8
        TEModule->>Quantizer: Configure for DTensor (if FSDP2)
        Quantizer->>Quantizer: Set amax_reduction_group from device_mesh
        TEModule->>Quantizer: quantize(initialized_weight)
        Quantizer->>CUDA: Launch FP8 quantization with amax reduction
        CUDA-->>TEModule: FP8 quantized parameter (DTensor)
    else High precision weights
        TEModule->>TEModule: Initialize in high precision
    end
    
    Note over User,CUDA: Special Handling for DTensors (FSDP2)
    
    alt Weight is DTensor
        TEModule->>Float8Tensor: Access _local_tensor
        Float8Tensor->>Quantizer: Update with local shard
        Quantizer->>CUDA: Quantize local shard
        CUDA-->>Float8Tensor: Quantized local tensor
        Float8Tensor->>Float8Tensor: Wrap as DTensor with placements
        Float8Tensor-->>TEModule: DTensor with FP8 data
    end
    
    Note over User,CUDA: Custom DTensor Operations
    
    FSDP2->>Float8Tensor: split.Tensor / slice.Tensor / as_strided
    Float8Tensor->>Float8Tensor: Apply operation to _data
    Float8Tensor->>Float8Tensor: Compute corresponding transpose operation
    Float8Tensor-->>FSDP2: New Float8Tensor with updated shapes
    
    FSDP2->>Float8Tensor: copy_(src)
    alt Both tensors are Float8Tensor
        Float8Tensor->>Float8Tensor: Copy _data, _scale_inv, _transpose
    else Mixed tensor types
        Float8Tensor->>Float8Tensor: Dequantize src, then copy
    end
Loading

11 files reviewed, 10 comments

Edit Code Review Agent Settings | Greptile

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review. The PR continues implementing FSDP2 support for TransformerEngine by addressing code formatting, documentation, and type reference updates. The changes include reformatting an amax_reduction_group assignment in base.py to comply with the 100-character line limit, updating type checks in grouped_linear.py from QuantizedTensor to QuantizedTensorStorage, and adding comprehensive docstrings to FSDP2 integration methods (fsdp_pre_all_gather and fsdp_post_all_gather) in both float8_tensor.py and mxfp8_tensor.py. Additionally, minor stylistic improvements were made, such as replacing or conditions with in tuple checks. The modifications primarily enhance code readability and documentation while aligning the codebase with FSDP2's custom tensor interface requirements.

Important Files Changed

Filename Score Overview
transformer_engine/pytorch/module/base.py 5/5 Reformatted amax_reduction_group assignment for DTensor handling, breaking a ternary expression across multiple lines to meet style guidelines
transformer_engine/pytorch/module/grouped_linear.py 4/5 Updated type checks from QuantizedTensor to QuantizedTensorStorage in weight quantization logic
transformer_engine/pytorch/tensor/float8_tensor.py 3/5 Added docstrings for FSDP2 methods, refactored conditional logic, and restructured fsdp_post_all_gather control flow
transformer_engine/pytorch/tensor/mxfp8_tensor.py 4/5 Added comprehensive documentation for FSDP2 integration methods and improved conditional expression readability

Confidence score: 3/5

  • This PR contains primarily documentation and formatting changes with moderate risk due to control flow restructuring in float8_tensor.py
  • Score reflects concerns about the control flow change in fsdp_post_all_gather (lines 813-833 in float8_tensor.py) which replaces an early-return pattern with if/else—the original early return may have intentionally returned None instead of the tuple, and the new structure always returns (out, (data,)), potentially changing FSDP2 behavior if that was intentional
  • Pay close attention to transformer_engine/pytorch/tensor/float8_tensor.py, particularly the fsdp_post_all_gather method changes

4 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +813 to +817
if out is not None:
out.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: logic flow changed: previously returned early when out exists, now both branches update out.update_usage() - verify this doesn't cause duplicate updates or side effects. Is the out.update_usage() call at line 814-817 guaranteed to be idempotent, or could calling it twice cause issues?

Comment on lines +818 to +833
else:
fp8_args = {
"shape": orig_shape,
"dtype": param_dtype,
"fp8_scale_inv": fp8_scale_inv,
"fp8_dtype": fp8_dtype,
"quantizer": quantizer,
"requires_grad": False,
"data": data,
}
out = Float8Tensor(**fp8_args)
out.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
)
return out, (data,)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: indent level changed from top-level to else block - now both branches return the same tuple (out, (data,)) instead of potentially returning None in the early-return case. What was the original behavior when out was not None? Did it return None or the same tuple?

Comment on lines +532 to +534
orig_size (torch.Size): Original size of the weight tensor.(For us same as self.shape)
contiguous_orig_stride (Tuple[int]): Original stride of the weight tensor
(For us same as self.stride())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: typo in docs: 'orig_size' description ends with '.(For us same as self.shape)' - should have space after period

Signed-off-by: Varun Thumbe <[email protected]>
Add default usage handling for rowwise and columnwise data.

Signed-off-by: vthumbe1503 <[email protected]>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds FSDP2 (Fully Sharded Data Parallel 2) support for Transformer Engine's quantized tensors (Float8 and MXFP8). The implementation enables distributed training with FP8/MXFP8 weights by implementing FSDP2's pre/post all-gather hooks that handle sharding and reconstruction of quantized tensors.

Key Changes:

  • Added fsdp_pre_all_gather and fsdp_post_all_gather methods to Float8Tensor and MXFP8Tensor for FSDP2 integration
  • Enhanced __torch_dispatch__ in Float8Tensor to properly handle transpose caching for view, split, new_zeros, and as_strided operations
  • Fixed bug where transpose cache wasn't being reshaped during view/reshape operations
  • Updated grouped_linear.py and layernorm_linear.py to handle pre-quantized weights from FSDP2 all-gather
  • Added amax reduction across FSDP mesh for Float8CurrentScalingQuantizer to ensure consistent scaling
  • Moved columnwise weight usage updates to backward pass where they're actually needed
  • Comprehensive test coverage with multiple recipes (delayed_scaling, current_scaling, mx_fp8_block_scaling)

Confidence Score: 4/5

  • This PR is largely safe to merge with comprehensive testing, though edge cases around FSDP2 state management should be monitored
  • The implementation is well-structured with proper FSDP2 integration patterns. The transpose caching bug fix is important. However, the complexity of managing quantizer state across FSDP2 operations (rowwise/columnwise usage, amax reduction groups) introduces some risk. The test coverage is excellent covering multiple recipes and configurations.
  • Pay attention to transformer_engine/pytorch/tensor/float8_tensor.py for the complex FSDP2 state management and transpose handling logic

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/tensor/float8_tensor.py 4/5 Added FSDP2 support with fsdp_pre_all_gather and fsdp_post_all_gather methods, enhanced torch_dispatch for view/split/as_strided ops to properly handle transpose caching
transformer_engine/pytorch/tensor/mxfp8_tensor.py 5/5 Added FSDP2 support with fsdp_pre_all_gather and fsdp_post_all_gather methods for MXFP8 tensors, handles rowwise/columnwise data based on forward/backward pass
transformer_engine/pytorch/module/grouped_linear.py 4/5 Updated quantizer handling for pre-quantized weights (FSDP2 all-gathered weights), moved columnwise weight usage updates from forward to backward pass where needed

Sequence Diagram

sequenceDiagram
    participant FSDP2 as FSDP2 Framework
    participant Module as TE Module
    participant Tensor as Float8Tensor/MXFP8Tensor
    participant Quantizer as Quantizer
    
    Note over FSDP2,Quantizer: Forward Pass - Reshard After Forward Enabled
    
    FSDP2->>Tensor: fsdp_pre_all_gather(mesh, orig_size, ...)
    Tensor->>Tensor: Copy quantizer, set rowwise usage
    Tensor->>Quantizer: Set amax_reduction_group for current scaling
    Tensor-->>FSDP2: Return (sharded_data, metadata)
    
    FSDP2->>FSDP2: All-gather sharded data across mesh
    
    FSDP2->>Tensor: fsdp_post_all_gather(all_gather_outputs, metadata, ...)
    Tensor->>Tensor: Reconstruct full Float8Tensor with rowwise data
    Tensor->>Tensor: update_usage(rowwise=True, columnwise=False)
    Tensor-->>FSDP2: Return (full_tensor, internal_tensors)
    
    FSDP2->>Module: forward(input)
    Module->>Tensor: Use rowwise quantized weights
    Module-->>FSDP2: forward output
    
    FSDP2->>FSDP2: Reshard weights after forward
    
    Note over FSDP2,Quantizer: Backward Pass
    
    FSDP2->>Tensor: fsdp_pre_all_gather(mesh, orig_size, ...)
    Tensor->>Tensor: Copy quantizer, set columnwise usage
    Tensor-->>FSDP2: Return (sharded_data_transpose, metadata)
    
    FSDP2->>FSDP2: All-gather sharded transpose data
    
    FSDP2->>Tensor: fsdp_post_all_gather(all_gather_outputs, metadata, ...)
    Tensor->>Tensor: Reconstruct Float8Tensor with columnwise data
    Tensor->>Tensor: update_usage(rowwise=False, columnwise=True)
    Tensor-->>FSDP2: Return (full_tensor, internal_tensors)
    
    FSDP2->>Module: backward(grad_output)
    Module->>Tensor: Use columnwise quantized weights for dgrad
    Module->>Tensor: update_usage(columnwise=True) on weights
    Module-->>FSDP2: gradients
    
    FSDP2->>FSDP2: Reduce-scatter gradients and update shards
Loading

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

quantizer = tensor._quantizer.copy()
out_tensor = Float8Tensor(
data=func_out,
shape=data.shape,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: should use func_out.shape instead of data.shape since new_zeros creates a tensor with a different shape specified in args[1]

Suggested change
shape=data.shape,
shape=func_out.shape,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants