-
Couldn't load subscription status.
- Fork 533
[PyTorch] FSDP2 Support for TE #2245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
…ngine into fsdp2_issue_fix Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
…ngine into fsdp2_issue_fix
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
…rgst Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
…es when required instead of doing upfront in fwd pass Signed-off-by: Varun Thumbe <[email protected]>
…ling in fsdp hook functions Signed-off-by: Varun Thumbe <[email protected]>
| # Detect if we're within fp8_autocast scope. We'll be in | ||
| # forward pass when within the fp8_autocast scope. Backward | ||
| # pass when outside of the fp8_autocast scope. | ||
| is_in_fp8_autocast = FP8GlobalStateManager.is_fp8_enabled() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming that the forward pass is within an autocast and that the backward pass is outside is quite tenuous. I don't see a logical reason why this has to be true. For an example, te.ops.Linear supports the case with FP8 params and non-FP8 compute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed it now. Using FSDP module's training state to decide to as to whether we are in forward pass.
| # fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") | ||
| # fp8_recipe = Float8CurrentScaling(fp8_format=fp8_format) | ||
| fp8_recipe = MXFP8BlockScaling(fp8_format=fp8_format) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The quantization recipe should be configurable so that we can test all of them.
| for (scale_inv, split_size) in zip(scale_invs, split_sizes_for_scale): | ||
| scale_inv_out = scale_inv.__torch_dispatch__( | ||
| func, | ||
| types, | ||
| [scale_inv, split_size] + list(args[2:]), | ||
| kwargs, | ||
| ) | ||
| out_data.append(scale_inv_out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should take the scale-inv padding into account. Alternatively, we could fall back to the dequantize-requantize impl if the dims are not multiples of 512.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the dims multiple be 128 instead of 512? I am already doing that at the moment and in that case it does default back to dequantize-requantize phase.
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Co-authored-by: Tim Moon <[email protected]> Signed-off-by: vthumbe1503 <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
…ngine into fsdp2_issue_fix Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
…ngine into fsdp2_issue_fix Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR enables FSDP2 (Fully Sharded Data Parallel v2) support for Transformer Engine's quantized tensor types (Float8Tensor, MXFP8Tensor) and modules (Linear, LayerNormLinear, LayerNormMLP, TransformerLayer). The integration allows TE to work with PyTorch's DTensor-based sharding by implementing custom FSDP2 lifecycle hooks (fsdp_pre_all_gather, fsdp_post_all_gather) that handle gathering/sharding of quantized weights with proper transpose cache management. Key architectural changes include detecting when weights are already quantized after FSDP2 all-gather and extracting their embedded quantizers rather than reconfiguring them, handling DTensor parameters during deferred initialization from meta device, and conditionally skipping dgrad computation when inputs don't require gradients. The PR also extends test coverage to validate all three quantization recipes (delayed_scaling, current_scaling, mx_fp8_block_scaling) and adds comprehensive integration tests in run_fsdp2_model.py.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/tensor/float8_tensor.py | 3/5 | Implements FSDP2hooks for Float8Tensor with transpose buffer management and training-state-based usage configuration, but relies on brittle training state detection that may not work for all use cases (e.g., FP8 params with BF16 compute) |
| transformer_engine/pytorch/tensor/mxfp8_tensor.py | 2/5 | Adds FSDP2support and torch dispatch handlers for MXFP8Tensor, but contains critical bugs including None pointer dereferences in copy_/split operations, incorrect scale_strides construction in as_strided, and silent failure mode in view operation |
| transformer_engine/pytorch/module/base.py | 3/5 | Adds DTensor parameter handling for FSDP2 with proper meta-device materialization and amax reduction group configuration, but the param_init_meta re-initialization guard relies on internal FSDP2 behavior |
| transformer_engine/pytorch/module/linear.py | 4/5 | Detects and extracts quantizers from FSDP2-allgathered QuantizedTensors to avoid redundant configuration, cleanly integrating with existing quantization flow |
| transformer_engine/pytorch/module/layernorm_linear.py | 4/5 | Adds quantizer extraction for FSDP2-allgathered weights and conditionally skips dgrad computation when input doesn't require gradients, properly optimizing for FSDP2's sharding patterns |
| transformer_engine/pytorch/module/layernorm_mlp.py | 2/5 | Skips quantizer state initialization for pre-quantized weights but removes columnwise usage update before backward pass without clear justification, potentially breaking gradient computation |
| transformer_engine/pytorch/module/grouped_linear.py | 3/5 | Similar quantizer extraction changes as Linear module for FSDP2 compatibility |
| transformer_engine/pytorch/quantized_tensor.py | 3/5 | Adds recursive list handling for FSDP2's list-based in-place ops and removes data kwarg from make_like factory method, which is a breaking API change that may affect existing callers |
| tests/pytorch/distributed/run_fsdp2_model.py | 4/5 | Comprehensive test harness refactor that validates FSDP2 with multiple layer types and quantization recipes, includes FP8 all-gather correctness check, but contains a commented-out assertion and workaround for TransformerLayer compatibility |
| tests/pytorch/distributed/test_torch_fsdp2.py | 4/5 | Extends test matrix to cover three quantization recipes, providing good coverage for FSDP2 integration validation |
| transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py | 5/5 | Cosmetic whitespace cleanup with no functional changes |
Confidence score: 2/5
- This PR requires careful review due to multiple critical implementation issues and potential breaking changes in core quantization logic.
- Score lowered due to: (1) critical bugs in MXFP8Tensor operations (None pointer dereferences, incorrect stride calculation) that will cause runtime failures, (2) brittle assumptions about training state detection in Float8Tensor hooks that don't hold for all TE use cases, (3) breaking API change in make_like removing data kwarg without validation that existing callers are updated, (4) removed logic in LayerNormMLP that updated columnwise usage before backward without clear replacement, and (5) workarounds in test code (TransformerLayer sharding limitation, commented-out assertions) indicating incomplete integration.
- Pay close attention to transformer_engine/pytorch/tensor/mxfp8_tensor.py (lines 433-436, 337, 402), transformer_engine/pytorch/tensor/float8_tensor.py (lines 736-777 training state detection), transformer_engine/pytorch/module/layernorm_mlp.py (removed lines 541-548), and transformer_engine/pytorch/quantized_tensor.py (lines 504-508 breaking change). All FSDP2 lifecycle hooks and torch dispatch handlers require thorough integration testing with real distributed workloads.
Sequence Diagram
sequenceDiagram
participant User
participant FSDP2
participant TEModule as TE Module (Linear/LayerNormLinear/LayerNormMLP)
participant Float8Tensor
participant MXFP8Tensor
participant Quantizer
participant CUDA as CUDA Kernels
User->>FSDP2: Initialize model with TE modules
FSDP2->>TEModule: Wrap modules for sharding
Note over User,CUDA: Forward Pass - Weight All-Gather
User->>FSDP2: forward()
FSDP2->>Float8Tensor: fsdp_pre_all_gather(mesh, orig_size, ...)
Float8Tensor->>Quantizer: Update amax_reduction_group with mesh
Float8Tensor->>Quantizer: Set usage (rowwise/columnwise) based on training state
Float8Tensor-->>FSDP2: Return (sharded_tensors, metadata)
FSDP2->>FSDP2: All-gather sharded tensors across ranks
FSDP2->>Float8Tensor: fsdp_post_all_gather(all_gather_outputs, metadata, ...)
Float8Tensor->>Float8Tensor: Reconstruct full tensor from gathered data
Float8Tensor->>Float8Tensor: update_usage() with quantizer settings
Float8Tensor-->>FSDP2: Return full Float8Tensor
FSDP2->>TEModule: forward(input)
TEModule->>Quantizer: quantize(input)
Quantizer->>CUDA: Launch FP8 quantization kernel
CUDA-->>Quantizer: Quantized input
TEModule->>CUDA: Execute GEMM (y = xW^T + b)
CUDA-->>TEModule: Output tensor
TEModule-->>User: forward output
Note over User,CUDA: Backward Pass - Gradient Computation & Weight Reduce-Scatter
User->>TEModule: backward(grad_output)
TEModule->>TEModule: Compute dgrad (dx = dy * W)
TEModule->>TEModule: Compute wgrad (dW = dy^T * x)
TEModule->>FSDP2: Return gradients
FSDP2->>FSDP2: Reduce-scatter weight gradients
Note over User,CUDA: Parameter Update with FP8 Weights
FSDP2->>Float8Tensor: Update sharded parameters
Float8Tensor->>TEModule: reset_parameters()
alt Primary weights in FP8
TEModule->>Quantizer: Configure for DTensor (if FSDP2)
Quantizer->>Quantizer: Set amax_reduction_group from device_mesh
TEModule->>Quantizer: quantize(initialized_weight)
Quantizer->>CUDA: Launch FP8 quantization with amax reduction
CUDA-->>TEModule: FP8 quantized parameter (DTensor)
else High precision weights
TEModule->>TEModule: Initialize in high precision
end
Note over User,CUDA: Special Handling for DTensors (FSDP2)
alt Weight is DTensor
TEModule->>Float8Tensor: Access _local_tensor
Float8Tensor->>Quantizer: Update with local shard
Quantizer->>CUDA: Quantize local shard
CUDA-->>Float8Tensor: Quantized local tensor
Float8Tensor->>Float8Tensor: Wrap as DTensor with placements
Float8Tensor-->>TEModule: DTensor with FP8 data
end
Note over User,CUDA: Custom DTensor Operations
FSDP2->>Float8Tensor: split.Tensor / slice.Tensor / as_strided
Float8Tensor->>Float8Tensor: Apply operation to _data
Float8Tensor->>Float8Tensor: Compute corresponding transpose operation
Float8Tensor-->>FSDP2: New Float8Tensor with updated shapes
FSDP2->>Float8Tensor: copy_(src)
alt Both tensors are Float8Tensor
Float8Tensor->>Float8Tensor: Copy _data, _scale_inv, _transpose
else Mixed tensor types
Float8Tensor->>Float8Tensor: Dequantize src, then copy
end
11 files reviewed, 10 comments
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review. The PR continues implementing FSDP2 support for TransformerEngine by addressing code formatting, documentation, and type reference updates. The changes include reformatting an amax_reduction_group assignment in base.py to comply with the 100-character line limit, updating type checks in grouped_linear.py from QuantizedTensor to QuantizedTensorStorage, and adding comprehensive docstrings to FSDP2 integration methods (fsdp_pre_all_gather and fsdp_post_all_gather) in both float8_tensor.py and mxfp8_tensor.py. Additionally, minor stylistic improvements were made, such as replacing or conditions with in tuple checks. The modifications primarily enhance code readability and documentation while aligning the codebase with FSDP2's custom tensor interface requirements.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/base.py | 5/5 | Reformatted amax_reduction_group assignment for DTensor handling, breaking a ternary expression across multiple lines to meet style guidelines |
| transformer_engine/pytorch/module/grouped_linear.py | 4/5 | Updated type checks from QuantizedTensor to QuantizedTensorStorage in weight quantization logic |
| transformer_engine/pytorch/tensor/float8_tensor.py | 3/5 | Added docstrings for FSDP2 methods, refactored conditional logic, and restructured fsdp_post_all_gather control flow |
| transformer_engine/pytorch/tensor/mxfp8_tensor.py | 4/5 | Added comprehensive documentation for FSDP2 integration methods and improved conditional expression readability |
Confidence score: 3/5
- This PR contains primarily documentation and formatting changes with moderate risk due to control flow restructuring in
float8_tensor.py - Score reflects concerns about the control flow change in
fsdp_post_all_gather(lines 813-833 in float8_tensor.py) which replaces an early-return pattern with if/else—the original early return may have intentionally returned None instead of the tuple, and the new structure always returns(out, (data,)), potentially changing FSDP2 behavior if that was intentional - Pay close attention to transformer_engine/pytorch/tensor/float8_tensor.py, particularly the
fsdp_post_all_gathermethod changes
4 files reviewed, 4 comments
| if out is not None: | ||
| out.update_usage( | ||
| rowwise_usage=quantizer.rowwise_usage, | ||
| columnwise_usage=quantizer.columnwise_usage, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: logic flow changed: previously returned early when out exists, now both branches update out.update_usage() - verify this doesn't cause duplicate updates or side effects. Is the out.update_usage() call at line 814-817 guaranteed to be idempotent, or could calling it twice cause issues?
| else: | ||
| fp8_args = { | ||
| "shape": orig_shape, | ||
| "dtype": param_dtype, | ||
| "fp8_scale_inv": fp8_scale_inv, | ||
| "fp8_dtype": fp8_dtype, | ||
| "quantizer": quantizer, | ||
| "requires_grad": False, | ||
| "data": data, | ||
| } | ||
| out = Float8Tensor(**fp8_args) | ||
| out.update_usage( | ||
| rowwise_usage=quantizer.rowwise_usage, | ||
| columnwise_usage=quantizer.columnwise_usage, | ||
| ) | ||
| return out, (data,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: indent level changed from top-level to else block - now both branches return the same tuple (out, (data,)) instead of potentially returning None in the early-return case. What was the original behavior when out was not None? Did it return None or the same tuple?
| orig_size (torch.Size): Original size of the weight tensor.(For us same as self.shape) | ||
| contiguous_orig_stride (Tuple[int]): Original stride of the weight tensor | ||
| (For us same as self.stride()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: typo in docs: 'orig_size' description ends with '.(For us same as self.shape)' - should have space after period
Signed-off-by: Varun Thumbe <[email protected]>
Add default usage handling for rowwise and columnwise data. Signed-off-by: vthumbe1503 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds FSDP2 (Fully Sharded Data Parallel 2) support for Transformer Engine's quantized tensors (Float8 and MXFP8). The implementation enables distributed training with FP8/MXFP8 weights by implementing FSDP2's pre/post all-gather hooks that handle sharding and reconstruction of quantized tensors.
Key Changes:
- Added
fsdp_pre_all_gatherandfsdp_post_all_gathermethods toFloat8TensorandMXFP8Tensorfor FSDP2 integration - Enhanced
__torch_dispatch__inFloat8Tensorto properly handle transpose caching for view, split, new_zeros, and as_strided operations - Fixed bug where transpose cache wasn't being reshaped during view/reshape operations
- Updated
grouped_linear.pyandlayernorm_linear.pyto handle pre-quantized weights from FSDP2 all-gather - Added amax reduction across FSDP mesh for
Float8CurrentScalingQuantizerto ensure consistent scaling - Moved columnwise weight usage updates to backward pass where they're actually needed
- Comprehensive test coverage with multiple recipes (delayed_scaling, current_scaling, mx_fp8_block_scaling)
Confidence Score: 4/5
- This PR is largely safe to merge with comprehensive testing, though edge cases around FSDP2 state management should be monitored
- The implementation is well-structured with proper FSDP2 integration patterns. The transpose caching bug fix is important. However, the complexity of managing quantizer state across FSDP2 operations (rowwise/columnwise usage, amax reduction groups) introduces some risk. The test coverage is excellent covering multiple recipes and configurations.
- Pay attention to
transformer_engine/pytorch/tensor/float8_tensor.pyfor the complex FSDP2 state management and transpose handling logic
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/tensor/float8_tensor.py | 4/5 | Added FSDP2 support with fsdp_pre_all_gather and fsdp_post_all_gather methods, enhanced torch_dispatch for view/split/as_strided ops to properly handle transpose caching |
| transformer_engine/pytorch/tensor/mxfp8_tensor.py | 5/5 | Added FSDP2 support with fsdp_pre_all_gather and fsdp_post_all_gather methods for MXFP8 tensors, handles rowwise/columnwise data based on forward/backward pass |
| transformer_engine/pytorch/module/grouped_linear.py | 4/5 | Updated quantizer handling for pre-quantized weights (FSDP2 all-gathered weights), moved columnwise weight usage updates from forward to backward pass where needed |
Sequence Diagram
sequenceDiagram
participant FSDP2 as FSDP2 Framework
participant Module as TE Module
participant Tensor as Float8Tensor/MXFP8Tensor
participant Quantizer as Quantizer
Note over FSDP2,Quantizer: Forward Pass - Reshard After Forward Enabled
FSDP2->>Tensor: fsdp_pre_all_gather(mesh, orig_size, ...)
Tensor->>Tensor: Copy quantizer, set rowwise usage
Tensor->>Quantizer: Set amax_reduction_group for current scaling
Tensor-->>FSDP2: Return (sharded_data, metadata)
FSDP2->>FSDP2: All-gather sharded data across mesh
FSDP2->>Tensor: fsdp_post_all_gather(all_gather_outputs, metadata, ...)
Tensor->>Tensor: Reconstruct full Float8Tensor with rowwise data
Tensor->>Tensor: update_usage(rowwise=True, columnwise=False)
Tensor-->>FSDP2: Return (full_tensor, internal_tensors)
FSDP2->>Module: forward(input)
Module->>Tensor: Use rowwise quantized weights
Module-->>FSDP2: forward output
FSDP2->>FSDP2: Reshard weights after forward
Note over FSDP2,Quantizer: Backward Pass
FSDP2->>Tensor: fsdp_pre_all_gather(mesh, orig_size, ...)
Tensor->>Tensor: Copy quantizer, set columnwise usage
Tensor-->>FSDP2: Return (sharded_data_transpose, metadata)
FSDP2->>FSDP2: All-gather sharded transpose data
FSDP2->>Tensor: fsdp_post_all_gather(all_gather_outputs, metadata, ...)
Tensor->>Tensor: Reconstruct Float8Tensor with columnwise data
Tensor->>Tensor: update_usage(rowwise=False, columnwise=True)
Tensor-->>FSDP2: Return (full_tensor, internal_tensors)
FSDP2->>Module: backward(grad_output)
Module->>Tensor: Use columnwise quantized weights for dgrad
Module->>Tensor: update_usage(columnwise=True) on weights
Module-->>FSDP2: gradients
FSDP2->>FSDP2: Reduce-scatter gradients and update shards
4 files reviewed, 1 comment
| quantizer = tensor._quantizer.copy() | ||
| out_tensor = Float8Tensor( | ||
| data=func_out, | ||
| shape=data.shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: should use func_out.shape instead of data.shape since new_zeros creates a tensor with a different shape specified in args[1]
| shape=data.shape, | |
| shape=func_out.shape, |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: