Skip to content

enable tensor parallelism for MXLinear #2434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jun 24, 2025
12 changes: 5 additions & 7 deletions test/prototype/mx_formats/test_mx_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,24 +68,22 @@ def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4):
)


def _test_mxfp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
# TODO(future PR): assert that the K dim must be divisible by block size,
# today this is silently incorrect if block_size is greater than K
config.block_size = 16
_test_lowp_mlp_tensor_parallelism_base(
mesh, config, size, compile=False, allgather_in_lowp=False
)

# TODO(future PR): compile
_test_lowp_mlp_tensor_parallelism_base(
mesh, config, size, compile=True, allgather_in_lowp=False
)


if __name__ == "__main__":
device_mesh = setup_distributed()
tests = [
_test_dtensor_cast_to_mxfp8,
# TODO(next PR): enable this (current PR got too large, so splitting)
# _test_mxfp8_mlp_tensor_parallelism_eager,
_test_mxfp8_mlp_tensor_parallelism,
]

for test in tqdm(tests, desc="Running tests"):
Expand Down
4 changes: 2 additions & 2 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):
# TODO(future): enable compile support
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_activation_checkpointing():
input_shape = (2, 4)
grad_shape = (2, 8)
input_shape = (16, 4)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was broken before, caught by enforcing that inner dim is divisible by block size

grad_shape = (16, 8)
elem_dtype = torch.float8_e4m3fn

m = nn.Sequential(
Expand Down
2 changes: 1 addition & 1 deletion test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_hello_world(elem_dtype):
data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16)
block_size = 4
_test_mx(data, elem_dtype, block_size)

Expand Down
8 changes: 5 additions & 3 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,7 +1056,7 @@ def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor:

# effective mx block size since we're packing 2 fp4 into 1 uint8
packed_mx_block_size = 3 * mx_block_size // 4
packed_shape = [uint8_data.shape[0], packed_mx_block_size]
packed_shape = [*uint8_data.shape[:-1], packed_mx_block_size]
n_mx_blocks = uint8_data.numel() // mx_block_size

grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),)
Expand Down Expand Up @@ -1337,7 +1337,9 @@ def triton_to_mxfp8_dim1(

# Create scale tensors
col_scale = torch.empty(
(n_cols * n_rows // inner_block_size, 1), dtype=torch.uint8, device=x.device
(n_cols, n_rows // inner_block_size, 1),
dtype=torch.uint8,
device=x.device,
)

# Calculate grid dimensions based on tile size
Expand Down Expand Up @@ -1374,7 +1376,7 @@ def triton_to_mxfp8_dim1_reference(
scale_e8m0_dim1, x_hp_d1_normalized = to_mx(
x_hp_d1, torch.float8_e4m3fn, block_size
)
scale_e8m0_dim1 = scale_e8m0_dim1.unsqueeze(1).view(torch.float8_e8m0fnu)
scale_e8m0_dim1 = scale_e8m0_dim1.view(torch.float8_e8m0fnu)
return (
x_hp_d1_normalized.t(),
scale_e8m0_dim1,
Expand Down
50 changes: 25 additions & 25 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from torchao.prototype.mx_formats.config import MXGemmKernelChoice
from torchao.prototype.mx_formats.constants import (
BF16_EXP_BIAS,
BLOCK_SIZE_DEFAULT,
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
Expand Down Expand Up @@ -62,7 +61,6 @@

# TODO(later): read from somewhere else?
SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23
EBITS_BF16, MBITS_BF16 = 8, 7
EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1
EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3
EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2
Expand Down Expand Up @@ -137,9 +135,7 @@ def _to_mx_rceil(
)

# scale and saturated cast the data elements to max of target dtype
data_lp = torch.clamp(
data_hp * descale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos
)
data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos)
return exponent, data_lp


Expand All @@ -160,22 +156,33 @@ def to_mx(
torch.float,
), f"{data_hp.dtype} is not supported yet"
# TODO(future PR): consider supporting padding
assert data_hp.numel() % block_size == 0, "unsupported"
assert data_hp.shape[-1] % block_size == 0, (
f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}"
)
assert data_hp.is_contiguous(), "unsupported"
assert elem_dtype in SUPPORTED_ELEM_DTYPES, "unsupported"

# calculate the scale in e8m0 format

orig_shape = data_hp.shape
# TODO(future PR): fix this line for TP, currently this reshape does not work
# for rank 3 tensor where dim1 is sharded
data_hp = data_hp.reshape(-1, block_size)
data_hp = data_hp.reshape(
*orig_shape[:-1], orig_shape[-1] // block_size, block_size
)

# find max value of the data
# Note: this only implements the `minimally supported` version of
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
# section 6.3.
max_abs = torch.amax(torch.abs(data_hp), 1)
max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1)

# We cast to float32 here because
# in the `max_abs_int32 = max_abs.view(hp_int_dtype)` line below,
# if tensor parallel is enabled then the resulting shape is 2x larger
# than it should be under some conditions, likely because of a bug in
# the `view` op with DTensor and target dtype int16. I reproduce in
# torchtitan but not in a unit test, so not enough info to file a good
# issue in pytorch/pytorch. For now, work around. In the future we should
# debug and fix this properly.
data_hp = data_hp.to(torch.float32)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

performance testing showed that with compile on, having this in float32 does not regress performance

max_abs = max_abs.to(torch.float32)

# Set X to be the largest power-of-two less than or equal to
# max_abs(v), divided by the largest power of two representable
Expand Down Expand Up @@ -206,17 +213,11 @@ def to_mx(
if scaling_mode == ScaleCalculationMode.RCEIL:
scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos)
else:
if data_hp.dtype is torch.float32:
hp_int_dtype = torch.int32
hp_mbits = MBITS_F32
hp_ebits = EBITS_F32
hp_exp_bias = F32_EXP_BIAS
else:
assert data_hp.dtype is torch.bfloat16
hp_int_dtype = torch.int16
hp_mbits = MBITS_BF16
hp_ebits = EBITS_BF16
hp_exp_bias = BF16_EXP_BIAS
assert data_hp.dtype is torch.float32
hp_int_dtype = torch.int32
hp_mbits = MBITS_F32
hp_ebits = EBITS_F32
hp_exp_bias = F32_EXP_BIAS

# rounding before calculating the largest power of 2
# X = 2^(floor(log2(rounding(max_abs(v)))-max_exp))
Expand Down Expand Up @@ -285,7 +286,7 @@ def to_mx(
scale_fp32 = torch.clamp(scale_fp32, min=F32_MIN_NORMAL)

# scale and saturated cast the data elements to max of target dtype
data_lp = data_hp / scale_fp32.unsqueeze(1)
data_lp = data_hp / scale_fp32

if (
elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
Expand Down Expand Up @@ -511,7 +512,6 @@ def __new__(
assert scale_e8m0_bits.dtype == torch.float8_e8m0fnu, (
f"scale_e8m0_bits.dtype must be `torch.float8_e8m0fnu`, got {scale_e8m0_bits.dtype}"
)
assert len(scale_e8m0_bits.shape) == 1, "unsupported"
assert data_bits.dtype in (
torch.float8_e4m3fn,
torch.float8_e5m2,
Expand Down
11 changes: 7 additions & 4 deletions torchao/testing/training/dtensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,18 @@ def _test_lowp_mlp_tensor_parallelism_base(
sp_model2 = torch.compile(sp_model2)

x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
go_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
x_fp32_tp_input = x_fp32.clone()
go_fp32_tp = go_fp32.clone()
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])
go_fp32_sp = distribute_tensor(go_fp32.clone(), mesh, [Shard(0)])

tp_out = tp_model(x_fp32_tp_input)
tp_out.sum().backward()
tp_out.backward(go_fp32_tp)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to make sure grad flowing into the last linear is contiguous

sp_out = sp_model(x_fp32_sp_input)
sp_out.sum().backward()
sp_out.backward(go_fp32_sp)
global_out = toy_model_fp8(x_fp32)
global_out.sum().backward()
global_out.backward(go_fp32)
torch.testing.assert_close(tp_out, global_out)
torch.testing.assert_close(sp_out.full_tensor(), global_out)
torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad)
Expand All @@ -169,7 +172,7 @@ def _test_lowp_mlp_tensor_parallelism_base(
)

sp_out2 = sp_model2(x_fp32_sp_input)
sp_out2.sum().backward()
sp_out2.backward(go_fp32_sp)
torch.testing.assert_close(sp_out2.full_tensor(), global_out)
torch.testing.assert_close(
tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad
Expand Down
Loading