Skip to content

Commit

Permalink
Delete manual pp input/output tensors after adding shape inference
Browse files Browse the repository at this point in the history
ghstack-source-id: 0ccee5e0d09e3b57d6ad78ba349f5d4d569d200c
Pull Request resolved: pytorch#616
  • Loading branch information
wconstab committed Oct 14, 2024
1 parent c134345 commit 1629fb9
Showing 1 changed file with 1 addition and 46 deletions.
47 changes: 1 addition & 46 deletions torchtitan/parallelisms/pipeline_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch.distributed import DeviceMesh
from torch.distributed.pipelining import PipelineStage

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
from torchtitan.models.llama.model import ModelArgs
from torchtitan.parallelisms.parallel_dims import ParallelDims
Expand Down Expand Up @@ -46,23 +46,6 @@ def pipeline_llama(
return pp_schedule, models


def _llama_trace_input(job_config: JobConfig, model_config: ModelArgs, device="meta"):
"""Get meta tensors with the right input shapes used for tracing"""
tokens_shape = (job_config.training.batch_size, job_config.training.seq_len)
tokens = torch.randint(
model_config.vocab_size, tokens_shape, dtype=torch.int64, device=device
)
return (tokens,)


def _mixed_precision_dtype(
job_config: JobConfig, parallel_dims, default: torch.dtype = torch.float32
) -> torch.dtype:
"""Get the mixed precision dtype if FSDP is enabled, otherwise return the default"""
mp_arg = job_config.training.mixed_precision_param
return TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else default


def pipeline_llama_manual_split(
whole_model: nn.Module,
pp_mesh: DeviceMesh,
Expand Down Expand Up @@ -108,39 +91,11 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
model.norm = None
model.output = None

# Note: these tensors are only here as metadata hints, so pipelining runtime knows what size buffer to allocate.
# these tensors should be on meta device, adn the model should also. It will be allocated on device after
# applying all other parallelisms.

# TODO(whc) once ManualPipelineStage supports lazy shape inference, we can avoid specifying input/output shapes
mp_dtype = _mixed_precision_dtype(job_config, parallel_dims)
batch_size = job_config.training.batch_size
local_seq_len = int(job_config.training.seq_len // parallel_dims.tp)
layers_io_shape = (batch_size, local_seq_len, model_config.dim)
output_layer_shape = (
batch_size,
job_config.training.seq_len,
model_config.vocab_size,
)
if is_first:
(input,) = _llama_trace_input(job_config, model_config, device="meta")
else:
# later layers (assume all start w/ a transformer layer)
input = torch.rand(layers_io_shape, dtype=mp_dtype, device="meta")

if is_last:
output = torch.rand(output_layer_shape, dtype=torch.float32, device="meta")
else:
# earlier layers (assume all end in a transformer layer)
output = torch.rand(layers_io_shape, dtype=mp_dtype, device="meta")

stage = PipelineStage(
model,
stage_idx,
num_stages,
device,
input_args=input.chunk(microbatches)[0],
output_args=output.chunk(microbatches)[0],
group=pp_mesh.get_group("pp"),
)
return stage, model
Expand Down

0 comments on commit 1629fb9

Please sign in to comment.