Skip to content

Commit

Permalink
Enable HSDP
Browse files Browse the repository at this point in the history
This PR enables HSDP.

ghstack-source-id: c85046adbfa1a6537bcbcb45d98ef27c2ecb3044
Pull Request resolved: pytorch#518
  • Loading branch information
fegin committed Sep 10, 2024
1 parent 1923ce4 commit db6b39a
Show file tree
Hide file tree
Showing 13 changed files with 133 additions and 47 deletions.
4 changes: 2 additions & 2 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ def estimate_memory(job_config: JobConfig):
job_config.experimental.enable_compiled_autograd = False

parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
dp_shard=job_config.training.data_parallel_shard_degree,
dp_replicate=job_config.training.data_parallel_replicate_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
dp_type=job_config.training.data_parallel_type,
)

device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
Expand Down
40 changes: 32 additions & 8 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 1",
"--training.data_parallel_shard_degree 1",
],
],
"PP 1D test 1f1b",
Expand All @@ -172,7 +172,7 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 1",
"--training.data_parallel_shard_degree 1",
],
],
"PP 1D test gpipe",
Expand All @@ -187,7 +187,7 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 2",
"--training.data_parallel_shard_degree 2",
],
],
"PP+DP 1f1b 2D test",
Expand All @@ -201,7 +201,7 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 2",
"--training.data_parallel_shard_degree 2",
],
],
"PP+DP gpipe 2D test",
Expand All @@ -227,15 +227,15 @@ def build_test_list():
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_degree 2",
"--training.data_parallel_shard_degree 2",
"--training.tensor_parallel_degree 2",
],
[
"--training.steps 20",
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_degree 2",
"--training.data_parallel_shard_degree 2",
"--training.tensor_parallel_degree 2",
],
],
Expand All @@ -249,7 +249,7 @@ def build_test_list():
[
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_degree 2",
"--training.data_parallel_shard_degree 2",
"--training.tensor_parallel_degree 2",
"--training.compile",
],
Expand Down Expand Up @@ -285,13 +285,37 @@ def build_test_list():
OverrideDefinitions(
[
[
"--training.data_parallel_type ddp",
"--training.data_parallel_shard_degree=1",
"--training.data_parallel_replicate_degree=4",
]
],
"DDP",
"ddp",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.data_parallel_shard_degree=2",
"--training.data_parallel_replicate_degree=2",
]
],
"HSDP",
"hsdp",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.data_parallel_shard_degree=2",
"--training.data_parallel_replicate_degree=2",
"--training.tensor_parallel_degree=2",
]
],
"HSDP+TP",
"hsdp+tp",
ngpu=8,
),
OverrideDefinitions(
[
[
Expand Down
34 changes: 26 additions & 8 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,34 @@ def __init__(self):
help="How many train steps to run",
)
self.parser.add_argument(
"--training.data_parallel_degree",
"--training.data_parallel_replicate_degree",
type=int,
default=1,
help="""
The `data_parallel_replicate_degree` argument specifies the degree of
data parallelism for weight replication. When this value is greater
than 1, weights will be replicated across `data_parallel_replicate_degree`
ranks. If `data_parallel_shard_degree` is also greater than 1, the parallelism
method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
parallelism method used is DDP (Distributed Data Parallelism).
1 means disabled.""",
)
self.parser.add_argument(
"--training.data_parallel_shard_degree",
type=int,
default=-1,
help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.",
help="""
The `data_parallel_shard_degree` argument specifies the degree of data
parallelism for weight sharding. When this value is greater than 1, weights
will be sharded across `data_parallel_shard_degree` ranks. If
`data_parallel_replicate_degree` is also greater than 1, the parallelism
method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
parallelism method used is FSDP (Fully Sharded Data Parallelism).
-1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that
only one of `data_parallel_replicate_degree` and `data_parallel_shard_degree`
can be negative.
1 means disabled.""",
)
self.parser.add_argument(
"--training.tensor_parallel_degree",
Expand Down Expand Up @@ -297,12 +321,6 @@ def __init__(self):
The default value will be the number of pipeline stages, if unspecified.
""",
)
self.parser.add_argument(
"--training.data_parallel_type",
type=str,
default="fsdp",
help="Data parallelism type. TorchTitan currently supports FSDP and DDP.",
)
self.parser.add_argument(
"--experimental.enable_compiled_autograd",
action="store_true",
Expand Down
63 changes: 48 additions & 15 deletions torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,45 +13,78 @@

@dataclass
class ParallelDims:
dp: int
dp_replicate: int
dp_shard: int
tp: int
pp: int
world_size: int
enable_loss_parallel: bool
dp_type: str

def __post_init__(self):
self.dp_type = self.dp_type.lower()
self._validate()

def _validate(self):
dp, tp, pp = self.dp, self.tp, self.pp
if dp == -1:
self.dp = dp = self.world_size // (tp * pp)
assert dp >= 1, dp
dp_replicate, dp_shard, tp, pp = (
self.dp_replicate,
self.dp_shard,
self.tp,
self.pp,
)
for d in (dp_replicate, tp, pp):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."

dp = dp_replicate * dp_shard
if dp < 0:
dp = self.world_size // (tp * pp)
self.dp_shard = dp_shard = dp // dp_replicate

assert dp_replicate >= 1
assert dp_shard >= 1
assert tp >= 1, tp
assert pp >= 1, pp
assert (
dp * tp * pp == self.world_size
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
assert self.dp_type in ("fsdp", "ddp")
assert dp_replicate * dp_shard * tp * pp == self.world_size, (
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
)

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True
[self.pp, self.dp_replicate, self.dp_shard, self.tp],
["pp", "dp_replicate", "dp_shard", "tp"],
strict=True,
):
if d > 1:
dims.append(d)
names.append(name)
if (name == "dp_replicate" and self.dp_shard == 1) or (
name == "dp_shard" and self.dp_replicate == 1
):
names.append("dp")
else:
names.append(name)

logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
names = tuple(names)
return init_device_mesh(device_type, dims, mesh_dim_names=names)
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
# Create all the submesh here to ensure all required process groups are
# initialized
if self.dp_replicate > 1 and self.dp_shard > 1:
mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp")
return mesh

@property
def dp_enabled(self):
return self.dp > 1
return self.dp_replicate > 1 or self.dp_shard > 1

@property
def dp_replicate_enabled(self):
return self.dp_replicate > 1

@property
def dp_shard_enabled(self):
return self.dp_shard > 1

@property
def tp_enabled(self):
Expand Down
14 changes: 9 additions & 5 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ def parallelize_llama(
apply_compile(model)

if parallel_dims.dp_enabled:
if parallel_dims.dp_type == "fsdp":
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
if parallel_dims.dp_shard_enabled:
if parallel_dims.dp_replicate_enabled:
dp_mesh = world_mesh["dp_replicate", "dp_shard"]
else:
dp_mesh = world_mesh["dp"]

apply_fsdp(
model,
Expand All @@ -87,6 +89,10 @@ def parallelize_llama(
tp_enabled=parallel_dims.tp_enabled,
pp_enabled=parallel_dims.pp_enabled,
)
if parallel_dims.dp_replicate_enabled:
logger.info("Applied HSDP to the model")
else:
logger.info("Applied FSDP to the model")
else:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
Expand Down Expand Up @@ -322,8 +328,6 @@ def apply_fsdp(
)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)

logger.info("Applied FSDP to the model")


def apply_ddp(
model: nn.Module,
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def main(job_config: JobConfig):
# init distributed
world_size = int(os.environ["WORLD_SIZE"])
parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
dp_shard=job_config.training.data_parallel_shard_degree,
dp_replicate=job_config.training.data_parallel_replicate_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
dp_type=job_config.training.data_parallel_type,
)
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)
Expand Down
3 changes: 2 additions & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ seq_len = 2048
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
compile = false
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
Expand Down
3 changes: 2 additions & 1 deletion train_configs/llama2_13b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ seq_len = 4096
warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
compile = false
dataset = "c4"
Expand Down
3 changes: 2 additions & 1 deletion train_configs/llama2_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ seq_len = 4096
warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 8 # 8-way TP
compile = false
dataset = "c4"
Expand Down
3 changes: 2 additions & 1 deletion train_configs/llama2_7b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ seq_len = 2048
warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1 # dp-only would be sufficient for 7B
compile = false
dataset = "c4"
Expand Down
3 changes: 2 additions & 1 deletion train_configs/llama3_405b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ seq_len = 8192
warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 3000
data_parallel_degree = -1
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 8 # 8-way TP
compile = true
dataset = "c4"
Expand Down
3 changes: 2 additions & 1 deletion train_configs/llama3_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ seq_len = 8192
warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 8 # 8-way TP
compile = false
dataset = "c4"
Expand Down
3 changes: 2 additions & 1 deletion train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ seq_len = 8192
warmup_steps = 200 # lr scheduler warm up
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
compile = false
dataset = "c4"
Expand Down

0 comments on commit db6b39a

Please sign in to comment.