From 82f7387f8c7af6b4285869f5daf2f52e523a2774 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Sun, 12 Jan 2025 15:04:33 -0800 Subject: [PATCH] minor fix ghstack-source-id: 0dd35232e76d80a4542a7e91b2d25fea663938b6 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/788 --- docs/checkpoint.md | 2 +- torchtitan/parallelisms/pipelining_utils.py | 2 +- train.py | 5 ++++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 3f66e5acd..05ef6f4d1 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -75,5 +75,5 @@ A seed checkpoint does initialization of the model on a single CPU, and can be l To create a seed checkpoint, use the same model config as you use for training. e.g. ```bash -NGPU=1 CONFIG= ./run_llama_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --training.data_parallel_shard_degree 1 +NGPU=1 CONFIG= ./run_llama_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --training.data_parallel_replicate_degree 1 --training.data_parallel_shard_degree 1 --training.tensor_parallel_degree 1 --experimental.pipeline_parallel_degree 1 --experimental.context_parallel_degree 1 ``` diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py index 4322a0315..7b2994f80 100644 --- a/torchtitan/parallelisms/pipelining_utils.py +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -107,7 +107,7 @@ def build_pipeline_schedule(job_config, stages, loss_fn): ) logger.info( f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule} \ -with {n_microbatches} and {num_total_stages} stages." +with {n_microbatches} microbatches and {num_total_stages} stages." ) if pp_schedule_csv: diff --git a/train.py b/train.py index 2874d0d5a..21dd9f8b8 100644 --- a/train.py +++ b/train.py @@ -201,7 +201,10 @@ def loss_fn(pred, labels): if job_config.checkpoint.create_seed_checkpoint: assert ( world_size == 1 - ), "Must create seed-checkpoint using one gpu, to disable sharding" + ), "Must create seed checkpoint using a single device, to disable sharding" + assert ( + job_config.checkpoint.enable_checkpoint + ), "Must enable checkpointing when creating a seed checkpoint" checkpoint.save(curr_step=0, force=True) logger.info("Created seed checkpoint") return