-
Notifications
You must be signed in to change notification settings - Fork 294
enable tensor parallelism for MXLinear #2434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5c23c6b
ad2ce62
5eb2066
6e3df57
75e6fe7
8bf42da
c0080cd
c6fc48b
4cc1531
42083e2
9d171ad
09c1c58
e511e7b
3562a5e
7a0fd96
2d1545f
20b7db2
7788412
aabeb61
28f32b9
1001602
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,6 @@ | |
|
||
from torchao.prototype.mx_formats.config import MXGemmKernelChoice | ||
from torchao.prototype.mx_formats.constants import ( | ||
BF16_EXP_BIAS, | ||
BLOCK_SIZE_DEFAULT, | ||
DTYPE_FP6_E2M3, | ||
DTYPE_FP6_E3M2, | ||
|
@@ -62,7 +61,6 @@ | |
|
||
# TODO(later): read from somewhere else? | ||
SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23 | ||
EBITS_BF16, MBITS_BF16 = 8, 7 | ||
EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1 | ||
EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3 | ||
EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2 | ||
|
@@ -137,9 +135,7 @@ def _to_mx_rceil( | |
) | ||
|
||
# scale and saturated cast the data elements to max of target dtype | ||
data_lp = torch.clamp( | ||
data_hp * descale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos | ||
) | ||
data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos) | ||
return exponent, data_lp | ||
|
||
|
||
|
@@ -160,22 +156,33 @@ def to_mx( | |
torch.float, | ||
), f"{data_hp.dtype} is not supported yet" | ||
# TODO(future PR): consider supporting padding | ||
assert data_hp.numel() % block_size == 0, "unsupported" | ||
assert data_hp.shape[-1] % block_size == 0, ( | ||
f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}" | ||
) | ||
assert data_hp.is_contiguous(), "unsupported" | ||
assert elem_dtype in SUPPORTED_ELEM_DTYPES, "unsupported" | ||
|
||
# calculate the scale in e8m0 format | ||
|
||
orig_shape = data_hp.shape | ||
# TODO(future PR): fix this line for TP, currently this reshape does not work | ||
# for rank 3 tensor where dim1 is sharded | ||
data_hp = data_hp.reshape(-1, block_size) | ||
data_hp = data_hp.reshape( | ||
*orig_shape[:-1], orig_shape[-1] // block_size, block_size | ||
) | ||
|
||
# find max value of the data | ||
# Note: this only implements the `minimally supported` version of | ||
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf | ||
# section 6.3. | ||
max_abs = torch.amax(torch.abs(data_hp), 1) | ||
max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1) | ||
|
||
# We cast to float32 here because | ||
# in the `max_abs_int32 = max_abs.view(hp_int_dtype)` line below, | ||
# if tensor parallel is enabled then the resulting shape is 2x larger | ||
# than it should be under some conditions, likely because of a bug in | ||
# the `view` op with DTensor and target dtype int16. I reproduce in | ||
# torchtitan but not in a unit test, so not enough info to file a good | ||
# issue in pytorch/pytorch. For now, work around. In the future we should | ||
# debug and fix this properly. | ||
data_hp = data_hp.to(torch.float32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. performance testing showed that with compile on, having this in float32 does not regress performance |
||
max_abs = max_abs.to(torch.float32) | ||
|
||
# Set X to be the largest power-of-two less than or equal to | ||
# max_abs(v), divided by the largest power of two representable | ||
|
@@ -206,17 +213,11 @@ def to_mx( | |
if scaling_mode == ScaleCalculationMode.RCEIL: | ||
scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos) | ||
else: | ||
if data_hp.dtype is torch.float32: | ||
hp_int_dtype = torch.int32 | ||
hp_mbits = MBITS_F32 | ||
hp_ebits = EBITS_F32 | ||
hp_exp_bias = F32_EXP_BIAS | ||
else: | ||
assert data_hp.dtype is torch.bfloat16 | ||
hp_int_dtype = torch.int16 | ||
hp_mbits = MBITS_BF16 | ||
hp_ebits = EBITS_BF16 | ||
hp_exp_bias = BF16_EXP_BIAS | ||
assert data_hp.dtype is torch.float32 | ||
hp_int_dtype = torch.int32 | ||
hp_mbits = MBITS_F32 | ||
hp_ebits = EBITS_F32 | ||
hp_exp_bias = F32_EXP_BIAS | ||
|
||
# rounding before calculating the largest power of 2 | ||
# X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)) | ||
|
@@ -285,7 +286,7 @@ def to_mx( | |
scale_fp32 = torch.clamp(scale_fp32, min=F32_MIN_NORMAL) | ||
|
||
# scale and saturated cast the data elements to max of target dtype | ||
data_lp = data_hp / scale_fp32.unsqueeze(1) | ||
data_lp = data_hp / scale_fp32 | ||
|
||
if ( | ||
elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2) | ||
|
@@ -511,7 +512,6 @@ def __new__( | |
assert scale_e8m0_bits.dtype == torch.float8_e8m0fnu, ( | ||
f"scale_e8m0_bits.dtype must be `torch.float8_e8m0fnu`, got {scale_e8m0_bits.dtype}" | ||
) | ||
assert len(scale_e8m0_bits.shape) == 1, "unsupported" | ||
assert data_bits.dtype in ( | ||
torch.float8_e4m3fn, | ||
torch.float8_e5m2, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -152,15 +152,18 @@ def _test_lowp_mlp_tensor_parallelism_base( | |
sp_model2 = torch.compile(sp_model2) | ||
|
||
x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) | ||
go_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) | ||
x_fp32_tp_input = x_fp32.clone() | ||
go_fp32_tp = go_fp32.clone() | ||
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)]) | ||
go_fp32_sp = distribute_tensor(go_fp32.clone(), mesh, [Shard(0)]) | ||
|
||
tp_out = tp_model(x_fp32_tp_input) | ||
tp_out.sum().backward() | ||
tp_out.backward(go_fp32_tp) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to make sure grad flowing into the last linear is contiguous |
||
sp_out = sp_model(x_fp32_sp_input) | ||
sp_out.sum().backward() | ||
sp_out.backward(go_fp32_sp) | ||
global_out = toy_model_fp8(x_fp32) | ||
global_out.sum().backward() | ||
global_out.backward(go_fp32) | ||
torch.testing.assert_close(tp_out, global_out) | ||
torch.testing.assert_close(sp_out.full_tensor(), global_out) | ||
torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad) | ||
|
@@ -169,7 +172,7 @@ def _test_lowp_mlp_tensor_parallelism_base( | |
) | ||
|
||
sp_out2 = sp_model2(x_fp32_sp_input) | ||
sp_out2.sum().backward() | ||
sp_out2.backward(go_fp32_sp) | ||
torch.testing.assert_close(sp_out2.full_tensor(), global_out) | ||
torch.testing.assert_close( | ||
tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was broken before, caught by enforcing that inner dim is divisible by block size