diff --git a/test/prototype/mx_formats/test_mx_dtensor.py b/test/prototype/mx_formats/test_mx_dtensor.py index bfc930c579..4aefb3874e 100644 --- a/test/prototype/mx_formats/test_mx_dtensor.py +++ b/test/prototype/mx_formats/test_mx_dtensor.py @@ -68,24 +68,22 @@ def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4): ) -def _test_mxfp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): +def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16): config = MXLinearConfig.from_recipe_name("mxfp8_emulated") - # TODO(future PR): assert that the K dim must be divisible by block size, - # today this is silently incorrect if block_size is greater than K config.block_size = 16 _test_lowp_mlp_tensor_parallelism_base( mesh, config, size, compile=False, allgather_in_lowp=False ) - - # TODO(future PR): compile + _test_lowp_mlp_tensor_parallelism_base( + mesh, config, size, compile=True, allgather_in_lowp=False + ) if __name__ == "__main__": device_mesh = setup_distributed() tests = [ _test_dtensor_cast_to_mxfp8, - # TODO(next PR): enable this (current PR got too large, so splitting) - # _test_mxfp8_mlp_tensor_parallelism_eager, + _test_mxfp8_mlp_tensor_parallelism, ] for test in tqdm(tests, desc="Running tests"): diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index bfb6742d14..b48b21bbf9 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -190,8 +190,8 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn): # TODO(future): enable compile support @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_activation_checkpointing(): - input_shape = (2, 4) - grad_shape = (2, 8) + input_shape = (16, 4) + grad_shape = (16, 8) elem_dtype = torch.float8_e4m3fn m = nn.Sequential( diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 6dfd33f9c7..f0124dd47b 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -72,7 +72,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_hello_world(elem_dtype): - data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16) + data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16) block_size = 4 _test_mx(data, elem_dtype, block_size) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index f96e73a55a..72cbba1802 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1056,7 +1056,7 @@ def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: # effective mx block size since we're packing 2 fp4 into 1 uint8 packed_mx_block_size = 3 * mx_block_size // 4 - packed_shape = [uint8_data.shape[0], packed_mx_block_size] + packed_shape = [*uint8_data.shape[:-1], packed_mx_block_size] n_mx_blocks = uint8_data.numel() // mx_block_size grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) @@ -1337,7 +1337,9 @@ def triton_to_mxfp8_dim1( # Create scale tensors col_scale = torch.empty( - (n_cols * n_rows // inner_block_size, 1), dtype=torch.uint8, device=x.device + (n_cols, n_rows // inner_block_size, 1), + dtype=torch.uint8, + device=x.device, ) # Calculate grid dimensions based on tile size @@ -1374,7 +1376,7 @@ def triton_to_mxfp8_dim1_reference( scale_e8m0_dim1, x_hp_d1_normalized = to_mx( x_hp_d1, torch.float8_e4m3fn, block_size ) - scale_e8m0_dim1 = scale_e8m0_dim1.unsqueeze(1).view(torch.float8_e8m0fnu) + scale_e8m0_dim1 = scale_e8m0_dim1.view(torch.float8_e8m0fnu) return ( x_hp_d1_normalized.t(), scale_e8m0_dim1, diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index ef9ae42fcd..e98878af77 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -25,7 +25,6 @@ from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.constants import ( - BF16_EXP_BIAS, BLOCK_SIZE_DEFAULT, DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, @@ -62,7 +61,6 @@ # TODO(later): read from somewhere else? SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23 -EBITS_BF16, MBITS_BF16 = 8, 7 EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1 EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3 EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2 @@ -137,9 +135,7 @@ def _to_mx_rceil( ) # scale and saturated cast the data elements to max of target dtype - data_lp = torch.clamp( - data_hp * descale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos - ) + data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos) return exponent, data_lp @@ -160,22 +156,33 @@ def to_mx( torch.float, ), f"{data_hp.dtype} is not supported yet" # TODO(future PR): consider supporting padding - assert data_hp.numel() % block_size == 0, "unsupported" + assert data_hp.shape[-1] % block_size == 0, ( + f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}" + ) assert data_hp.is_contiguous(), "unsupported" assert elem_dtype in SUPPORTED_ELEM_DTYPES, "unsupported" - # calculate the scale in e8m0 format - orig_shape = data_hp.shape - # TODO(future PR): fix this line for TP, currently this reshape does not work - # for rank 3 tensor where dim1 is sharded - data_hp = data_hp.reshape(-1, block_size) + data_hp = data_hp.reshape( + *orig_shape[:-1], orig_shape[-1] // block_size, block_size + ) # find max value of the data # Note: this only implements the `minimally supported` version of # https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf # section 6.3. - max_abs = torch.amax(torch.abs(data_hp), 1) + max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1) + + # We cast to float32 here because + # in the `max_abs_int32 = max_abs.view(hp_int_dtype)` line below, + # if tensor parallel is enabled then the resulting shape is 2x larger + # than it should be under some conditions, likely because of a bug in + # the `view` op with DTensor and target dtype int16. I reproduce in + # torchtitan but not in a unit test, so not enough info to file a good + # issue in pytorch/pytorch. For now, work around. In the future we should + # debug and fix this properly. + data_hp = data_hp.to(torch.float32) + max_abs = max_abs.to(torch.float32) # Set X to be the largest power-of-two less than or equal to # max_abs(v), divided by the largest power of two representable @@ -206,17 +213,11 @@ def to_mx( if scaling_mode == ScaleCalculationMode.RCEIL: scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos) else: - if data_hp.dtype is torch.float32: - hp_int_dtype = torch.int32 - hp_mbits = MBITS_F32 - hp_ebits = EBITS_F32 - hp_exp_bias = F32_EXP_BIAS - else: - assert data_hp.dtype is torch.bfloat16 - hp_int_dtype = torch.int16 - hp_mbits = MBITS_BF16 - hp_ebits = EBITS_BF16 - hp_exp_bias = BF16_EXP_BIAS + assert data_hp.dtype is torch.float32 + hp_int_dtype = torch.int32 + hp_mbits = MBITS_F32 + hp_ebits = EBITS_F32 + hp_exp_bias = F32_EXP_BIAS # rounding before calculating the largest power of 2 # X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)) @@ -285,7 +286,7 @@ def to_mx( scale_fp32 = torch.clamp(scale_fp32, min=F32_MIN_NORMAL) # scale and saturated cast the data elements to max of target dtype - data_lp = data_hp / scale_fp32.unsqueeze(1) + data_lp = data_hp / scale_fp32 if ( elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2) @@ -511,7 +512,6 @@ def __new__( assert scale_e8m0_bits.dtype == torch.float8_e8m0fnu, ( f"scale_e8m0_bits.dtype must be `torch.float8_e8m0fnu`, got {scale_e8m0_bits.dtype}" ) - assert len(scale_e8m0_bits.shape) == 1, "unsupported" assert data_bits.dtype in ( torch.float8_e4m3fn, torch.float8_e5m2, diff --git a/torchao/testing/training/dtensor_utils.py b/torchao/testing/training/dtensor_utils.py index 815ee20969..7ebf67d53c 100644 --- a/torchao/testing/training/dtensor_utils.py +++ b/torchao/testing/training/dtensor_utils.py @@ -152,15 +152,18 @@ def _test_lowp_mlp_tensor_parallelism_base( sp_model2 = torch.compile(sp_model2) x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) + go_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) x_fp32_tp_input = x_fp32.clone() + go_fp32_tp = go_fp32.clone() x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)]) + go_fp32_sp = distribute_tensor(go_fp32.clone(), mesh, [Shard(0)]) tp_out = tp_model(x_fp32_tp_input) - tp_out.sum().backward() + tp_out.backward(go_fp32_tp) sp_out = sp_model(x_fp32_sp_input) - sp_out.sum().backward() + sp_out.backward(go_fp32_sp) global_out = toy_model_fp8(x_fp32) - global_out.sum().backward() + global_out.backward(go_fp32) torch.testing.assert_close(tp_out, global_out) torch.testing.assert_close(sp_out.full_tensor(), global_out) torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad) @@ -169,7 +172,7 @@ def _test_lowp_mlp_tensor_parallelism_base( ) sp_out2 = sp_model2(x_fp32_sp_input) - sp_out2.sum().backward() + sp_out2.backward(go_fp32_sp) torch.testing.assert_close(sp_out2.full_tensor(), global_out) torch.testing.assert_close( tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad