From 7f4dfdb0503babf08e9503ca5488d7e8fe54a7ff Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 25 Feb 2025 02:02:44 +0000 Subject: [PATCH] Support MXFP8 all-gather with only column-wise data Signed-off-by: Tim Moon --- transformer_engine/pytorch/distributed.py | 99 ++++++++++++++--------- 1 file changed, 59 insertions(+), 40 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index fe023208d1..e23035ab8a 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -923,20 +923,27 @@ def _all_gather_mxfp8( if out_shape is None: out_shape = [in_shape[0] * world_size] + in_shape[1:] - # Gather MXFP8 data for row-wise usage - if quantizer.rowwise_usage and not quantizer.columnwise_usage: + # Cast input tensor to MXFP8 with required data + if not isinstance(input_, MXFP8TensorBase): + input_ = quantizer(input_) + elif ( + input_.rowwise_data is None and quantizer.rowwise_usage + or input_.columnwise_data is None and quantizer.columnwise_usage + ): + warnings.warn( + "Input and quantizer do not have matching usages. " + "Dequantizing and requantizing to MXFP8." + ) + input_ = quantizer(input_.dequantize()) - # Cast input tensor to MXFP8 if needed - if not isinstance(input_, MXFP8TensorBase): - input_ = quantizer(input_) + # Construct MXFP8 output tensor + out = quantizer.make_empty(out_shape, dtype=input_.dtype, device=input._device) - # Construct MXFP8 output tensor - dtype = torch.float32 - device = "cuda" - if isinstance(input_, MXFP8Tensor): - dtype = input_.dtype - device = input_.device - out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + # Async op handle + handle = None + + # Gather MXFP8 data for row-wise usage + if quantizer.rowwise_usage: # Remove padding from MXFP8 scale-inverses in_scale_inv = input_._rowwise_scale_inv @@ -948,36 +955,48 @@ def _all_gather_mxfp8( out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] # Launch all-gathers - with torch.distributed._coalescing_manager( + if handle is not None: + handle.wait() + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, group=process_group, - device=device, - async_ops=async_op, - ) as coalescing_manager: - torch.distributed.all_gather_into_tensor( - out._rowwise_data, - input_._rowwise_data, - group=process_group, - ) - torch.distributed.all_gather_into_tensor( - out_scale_inv, - in_scale_inv, - group=process_group, - ) - handle = coalescing_manager if async_op else None - return out, handle + ) + handle = torch.distributed.all_gather_into_tensor( + out._rowwise_data, + input_._rowwise_data, + group=process_group, + async_op=async_op, + ) - # Gather in high precision and quantize for column-wise usage - if isinstance(input_, QuantizedTensor): - input_ = input_.dequantize(dtype=torch.bfloat16) - out = torch.empty( - out_shape, - dtype=input_.dtype, - device=input_.device, - memory_format=torch.contiguous_format, - ) - torch.distributed.all_gather_into_tensor(out, input_, group=process_group) - out = quantizer(out) - return out, None + # Gather MXFP8 data for column-wise usage + if quantizer.columnwise_usage: + + # Remove padding from MXFP8 scale-inverses + in_scale_inv = input_._columnwise_scale_inv + out_scale_inv = out._columnwise_scale_inv + flattened_in_shape0 = math.prod(in_shape[:-1]) // 32 + if in_scale_inv.size(0) != flattened_in_shape0: + in_scale_inv = in_scale_inv[:flattened_in_shape0] + out_scale_inv[flattened_in_shape0 * world_size :].zero_() + out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] + + # Launch all-gathers + if handle is not None: + handle.wait() + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, + group=process_group, + ) + handle = torch.distributed.all_gather_into_tensor( + out._columnwise_data, + input_._columnwise_data, + group=process_group, + async_op=async_op, + ) + + return out, handle def gather_along_first_dim(