Skip to content

Commit db6b39a

Browse files
committed
Enable HSDP
This PR enables HSDP. ghstack-source-id: c85046a Pull Request resolved: #518
1 parent 1923ce4 commit db6b39a

File tree

13 files changed

+133
-47
lines changed

13 files changed

+133
-47
lines changed

estimation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ def estimate_memory(job_config: JobConfig):
6464
job_config.experimental.enable_compiled_autograd = False
6565

6666
parallel_dims = ParallelDims(
67-
dp=job_config.training.data_parallel_degree,
67+
dp_shard=job_config.training.data_parallel_shard_degree,
68+
dp_replicate=job_config.training.data_parallel_replicate_degree,
6869
tp=job_config.training.tensor_parallel_degree,
6970
pp=job_config.experimental.pipeline_parallel_degree,
7071
world_size=world_size,
7172
enable_loss_parallel=job_config.training.enable_loss_parallel,
72-
dp_type=job_config.training.data_parallel_type,
7373
)
7474

7575
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")

test_runner.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def build_test_list():
157157
"--experimental.pipeline_parallel_degree 2",
158158
"--experimental.pipeline_parallel_split_points layers.4",
159159
"--experimental.pipeline_parallel_schedule 1f1b",
160-
"--training.data_parallel_degree 1",
160+
"--training.data_parallel_shard_degree 1",
161161
],
162162
],
163163
"PP 1D test 1f1b",
@@ -172,7 +172,7 @@ def build_test_list():
172172
"--experimental.pipeline_parallel_degree 2",
173173
"--experimental.pipeline_parallel_split_points layers.4",
174174
"--experimental.pipeline_parallel_schedule gpipe",
175-
"--training.data_parallel_degree 1",
175+
"--training.data_parallel_shard_degree 1",
176176
],
177177
],
178178
"PP 1D test gpipe",
@@ -187,7 +187,7 @@ def build_test_list():
187187
"--experimental.pipeline_parallel_degree 2",
188188
"--experimental.pipeline_parallel_split_points layers.4",
189189
"--experimental.pipeline_parallel_schedule 1f1b",
190-
"--training.data_parallel_degree 2",
190+
"--training.data_parallel_shard_degree 2",
191191
],
192192
],
193193
"PP+DP 1f1b 2D test",
@@ -201,7 +201,7 @@ def build_test_list():
201201
"--experimental.pipeline_parallel_degree 2",
202202
"--experimental.pipeline_parallel_split_points layers.4",
203203
"--experimental.pipeline_parallel_schedule gpipe",
204-
"--training.data_parallel_degree 2",
204+
"--training.data_parallel_shard_degree 2",
205205
],
206206
],
207207
"PP+DP gpipe 2D test",
@@ -227,15 +227,15 @@ def build_test_list():
227227
"--checkpoint.enable_checkpoint",
228228
"--experimental.pipeline_parallel_degree 2",
229229
"--experimental.pipeline_parallel_split_points layers.4",
230-
"--training.data_parallel_degree 2",
230+
"--training.data_parallel_shard_degree 2",
231231
"--training.tensor_parallel_degree 2",
232232
],
233233
[
234234
"--training.steps 20",
235235
"--checkpoint.enable_checkpoint",
236236
"--experimental.pipeline_parallel_degree 2",
237237
"--experimental.pipeline_parallel_split_points layers.4",
238-
"--training.data_parallel_degree 2",
238+
"--training.data_parallel_shard_degree 2",
239239
"--training.tensor_parallel_degree 2",
240240
],
241241
],
@@ -249,7 +249,7 @@ def build_test_list():
249249
[
250250
"--experimental.pipeline_parallel_degree 2",
251251
"--experimental.pipeline_parallel_split_points layers.4",
252-
"--training.data_parallel_degree 2",
252+
"--training.data_parallel_shard_degree 2",
253253
"--training.tensor_parallel_degree 2",
254254
"--training.compile",
255255
],
@@ -285,13 +285,37 @@ def build_test_list():
285285
OverrideDefinitions(
286286
[
287287
[
288-
"--training.data_parallel_type ddp",
288+
"--training.data_parallel_shard_degree=1",
289+
"--training.data_parallel_replicate_degree=4",
289290
]
290291
],
291292
"DDP",
292293
"ddp",
293294
ngpu=4,
294295
),
296+
OverrideDefinitions(
297+
[
298+
[
299+
"--training.data_parallel_shard_degree=2",
300+
"--training.data_parallel_replicate_degree=2",
301+
]
302+
],
303+
"HSDP",
304+
"hsdp",
305+
ngpu=4,
306+
),
307+
OverrideDefinitions(
308+
[
309+
[
310+
"--training.data_parallel_shard_degree=2",
311+
"--training.data_parallel_replicate_degree=2",
312+
"--training.tensor_parallel_degree=2",
313+
]
314+
],
315+
"HSDP+TP",
316+
"hsdp+tp",
317+
ngpu=8,
318+
),
295319
OverrideDefinitions(
296320
[
297321
[

torchtitan/config_manager.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,34 @@ def __init__(self):
224224
help="How many train steps to run",
225225
)
226226
self.parser.add_argument(
227-
"--training.data_parallel_degree",
227+
"--training.data_parallel_replicate_degree",
228+
type=int,
229+
default=1,
230+
help="""
231+
The `data_parallel_replicate_degree` argument specifies the degree of
232+
data parallelism for weight replication. When this value is greater
233+
than 1, weights will be replicated across `data_parallel_replicate_degree`
234+
ranks. If `data_parallel_shard_degree` is also greater than 1, the parallelism
235+
method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
236+
parallelism method used is DDP (Distributed Data Parallelism).
237+
1 means disabled.""",
238+
)
239+
self.parser.add_argument(
240+
"--training.data_parallel_shard_degree",
228241
type=int,
229242
default=-1,
230-
help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.",
243+
help="""
244+
The `data_parallel_shard_degree` argument specifies the degree of data
245+
parallelism for weight sharding. When this value is greater than 1, weights
246+
will be sharded across `data_parallel_shard_degree` ranks. If
247+
`data_parallel_replicate_degree` is also greater than 1, the parallelism
248+
method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
249+
parallelism method used is FSDP (Fully Sharded Data Parallelism).
250+
251+
-1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that
252+
only one of `data_parallel_replicate_degree` and `data_parallel_shard_degree`
253+
can be negative.
254+
1 means disabled.""",
231255
)
232256
self.parser.add_argument(
233257
"--training.tensor_parallel_degree",
@@ -297,12 +321,6 @@ def __init__(self):
297321
The default value will be the number of pipeline stages, if unspecified.
298322
""",
299323
)
300-
self.parser.add_argument(
301-
"--training.data_parallel_type",
302-
type=str,
303-
default="fsdp",
304-
help="Data parallelism type. TorchTitan currently supports FSDP and DDP.",
305-
)
306324
self.parser.add_argument(
307325
"--experimental.enable_compiled_autograd",
308326
action="store_true",

torchtitan/parallelisms/parallel_dims.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,45 +13,78 @@
1313

1414
@dataclass
1515
class ParallelDims:
16-
dp: int
16+
dp_replicate: int
17+
dp_shard: int
1718
tp: int
1819
pp: int
1920
world_size: int
2021
enable_loss_parallel: bool
21-
dp_type: str
2222

2323
def __post_init__(self):
24-
self.dp_type = self.dp_type.lower()
2524
self._validate()
2625

2726
def _validate(self):
28-
dp, tp, pp = self.dp, self.tp, self.pp
29-
if dp == -1:
30-
self.dp = dp = self.world_size // (tp * pp)
31-
assert dp >= 1, dp
27+
dp_replicate, dp_shard, tp, pp = (
28+
self.dp_replicate,
29+
self.dp_shard,
30+
self.tp,
31+
self.pp,
32+
)
33+
for d in (dp_replicate, tp, pp):
34+
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
35+
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
36+
37+
dp = dp_replicate * dp_shard
38+
if dp < 0:
39+
dp = self.world_size // (tp * pp)
40+
self.dp_shard = dp_shard = dp // dp_replicate
41+
42+
assert dp_replicate >= 1
43+
assert dp_shard >= 1
3244
assert tp >= 1, tp
3345
assert pp >= 1, pp
34-
assert (
35-
dp * tp * pp == self.world_size
36-
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
37-
assert self.dp_type in ("fsdp", "ddp")
46+
assert dp_replicate * dp_shard * tp * pp == self.world_size, (
47+
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
48+
f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
49+
)
3850

3951
def build_mesh(self, device_type):
4052
dims = []
4153
names = []
4254
for d, name in zip(
43-
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True
55+
[self.pp, self.dp_replicate, self.dp_shard, self.tp],
56+
["pp", "dp_replicate", "dp_shard", "tp"],
57+
strict=True,
4458
):
4559
if d > 1:
4660
dims.append(d)
47-
names.append(name)
61+
if (name == "dp_replicate" and self.dp_shard == 1) or (
62+
name == "dp_shard" and self.dp_replicate == 1
63+
):
64+
names.append("dp")
65+
else:
66+
names.append(name)
67+
4868
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
4969
names = tuple(names)
50-
return init_device_mesh(device_type, dims, mesh_dim_names=names)
70+
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
71+
# Create all the submesh here to ensure all required process groups are
72+
# initialized
73+
if self.dp_replicate > 1 and self.dp_shard > 1:
74+
mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp")
75+
return mesh
5176

5277
@property
5378
def dp_enabled(self):
54-
return self.dp > 1
79+
return self.dp_replicate > 1 or self.dp_shard > 1
80+
81+
@property
82+
def dp_replicate_enabled(self):
83+
return self.dp_replicate > 1
84+
85+
@property
86+
def dp_shard_enabled(self):
87+
return self.dp_shard > 1
5588

5689
@property
5790
def tp_enabled(self):

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,11 @@ def parallelize_llama(
7373
apply_compile(model)
7474

7575
if parallel_dims.dp_enabled:
76-
if parallel_dims.dp_type == "fsdp":
77-
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
78-
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
76+
if parallel_dims.dp_shard_enabled:
77+
if parallel_dims.dp_replicate_enabled:
78+
dp_mesh = world_mesh["dp_replicate", "dp_shard"]
79+
else:
80+
dp_mesh = world_mesh["dp"]
7981

8082
apply_fsdp(
8183
model,
@@ -87,6 +89,10 @@ def parallelize_llama(
8789
tp_enabled=parallel_dims.tp_enabled,
8890
pp_enabled=parallel_dims.pp_enabled,
8991
)
92+
if parallel_dims.dp_replicate_enabled:
93+
logger.info("Applied HSDP to the model")
94+
else:
95+
logger.info("Applied FSDP to the model")
9096
else:
9197
if world_mesh.ndim > 1:
9298
raise RuntimeError("DDP has not supported > 1D parallelism")
@@ -322,8 +328,6 @@ def apply_fsdp(
322328
)
323329
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
324330

325-
logger.info("Applied FSDP to the model")
326-
327331

328332
def apply_ddp(
329333
model: nn.Module,

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@ def main(job_config: JobConfig):
5959
# init distributed
6060
world_size = int(os.environ["WORLD_SIZE"])
6161
parallel_dims = ParallelDims(
62-
dp=job_config.training.data_parallel_degree,
62+
dp_shard=job_config.training.data_parallel_shard_degree,
63+
dp_replicate=job_config.training.data_parallel_replicate_degree,
6364
tp=job_config.training.tensor_parallel_degree,
6465
pp=job_config.experimental.pipeline_parallel_degree,
6566
world_size=world_size,
6667
enable_loss_parallel=job_config.training.enable_loss_parallel,
67-
dp_type=job_config.training.data_parallel_type,
6868
)
6969
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
7070
torch.cuda.set_device(device)

train_configs/debug_model.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ seq_len = 2048
3535
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
3636
max_norm = 1.0 # grad norm clipping
3737
steps = 10
38-
data_parallel_degree = -1
38+
data_parallel_replicate_degree = 1
39+
data_parallel_shard_degree = -1
3940
tensor_parallel_degree = 1
4041
compile = false
4142
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)

train_configs/llama2_13b.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ seq_len = 4096
3131
warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps
3232
max_norm = 1.0 # grad norm clipping
3333
steps = 1000
34-
data_parallel_degree = -1
34+
data_parallel_replicate_degree = 1
35+
data_parallel_shard_degree = -1
3536
tensor_parallel_degree = 1
3637
compile = false
3738
dataset = "c4"

train_configs/llama2_70b.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ seq_len = 4096
3131
warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps
3232
max_norm = 1.0 # grad norm clipping
3333
steps = 1000
34-
data_parallel_degree = -1
34+
data_parallel_replicate_degree = 1
35+
data_parallel_shard_degree = -1
3536
tensor_parallel_degree = 8 # 8-way TP
3637
compile = false
3738
dataset = "c4"

train_configs/llama2_7b.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ seq_len = 2048
3030
warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps
3131
max_norm = 1.0 # grad norm clipping
3232
steps = 1000
33-
data_parallel_degree = -1
33+
data_parallel_replicate_degree = 1
34+
data_parallel_shard_degree = -1
3435
tensor_parallel_degree = 1 # dp-only would be sufficient for 7B
3536
compile = false
3637
dataset = "c4"

0 commit comments

Comments
 (0)