diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index e7220bff9f..5509eb1cc2 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -10,7 +10,6 @@ TODO(future): make this run in CI """ -import copy import os import pytest @@ -23,12 +22,6 @@ from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -from torch.distributed.tensor.parallel import ( - ColwiseParallel, - PrepareModuleInput, - RowwiseParallel, - parallelize_module, -) from torch.testing._internal.distributed._tensor.common_dtensor import ( ModelArgs, Transformer, @@ -50,14 +43,11 @@ LinearMMConfig, hp_tensor_and_scale_to_float8, ) -from torchao.float8.float8_tensor_parallel import ( - Float8ColwiseParallel, - Float8RowwiseParallel, - PrepareFloat8ModuleInput, -) from torchao.float8.float8_utils import tensor_to_scale from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -from torchao.testing.training.dtensor_utils import ToyModel +from torchao.testing.training.dtensor_utils import ( + _test_lowp_mlp_tensor_parallelism_base, +) torch.set_float32_matmul_precision("high") @@ -193,140 +183,36 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): loss.backward() -def _test_fp8_mlp_tensor_parallelism_base( - mesh: DeviceMesh, size=16, compile: bool = False, rowwise: bool = False -): - device = mesh.device_type - - if rowwise: - config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE) - # hack around config being frozen - # TODO(future PR): we should make this nicer at the config level - object.__setattr__(config, "emulate", True) - else: - config = Float8LinearConfig(emulate=True) - - toy_model = ToyModel().to(device) - toy_model_fp8 = convert_to_float8_training(toy_model, config=config) - - tp_model = copy.deepcopy(toy_model) - tp_model = convert_to_float8_training(tp_model, config=config) - sp_model = copy.deepcopy(toy_model) - sp_model = convert_to_float8_training(sp_model, config=config) - - # For tensorwise scaling, enable float8 all_gather. - # For rowwise scaling, keep high precision all_gather. Motivation for - # not doing float8 all-gather for rowwise: tensors need to be scaled both ways, - # so for float8 all-gather we'd need to send two float8 copies per tensor, - # which is similar # bytes over the wire than just doing bfloat16 all-gather. - if rowwise: - colwise_parallel_cls = ColwiseParallel - rowwise_parallel_cls = RowwiseParallel - prepare_input_cls = PrepareModuleInput - else: - colwise_parallel_cls = Float8ColwiseParallel - rowwise_parallel_cls = Float8RowwiseParallel - prepare_input_cls = PrepareFloat8ModuleInput - - # vanilla TP - tp_model = parallelize_module( - tp_model, - mesh, - { - "ffn.w1": colwise_parallel_cls(), - "ffn.w2": colwise_parallel_cls(), - "ffn.out_proj": rowwise_parallel_cls(), - }, +def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): + tensorwise_config = Float8LinearConfig(emulate=True) + _test_lowp_mlp_tensor_parallelism_base( + mesh, tensorwise_config, size, compile=False, allgather_in_lowp=True ) - # "sequence parallel" mlp computation - sp_model = parallelize_module( - sp_model, - mesh, - { - "ffn": prepare_input_cls( - input_layouts=Shard(1), desired_input_layouts=Replicate() - ), - "ffn.w1": colwise_parallel_cls(), - "ffn.w2": colwise_parallel_cls(), - "ffn.out_proj": rowwise_parallel_cls( - output_layouts=Shard(1), use_local_output=False - ), - }, + rowwise_config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE) + # hack around config being frozen + # TODO(future PR): we should make this nicer at the config level + object.__setattr__(rowwise_config, "emulate", True) + _test_lowp_mlp_tensor_parallelism_base( + mesh, rowwise_config, size, compile=False, allgather_in_lowp=False ) - # prepare_input_cls with specific submodule fqn - sp_model2 = copy.deepcopy(toy_model) - sp_model2 = convert_to_float8_training(sp_model2, config=config) - if rowwise: - prepare_input = prepare_input_cls( - input_layouts=Shard(1), - desired_input_layouts=Replicate(), - ) - else: - prepare_input = prepare_input_cls( - input_layouts=Shard(1), - desired_input_layouts=Replicate(), - fwd_config_submodule_fqn="w2", - ) - - sp_model2 = parallelize_module( - sp_model2, - mesh, - { - "ffn": prepare_input, - "ffn.w1": colwise_parallel_cls(), - "ffn.w2": colwise_parallel_cls(), - "ffn.out_proj": rowwise_parallel_cls( - output_layouts=Shard(1), use_local_output=False - ), - }, - ) - - if compile: - tp_model = torch.compile(tp_model) - sp_model = torch.compile(sp_model) - sp_model2 = torch.compile(sp_model2) - - x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) - x_fp32_tp_input = x_fp32.clone() - x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)]) - - tp_out = tp_model(x_fp32_tp_input) - tp_out.sum().backward() - sp_out = sp_model(x_fp32_sp_input) - sp_out.sum().backward() - global_out = toy_model_fp8(x_fp32) - global_out.sum().backward() - torch.testing.assert_close(tp_out, global_out) - torch.testing.assert_close(sp_out.full_tensor(), global_out) - torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad) - torch.testing.assert_close( - tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad +def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): + tensorwise_config = Float8LinearConfig(emulate=True) + _test_lowp_mlp_tensor_parallelism_base( + mesh, tensorwise_config, size, compile=True, allgather_in_lowp=True ) - sp_out2 = sp_model2(x_fp32_sp_input) - sp_out2.sum().backward() - torch.testing.assert_close(sp_out2.full_tensor(), global_out) - torch.testing.assert_close( - tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad - ) - torch.testing.assert_close( - tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad + rowwise_config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE) + # hack around config being frozen + # TODO(future PR): we should make this nicer at the config level + object.__setattr__(rowwise_config, "emulate", True) + _test_lowp_mlp_tensor_parallelism_base( + mesh, rowwise_config, size, compile=True, allgather_in_lowp=False ) -def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): - _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False, rowwise=False) - _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False, rowwise=True) - - -def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): - _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True, rowwise=False) - _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True, rowwise=True) - - def _test_distribute_fsdp_tensor_subclass(tp_mesh: DeviceMesh): torch.manual_seed(42) model = Transformer(ModelArgs(dropout_p=0.0, weight_tying=False)).cuda() diff --git a/torchao/testing/training/dtensor_utils.py b/torchao/testing/training/dtensor_utils.py index 84e4095263..7ac0360363 100644 --- a/torchao/testing/training/dtensor_utils.py +++ b/torchao/testing/training/dtensor_utils.py @@ -3,9 +3,27 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import copy +import torch import torch.nn as nn import torch.nn.functional as F +from torch.distributed._tensor import Replicate, Shard, distribute_tensor +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + parallelize_module, +) + +from torchao.float8 import Float8LinearConfig +from torchao.float8.float8_linear_utils import convert_to_float8_training +from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, +) class FeedForward(nn.Module): @@ -28,3 +46,123 @@ def __init__(self): def forward(self, x): return self.ffn(x) + + +def _test_lowp_mlp_tensor_parallelism_base( + mesh: DeviceMesh, + config: Float8LinearConfig, + size=16, + compile: bool = False, + allgather_in_lowp: bool = False, +): + device = mesh.device_type + + toy_model = ToyModel().to(device) + toy_model_fp8 = convert_to_float8_training(toy_model, config=config) + + tp_model = copy.deepcopy(toy_model) + tp_model = convert_to_float8_training(tp_model, config=config) + sp_model = copy.deepcopy(toy_model) + sp_model = convert_to_float8_training(sp_model, config=config) + + # For tensorwise scaling, enable float8 all_gather. + # For rowwise scaling, keep high precision all_gather. Motivation for + # not doing float8 all-gather for rowwise: tensors need to be scaled both ways, + # so for float8 all-gather we'd need to send two float8 copies per tensor, + # which is similar # bytes over the wire than just doing bfloat16 all-gather. + if not allgather_in_lowp: + colwise_parallel_cls = ColwiseParallel + rowwise_parallel_cls = RowwiseParallel + prepare_input_cls = PrepareModuleInput + else: + colwise_parallel_cls = Float8ColwiseParallel + rowwise_parallel_cls = Float8RowwiseParallel + prepare_input_cls = PrepareFloat8ModuleInput + + # vanilla TP + tp_model = parallelize_module( + tp_model, + mesh, + { + "ffn.w1": colwise_parallel_cls(), + "ffn.w2": colwise_parallel_cls(), + "ffn.out_proj": rowwise_parallel_cls(), + }, + ) + + # "sequence parallel" mlp computation + sp_model = parallelize_module( + sp_model, + mesh, + { + "ffn": prepare_input_cls( + input_layouts=Shard(1), desired_input_layouts=Replicate() + ), + "ffn.w1": colwise_parallel_cls(), + "ffn.w2": colwise_parallel_cls(), + "ffn.out_proj": rowwise_parallel_cls( + output_layouts=Shard(1), use_local_output=False + ), + }, + ) + + # prepare_input_cls with specific submodule fqn + sp_model2 = copy.deepcopy(toy_model) + sp_model2 = convert_to_float8_training(sp_model2, config=config) + + if not allgather_in_lowp: + prepare_input = prepare_input_cls( + input_layouts=Shard(1), + desired_input_layouts=Replicate(), + ) + else: + prepare_input = prepare_input_cls( + input_layouts=Shard(1), + desired_input_layouts=Replicate(), + fwd_config_submodule_fqn="w2", + ) + + sp_model2 = parallelize_module( + sp_model2, + mesh, + { + "ffn": prepare_input, + "ffn.w1": colwise_parallel_cls(), + "ffn.w2": colwise_parallel_cls(), + "ffn.out_proj": rowwise_parallel_cls( + output_layouts=Shard(1), use_local_output=False + ), + }, + ) + + if compile: + tp_model = torch.compile(tp_model) + sp_model = torch.compile(sp_model) + sp_model2 = torch.compile(sp_model2) + + x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) + x_fp32_tp_input = x_fp32.clone() + x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)]) + + tp_out = tp_model(x_fp32_tp_input) + tp_out.sum().backward() + sp_out = sp_model(x_fp32_sp_input) + sp_out.sum().backward() + global_out = toy_model_fp8(x_fp32) + global_out.sum().backward() + torch.testing.assert_close(tp_out, global_out) + torch.testing.assert_close(sp_out.full_tensor(), global_out) + torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad) + torch.testing.assert_close( + tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad + ) + + sp_out2 = sp_model2(x_fp32_sp_input) + sp_out2.sum().backward() + torch.testing.assert_close(sp_out2.full_tensor(), global_out) + torch.testing.assert_close( + tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad + ) + torch.testing.assert_close( + tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad + )