From db6b39acd00027b100ab5c6a45890b03087c55f9 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 10 Sep 2024 10:48:38 -0700 Subject: [PATCH] Enable HSDP This PR enables HSDP. ghstack-source-id: c85046adbfa1a6537bcbcb45d98ef27c2ecb3044 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/518 --- estimation.py | 4 +- test_runner.py | 40 ++++++++++--- torchtitan/config_manager.py | 34 ++++++++--- torchtitan/parallelisms/parallel_dims.py | 63 +++++++++++++++----- torchtitan/parallelisms/parallelize_llama.py | 14 +++-- train.py | 4 +- train_configs/debug_model.toml | 3 +- train_configs/llama2_13b.toml | 3 +- train_configs/llama2_70b.toml | 3 +- train_configs/llama2_7b.toml | 3 +- train_configs/llama3_405b.toml | 3 +- train_configs/llama3_70b.toml | 3 +- train_configs/llama3_8b.toml | 3 +- 13 files changed, 133 insertions(+), 47 deletions(-) diff --git a/estimation.py b/estimation.py index 13ccd4c16..f58907c6f 100644 --- a/estimation.py +++ b/estimation.py @@ -64,12 +64,12 @@ def estimate_memory(job_config: JobConfig): job_config.experimental.enable_compiled_autograd = False parallel_dims = ParallelDims( - dp=job_config.training.data_parallel_degree, + dp_shard=job_config.training.data_parallel_shard_degree, + dp_replicate=job_config.training.data_parallel_replicate_degree, tp=job_config.training.tensor_parallel_degree, pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, - dp_type=job_config.training.data_parallel_type, ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") diff --git a/test_runner.py b/test_runner.py index 7d6d4063b..6d706a641 100755 --- a/test_runner.py +++ b/test_runner.py @@ -157,7 +157,7 @@ def build_test_list(): "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule 1f1b", - "--training.data_parallel_degree 1", + "--training.data_parallel_shard_degree 1", ], ], "PP 1D test 1f1b", @@ -172,7 +172,7 @@ def build_test_list(): "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule gpipe", - "--training.data_parallel_degree 1", + "--training.data_parallel_shard_degree 1", ], ], "PP 1D test gpipe", @@ -187,7 +187,7 @@ def build_test_list(): "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule 1f1b", - "--training.data_parallel_degree 2", + "--training.data_parallel_shard_degree 2", ], ], "PP+DP 1f1b 2D test", @@ -201,7 +201,7 @@ def build_test_list(): "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule gpipe", - "--training.data_parallel_degree 2", + "--training.data_parallel_shard_degree 2", ], ], "PP+DP gpipe 2D test", @@ -227,7 +227,7 @@ def build_test_list(): "--checkpoint.enable_checkpoint", "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.4", - "--training.data_parallel_degree 2", + "--training.data_parallel_shard_degree 2", "--training.tensor_parallel_degree 2", ], [ @@ -235,7 +235,7 @@ def build_test_list(): "--checkpoint.enable_checkpoint", "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.4", - "--training.data_parallel_degree 2", + "--training.data_parallel_shard_degree 2", "--training.tensor_parallel_degree 2", ], ], @@ -249,7 +249,7 @@ def build_test_list(): [ "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.4", - "--training.data_parallel_degree 2", + "--training.data_parallel_shard_degree 2", "--training.tensor_parallel_degree 2", "--training.compile", ], @@ -285,13 +285,37 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.data_parallel_type ddp", + "--training.data_parallel_shard_degree=1", + "--training.data_parallel_replicate_degree=4", ] ], "DDP", "ddp", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--training.data_parallel_shard_degree=2", + "--training.data_parallel_replicate_degree=2", + ] + ], + "HSDP", + "hsdp", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--training.data_parallel_shard_degree=2", + "--training.data_parallel_replicate_degree=2", + "--training.tensor_parallel_degree=2", + ] + ], + "HSDP+TP", + "hsdp+tp", + ngpu=8, + ), OverrideDefinitions( [ [ diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 3ba1d1029..67c82d53f 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -224,10 +224,34 @@ def __init__(self): help="How many train steps to run", ) self.parser.add_argument( - "--training.data_parallel_degree", + "--training.data_parallel_replicate_degree", + type=int, + default=1, + help=""" + The `data_parallel_replicate_degree` argument specifies the degree of + data parallelism for weight replication. When this value is greater + than 1, weights will be replicated across `data_parallel_replicate_degree` + ranks. If `data_parallel_shard_degree` is also greater than 1, the parallelism + method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the + parallelism method used is DDP (Distributed Data Parallelism). + 1 means disabled.""", + ) + self.parser.add_argument( + "--training.data_parallel_shard_degree", type=int, default=-1, - help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.", + help=""" + The `data_parallel_shard_degree` argument specifies the degree of data + parallelism for weight sharding. When this value is greater than 1, weights + will be sharded across `data_parallel_shard_degree` ranks. If + `data_parallel_replicate_degree` is also greater than 1, the parallelism + method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the + parallelism method used is FSDP (Fully Sharded Data Parallelism). + + -1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that + only one of `data_parallel_replicate_degree` and `data_parallel_shard_degree` + can be negative. + 1 means disabled.""", ) self.parser.add_argument( "--training.tensor_parallel_degree", @@ -297,12 +321,6 @@ def __init__(self): The default value will be the number of pipeline stages, if unspecified. """, ) - self.parser.add_argument( - "--training.data_parallel_type", - type=str, - default="fsdp", - help="Data parallelism type. TorchTitan currently supports FSDP and DDP.", - ) self.parser.add_argument( "--experimental.enable_compiled_autograd", action="store_true", diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index 22c114eda..2e2aacc75 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -13,45 +13,78 @@ @dataclass class ParallelDims: - dp: int + dp_replicate: int + dp_shard: int tp: int pp: int world_size: int enable_loss_parallel: bool - dp_type: str def __post_init__(self): - self.dp_type = self.dp_type.lower() self._validate() def _validate(self): - dp, tp, pp = self.dp, self.tp, self.pp - if dp == -1: - self.dp = dp = self.world_size // (tp * pp) - assert dp >= 1, dp + dp_replicate, dp_shard, tp, pp = ( + self.dp_replicate, + self.dp_shard, + self.tp, + self.pp, + ) + for d in (dp_replicate, tp, pp): + assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" + assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." + + dp = dp_replicate * dp_shard + if dp < 0: + dp = self.world_size // (tp * pp) + self.dp_shard = dp_shard = dp // dp_replicate + + assert dp_replicate >= 1 + assert dp_shard >= 1 assert tp >= 1, tp assert pp >= 1, pp - assert ( - dp * tp * pp == self.world_size - ), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" - assert self.dp_type in ("fsdp", "ddp") + assert dp_replicate * dp_shard * tp * pp == self.world_size, ( + f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " + f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" + ) def build_mesh(self, device_type): dims = [] names = [] for d, name in zip( - [self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True + [self.pp, self.dp_replicate, self.dp_shard, self.tp], + ["pp", "dp_replicate", "dp_shard", "tp"], + strict=True, ): if d > 1: dims.append(d) - names.append(name) + if (name == "dp_replicate" and self.dp_shard == 1) or ( + name == "dp_shard" and self.dp_replicate == 1 + ): + names.append("dp") + else: + names.append(name) + logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") names = tuple(names) - return init_device_mesh(device_type, dims, mesh_dim_names=names) + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + # Create all the submesh here to ensure all required process groups are + # initialized + if self.dp_replicate > 1 and self.dp_shard > 1: + mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp") + return mesh @property def dp_enabled(self): - return self.dp > 1 + return self.dp_replicate > 1 or self.dp_shard > 1 + + @property + def dp_replicate_enabled(self): + return self.dp_replicate > 1 + + @property + def dp_shard_enabled(self): + return self.dp_shard > 1 @property def tp_enabled(self): diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index aa07f25fb..fc26703db 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -73,9 +73,11 @@ def parallelize_llama( apply_compile(model) if parallel_dims.dp_enabled: - if parallel_dims.dp_type == "fsdp": - dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh - assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names + if parallel_dims.dp_shard_enabled: + if parallel_dims.dp_replicate_enabled: + dp_mesh = world_mesh["dp_replicate", "dp_shard"] + else: + dp_mesh = world_mesh["dp"] apply_fsdp( model, @@ -87,6 +89,10 @@ def parallelize_llama( tp_enabled=parallel_dims.tp_enabled, pp_enabled=parallel_dims.pp_enabled, ) + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") else: if world_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") @@ -322,8 +328,6 @@ def apply_fsdp( ) fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) - logger.info("Applied FSDP to the model") - def apply_ddp( model: nn.Module, diff --git a/train.py b/train.py index ffea00a9d..d1973b6d6 100644 --- a/train.py +++ b/train.py @@ -59,12 +59,12 @@ def main(job_config: JobConfig): # init distributed world_size = int(os.environ["WORLD_SIZE"]) parallel_dims = ParallelDims( - dp=job_config.training.data_parallel_degree, + dp_shard=job_config.training.data_parallel_shard_degree, + dp_replicate=job_config.training.data_parallel_replicate_degree, tp=job_config.training.tensor_parallel_degree, pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, - dp_type=job_config.training.data_parallel_type, ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") torch.cuda.set_device(device) diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index af5472148..bb3cd3537 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -35,7 +35,8 @@ seq_len = 2048 warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 10 -data_parallel_degree = -1 +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 tensor_parallel_degree = 1 compile = false dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) diff --git a/train_configs/llama2_13b.toml b/train_configs/llama2_13b.toml index df2f6bb3d..3230b208b 100644 --- a/train_configs/llama2_13b.toml +++ b/train_configs/llama2_13b.toml @@ -31,7 +31,8 @@ seq_len = 4096 warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 1000 -data_parallel_degree = -1 +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 tensor_parallel_degree = 1 compile = false dataset = "c4" diff --git a/train_configs/llama2_70b.toml b/train_configs/llama2_70b.toml index 354ebe11f..e7c920c65 100644 --- a/train_configs/llama2_70b.toml +++ b/train_configs/llama2_70b.toml @@ -31,7 +31,8 @@ seq_len = 4096 warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 1000 -data_parallel_degree = -1 +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 tensor_parallel_degree = 8 # 8-way TP compile = false dataset = "c4" diff --git a/train_configs/llama2_7b.toml b/train_configs/llama2_7b.toml index e2b0e78d2..5ffaaeca7 100644 --- a/train_configs/llama2_7b.toml +++ b/train_configs/llama2_7b.toml @@ -30,7 +30,8 @@ seq_len = 2048 warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 1000 -data_parallel_degree = -1 +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 tensor_parallel_degree = 1 # dp-only would be sufficient for 7B compile = false dataset = "c4" diff --git a/train_configs/llama3_405b.toml b/train_configs/llama3_405b.toml index 1a83301fb..c7723ef31 100644 --- a/train_configs/llama3_405b.toml +++ b/train_configs/llama3_405b.toml @@ -31,7 +31,8 @@ seq_len = 8192 warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 3000 -data_parallel_degree = -1 +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 tensor_parallel_degree = 8 # 8-way TP compile = true dataset = "c4" diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index 470149a58..fb6d5f50b 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -31,7 +31,8 @@ seq_len = 8192 warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 1000 -data_parallel_degree = -1 +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 tensor_parallel_degree = 8 # 8-way TP compile = false dataset = "c4" diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 3d0c5160d..e0c5bd03e 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -31,7 +31,8 @@ seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 1000 -data_parallel_degree = -1 +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 tensor_parallel_degree = 1 compile = false dataset = "c4"