diff --git a/.gitignore b/.gitignore index 00037067..cf5f06e1 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,3 @@ torchtitan/datasets/**/*.model *.log error.json _remote_module_non_scriptable.py - -# torch compile debug related -torch_compile_debug/* diff --git a/benchmark.py b/benchmark.py deleted file mode 100644 index 16a706f9..00000000 --- a/benchmark.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import time -from datetime import timedelta - -import torch -from torch.distributed.elastic.multiprocessing.errors import record - -from torchbenchmark.util.experiment.instantiator import ( - load_model, - TorchBenchModelConfig, -) -from torchbenchmark.util.experiment.metrics import get_model_flops -from torchbenchmark.util.input import input_cast - -from torchtitan import utils -from torchtitan.checkpoint import TrainState -from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP -from torchtitan.logging import init_logger, logger -from torchtitan.metrics import build_gpu_memory_monitor -from torchtitan.parallelisms import ParallelDims -from torchtitan.parallelisms.parallelize_llama import torch_spmd_parallelize -from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling - - -# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html -@record -def main(job_config: JobConfig): - init_logger() - logger.info(f"Starting job: {job_config.job.description}") - - # used for colorful printing - color = utils.Color if job_config.metrics.enable_color_printing else utils.NoColor - - # take control of garbage collection to avoid stragglers - gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) - - # init distributed - world_size = int(os.environ["WORLD_SIZE"]) - parallel_dims = ParallelDims( - dp=job_config.training.data_parallel_degree, - tp=job_config.training.tensor_parallel_degree, - pp=job_config.experimental.pipeline_parallel_degree, - world_size=world_size, - enable_loss_parallel=job_config.training.enable_loss_parallel, - dp_type=job_config.training.data_parallel_type, - ) - device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") - torch.cuda.set_device(device) - utils.init_distributed(job_config) - # initialize GPU memory monitor and get peak flops for MFU calculation - gpu_memory_monitor = build_gpu_memory_monitor() - gpu_peak_flops = utils.get_peak_flops(gpu_memory_monitor.device_name) - - # build meshes - world_mesh = parallel_dims.build_mesh(device_type="cuda") - if parallel_dims.dp_enabled: - dp_mesh = world_mesh["dp"] - dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() - else: - dp_degree, dp_rank = 1, 0 - - if parallel_dims.pp_enabled: - pp_mesh = world_mesh["pp"] - - model_name = job_config.model.name - - # initiate model from torchbench - config = TorchBenchModelConfig( - name=model_name, - test="train", - device="cuda", - batch_size=job_config.training.batch_size, - extra_args=[], - ) - model_flops = get_model_flops(config) - benchmark_model = load_model(config) - model, _ = benchmark_model.get_module() - - # TODO: there seems to be a bug with dtype conversion (e.g. use resnet50) - # cast input dtype if needed - param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] - input_cond = lambda x: x.dtype == torch.float32 - input_action = lambda x: x.to(param_dtype) - if hasattr(benchmark_model, "example_inputs"): - benchmark_model.example_inputs = input_cast( - input_cond, input_action, benchmark_model.example_inputs - ) - else: - logger.warning( - f"{model_name} example inputs haven't been cast to {action} yet!" - ) - - # log model size - model_param_count = utils.get_num_params(model) - logger.info( - f"{color.blue}Model {model_name} " - f"{color.red}size: {model_param_count:,} total parameters{color.reset}" - ) - - # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel - model = torch_spmd_parallelize(model, world_mesh, parallel_dims, job_config) - - # update model and optimizer after applying parallelisms - benchmark_model.set_module(model) - optimizer = benchmark_model.get_optimizer() - optimizer.add_param_group({"params": model.parameters()}) - - model.train() - - gpu_mem_stats = gpu_memory_monitor.get_peak_stats() - logger.info( - f"GPU memory usage for model: " - f"{gpu_mem_stats.max_reserved_gib:.2f}GiB" - f"({gpu_mem_stats.max_reserved_pct:.2f}%)" - ) - - train_state = TrainState() - - # variables used to keep info for metrics logging - losses_since_last_log = [] - gpu_memory_monitor.reset_peak_stats() - - # train loop - logger.info( - f"Training starts at step {train_state.step + 1}, " - f"with local batch size {job_config.training.batch_size}, " - f"global batch size {job_config.training.batch_size * dp_degree}, " - f"total steps {job_config.training.steps}" - ) - with maybe_enable_profiling( - job_config, global_step=train_state.step - ) as torch_profiler, maybe_enable_memory_snapshot( - job_config, global_step=train_state.step - ) as memory_profiler: - while train_state.step < job_config.training.steps: - train_state.step += 1 - gc_handler.run(train_state.step) - - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Collect time_ns() instead of time() which does not provide better precision than 1 - # second according to https://docs.python.org/3/library/time.html#time.time. - t0 = time.time_ns() - start_event.record() - - is_staged = ( - hasattr(benchmark_model, "forward") - and hasattr(benchmark_model, "backward") - and hasattr(benchmark_model, "optimizer_step") - ) - if is_staged and (getattr(benchmark_model, "train", None) is None): - if optimizer is not None: - optimizer.zero_grad() - loss = benchmark_model.forward() - benchmark_model.backward(loss) - if optimizer is not None: - benchmark_model.optimizer_step() - else: - loss = benchmark_model.train() - - end_event.record() - torch.cuda.synchronize() - t1 = time.time_ns() - time_delta = start_event.elapsed_time(end_event), (t1 - t0) / 1_000_000 - - # log metrics - losses_since_last_log.append(loss) - if ( - train_state.step == 1 - or train_state.step % job_config.metrics.log_freq == 0 - ): - losses = [ - loss.item() if isinstance(loss, torch.Tensor) else loss - for loss in losses_since_last_log - ] - avg_loss, max_loss = sum(losses) / len(losses), max(losses) - if parallel_dims.dp_enabled: - global_avg_loss, global_max_loss = ( - utils.dist_mean(avg_loss, dp_mesh), - utils.dist_max(max_loss, dp_mesh), - ) - else: - global_avg_loss, global_max_loss = avg_loss, max_loss - - gpu_mem_stats = gpu_memory_monitor.get_peak_stats() - - logger.info( - f"{color.cyan}step: {train_state.step:2} " - f"{color.green}loss: {global_avg_loss:7.4f} " - f"{color.yellow}memory: {gpu_mem_stats.max_reserved_gib:5.2f}GiB" - f"({gpu_mem_stats.max_reserved_pct:.2f}%) " - f"{color.blue}GPU time: {time_delta[0]:.3f}ms " - f"CPU wall time: {time_delta[1]:.3f}ms{color.reset}" - ) - - losses_since_last_log.clear() - gpu_memory_monitor.reset_peak_stats() - - # signal the profiler that the next profiling step has started - if torch_profiler: - torch_profiler.step() - if memory_profiler: - memory_profiler.step() - - # reduce timeout after first train step for faster signal - # (assuming lazy init and compilation are finished) - if train_state.step == 1: - utils.set_pg_timeouts( - timeout=timedelta(seconds=job_config.comm.train_timeout_seconds), - world_mesh=world_mesh, - ) - - if torch.distributed.get_rank() == 0: - logger.info("Sleeping 2 seconds for other ranks to complete") - time.sleep(2) - - logger.info("Training completed") - - -if __name__ == "__main__": - config = JobConfig() - config.parse_args() - main(config) - torch.distributed.destroy_process_group() diff --git a/run_benchmark_train.sh b/run_benchmark_train.sh deleted file mode 100755 index 022e74c5..00000000 --- a/run_benchmark_train.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -set -ex - -# use envs as local overrides for convenience -# e.g. -# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh -NGPU=${NGPU:-"8"} -LOG_RANK=${LOG_RANK:-0} -CONFIG_FILE=${CONFIG_FILE:-"./train_configs/benchmark_model.toml"} - -overrides="" -if [ $# -ne 0 ]; then - overrides="$*" -fi - -torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ ---local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ -benchmark.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/run_llama_train.sh b/run_llama_train.sh index 296da519..a4107806 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -19,8 +19,6 @@ if [ $# -ne 0 ]; then overrides="$*" fi -# TORCH_TRACE="./outputs/trace" \ -TORCH_NCCL_AVOID_RECORD_STREAMS=1 \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ train.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index d6ead31c..3ba1d102 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -241,14 +241,6 @@ def __init__(self): action="store_true", help="Whether to apply loss parallel when sequence parallel is enabled", ) - - # experimental configs - self.parser.add_argument( - "--experimental.torch_spmd", - default=False, - action="store_true", - help="Whether to use the experimental torch_spmd style parallelism", - ) self.parser.add_argument( "--experimental.enable_async_tensor_parallel", default=False, diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index ef5f1fc7..aa07f25f 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -29,55 +29,7 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging import logger from torchtitan.parallelisms.parallel_dims import ParallelDims - - -# NOTE(lty): experimental for the PT-D 24 research internship project -def torch_spmd_parallelize( - model: nn.Module, - world_mesh: DeviceMesh, - parallel_dims: ParallelDims, - job_config: JobConfig, -): - torch._inductor.config.simplefsdp.enable_reorder = True - torch._inductor.config.simplefsdp.enable_bucket = True - - if parallel_dims.tp_enabled: - apply_tp( - model, - world_mesh["tp"], - loss_parallel=parallel_dims.loss_parallel_enabled, - enable_float8=job_config.float8.enable_float8_linear, - enable_async_tp=job_config.experimental.enable_async_tensor_parallel, - ) - - ac_config = job_config.activation_checkpoint - if ac_config.mode != "none": - apply_ac(model, ac_config) - logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") - - if parallel_dims.dp_enabled: - from torch_spmd.data_parallel import data_parallel, MixedPrecisionPolicy - - mp_policy = MixedPrecisionPolicy( - param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], - reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], - ) - dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh - - model = data_parallel( - model, - dp_mesh, - mode="fully_shard", - ac_mode=ac_config.mode, - mp_policy=mp_policy, - ) - logger.info("Applied Simple FSDP to the model") - - if job_config.training.compile: - model = torch.compile(model, fullgraph=True) - logger.info("Compiling with torch.compile") - - return model +from torchtitan.parallelisms.utils import check_strided_sharding_enabled def parallelize_llama( @@ -93,9 +45,6 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - # NOTE(lty): experimental for the PT-D 24 research internship project - if job_config.experimental.torch_spmd: - return torch_spmd_parallelize(model, world_mesh, parallel_dims, job_config) if parallel_dims.tp_enabled: if ( @@ -351,12 +300,11 @@ def apply_fsdp( mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} - # TODO(lty): the check below requires the latest PyTorch nightly; remove for now # TODO: remove this check once PyTorch 2.5 is released. We can safely assume # that users won't use a nightly build which is older than 20240809 by then. - # if tp_enabled: - # # check if strided sharding is enabled, which is necessary for 2D/3D DCP - # check_strided_sharding_enabled() + if tp_enabled: + # check if strided sharding is enabled, which is necessary for 2D/3D DCP + check_strided_sharding_enabled() for layer_id, transformer_block in model.layers.items(): if pp_enabled: diff --git a/train.py b/train.py index 4803aa8f..ffea00a9 100644 --- a/train.py +++ b/train.py @@ -12,9 +12,6 @@ import torch from torch.distributed.elastic.multiprocessing.errors import record -# context needed by meta-init with torch_spmd -from torch_spmd.data_parallel import disable_data_parallel - from torchtitan import utils from torchtitan.checkpoint import CheckpointManager, TrainState from torchtitan.config_manager import JobConfig @@ -157,20 +154,16 @@ def loss_fn(pred, labels): # apply SPMD-style PT-D techniques models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) m.to_empty(device="cuda") - with disable_data_parallel() if job_config.experimental.torch_spmd else contextlib.nullcontext(): - m.init_weights() + m.init_weights() m.train() else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel - model = models_parallelize_fns[model_name]( - model, world_mesh, parallel_dims, job_config - ) + models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) # move sharded model to CPU/GPU and initialize weights via DTensor init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" model.to_empty(device=init_device) - with disable_data_parallel() if job_config.experimental.torch_spmd else contextlib.nullcontext(): - model.init_weights() + model.init_weights() model.train() model_parts = [model] diff --git a/train_configs/benchmark_model.toml b/train_configs/benchmark_model.toml deleted file mode 100644 index c2f37d04..00000000 --- a/train_configs/benchmark_model.toml +++ /dev/null @@ -1,39 +0,0 @@ -# torchtitan Config.toml - -[job] -dump_folder = "./outputs" -description = "torchbenchmark training" - -[profiling] -enable_profiling = false -save_traces_folder = "profile_trace" -profile_freq = 10 -enable_memory_snapshot = false -save_memory_snapshot_folder = "memory_snapshot" - -[metrics] -log_freq = 1 -enable_color_printing = true -enable_tensorboard = false -save_tb_folder = "tb" - -[model] -# name = "resnet50" -name = "hf_GPT2" - -[training] -batch_size = 8 -max_norm = 1.0 # grad norm clipping -steps = 10 -data_parallel_degree = -1 -compile = true -mixed_precision_param = "bfloat16" -mixed_precision_reduce = "bfloat16" -# mixed_precision_param = "float32" -# mixed_precision_reduce = "float32" - -[experimental] -torch_spmd = true - -[activation_checkpoint] -mode = 'none' # ['none', 'selective', 'full'] diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index f401b791..af547214 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -6,7 +6,7 @@ description = "Llama 3 debug training" use_for_integration_test = true [profiling] -enable_profiling = false +enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 10 enable_memory_snapshot = false @@ -15,7 +15,7 @@ save_memory_snapshot_folder = "memory_snapshot" [metrics] log_freq = 1 enable_color_printing = true -enable_tensorboard = false +enable_tensorboard = true save_tb_folder = "tb" [model] @@ -37,14 +37,12 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -compile = true +compile = false dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [experimental] -# pipeline_parallel_degree = 2 -# pipeline_parallel_split_points = ["layers.4"] +pipeline_parallel_degree = 1 enable_async_tensor_parallel = false -torch_spmd = true [checkpoint] enable_checkpoint = false @@ -56,8 +54,8 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] -mode = 'none' # ['none', 'selective', 'full'] -selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy +mode = 'selective' # ['none', 'selective', 'full'] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy [float8] enable_float8_linear = false diff --git a/train_configs/llama3_405b.toml b/train_configs/llama3_405b.toml index 1cceb14b..1a83301f 100644 --- a/train_configs/llama3_405b.toml +++ b/train_configs/llama3_405b.toml @@ -38,7 +38,7 @@ dataset = "c4" [experimental] pipeline_parallel_degree = 1 -enable_async_tensor_parallel = false +enable_async_tensor_parallel = true [checkpoint] enable_checkpoint = false @@ -53,6 +53,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] mode = 'full' # ['none', 'selective', 'full'] [float8] -enable_float8_linear = false -enable_fsdp_float8_all_gather = false -precompute_float8_dynamic_scale_for_fsdp = false +enable_float8_linear = true +enable_fsdp_float8_all_gather = true +precompute_float8_dynamic_scale_for_fsdp = true diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index c96278b7..3d0c5160 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -6,14 +6,13 @@ dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] -enable_profiling = false +enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 10 -enable_color_printing = false -enable_tensorboard = false +enable_tensorboard = true save_tb_folder = "tb" [model] @@ -34,12 +33,11 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -compile = true +compile = false dataset = "c4" [experimental] pipeline_parallel_degree = 1 -torch_spmd = true [checkpoint] enable_checkpoint = false @@ -51,7 +49,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] -mode = 'none' # ['none', 'selective', 'full'] +mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [float8]