Skip to content

Commit

Permalink
Make pp split points optional
Browse files Browse the repository at this point in the history
ghstack-source-id: edf6483cd98bb584ea631120eb16ed147faee6b3
Pull Request resolved: pytorch#604
  • Loading branch information
H-Huang committed Oct 10, 2024
1 parent a88dc41 commit cec7fb6
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 12 deletions.
10 changes: 0 additions & 10 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def build_test_list():
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
"--experimental.pipeline_parallel_schedule FlexibleInterleaved1F1B",
],
],
Expand All @@ -155,7 +154,6 @@ def build_test_list():
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1F1B",
"--training.data_parallel_shard_degree 1",
],
Expand All @@ -170,7 +168,6 @@ def build_test_list():
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule GPipe",
"--training.data_parallel_shard_degree 1",
],
Expand All @@ -185,7 +182,6 @@ def build_test_list():
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1F1B",
"--training.data_parallel_shard_degree 2",
],
Expand All @@ -199,7 +195,6 @@ def build_test_list():
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule GPipe",
"--training.data_parallel_shard_degree 2",
],
Expand All @@ -213,7 +208,6 @@ def build_test_list():
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.tensor_parallel_degree 2",
],
],
Expand All @@ -226,15 +220,13 @@ def build_test_list():
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_shard_degree 2",
"--training.tensor_parallel_degree 2",
],
[
"--training.steps 20",
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_shard_degree 2",
"--training.tensor_parallel_degree 2",
],
Expand All @@ -248,7 +240,6 @@ def build_test_list():
[
[
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_shard_degree 2",
"--training.tensor_parallel_degree 2",
"--training.compile",
Expand All @@ -264,7 +255,6 @@ def build_test_list():
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
"--experimental.pipeline_parallel_schedule Interleaved1F1B",
],
],
Expand Down
6 changes: 5 additions & 1 deletion torchtitan/parallelisms/pipeline_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.pipelining_utils import (
build_pipeline_schedule,
generate_split_points,
stage_ids_this_rank,
)

Expand Down Expand Up @@ -83,7 +84,10 @@ def pipeline_llama_manual_split(
microbatches = (
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp
)
splits = job_config.experimental.pipeline_parallel_split_points
splits = (
job_config.experimental.pipeline_parallel_split_points
or generate_split_points(job_config, parallel_dims.pp, model_config)
)

def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=False):
model = copy.deepcopy(whole_model)
Expand Down
49 changes: 48 additions & 1 deletion torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,57 @@
ScheduleFlexibleInterleaved1F1B,
ScheduleInterleaved1F1B,
)
from torch.distributed.pipelining.schedules import get_schedule_class

from torch.distributed.pipelining.schedules import (
get_schedule_class,
PipelineScheduleMulti,
PipelineScheduleSingle,
)
from torchtitan.logging import logger


def generate_split_points(job_config, pp_dim, model_config):
schedule_class = get_schedule_class(
job_config.experimental.pipeline_parallel_schedule
)
if issubclass(schedule_class, PipelineScheduleSingle):
num_stages_per_rank = 1
elif issubclass(schedule_class, PipelineScheduleMulti):
# Multi-stage schedules support more than 2 stages per rank, but this is the default if
# no pipeline split is specified
num_stages_per_rank = 2
else:
raise ValueError(
f"Unsupported pipeline schedule: {job_config.experimental.pipeline_parallel_schedule}"
)
total_stages = pp_dim * num_stages_per_rank
num_layers = model_config.n_layers
if total_stages > num_layers:
raise ValueError("Total stages cannot be greater than the number of layers")

base_interval = num_layers // total_stages
extra_layers = num_layers % total_stages

splits = []
current_layer = 0
for i in range(total_stages - 1):
if i == 0:
current_layer += base_interval
else:
# Middle stages get an extra layer if there are any remaining
if extra_layers > 0:
current_layer += base_interval + 1
extra_layers -= 1
else:
current_layer += base_interval
splits.append("layers." + str(current_layer))
logger.info(
f"No 'pipeline_parallel_split_points' so the generated splits are: {splits} \
This may be sub-optimal as the number of layers per stage may be unbalanced."
)
return splits


def build_pipeline_schedule(job_config, stages, loss_fn):
looped_schedule = False

Expand Down

0 comments on commit cec7fb6

Please sign in to comment.