Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
fegin committed Feb 12, 2025
2 parents caf5b97 + 5467f2b commit c131309
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 28 deletions.
9 changes: 9 additions & 0 deletions docs/checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ Finally, once you have obtained the last checkpoint, you can use the following c
python -m torch.distributed.checkpoint.format_utils dcp_to_torch torchtitan/outputs/checkpoint/step-1000 checkpoint.pt
```

7. EXCLUDING SPECIFIC KEYS FROM CHECKPOINT LOADING
In some cases, you may want to partially load from a previous-trained checkpoint and modify certain settings, such as the number of GPUs or the current step. To achieve this, you can use the `exclude_from_loading` parameter to specify which keys should be excluded from loading.
This parameter takes a comma-separated list of keys that should be excluded from loading.
```
[checkpoint]
enable_checkpoint = true
exclude_from_loading = "data_loader,lr_scheduler"
```

That's it. You have now successfully converted a sharded torchtitan checkpoint for use in torchtune.


Expand Down
30 changes: 30 additions & 0 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,18 @@ def build_test_list():
"pp_looped_zero_bubble",
ngpu=4,
),
OverrideDefinitions(
[
[
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_schedule ZBVZeroBubble",
"--experimental.pipeline_parallel_microbatches 8",
],
],
"PP zero bubble test (v shaped)",
"pp_zbv",
ngpu=2,
),
OverrideDefinitions(
[
[
Expand Down Expand Up @@ -418,6 +430,24 @@ def build_test_list():
"fsdp_reshard_always",
ngpu=2,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--training.steps 10",
],
# Save at [dp:4] and load at [dp:2, tp:2]. Note that the dataloader should be
# excluded during loading to avoid errors caused by mismatched dp_degree.
[
"--checkpoint.enable_checkpoint",
"--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer",
"--training.tensor_parallel_degree 2",
"--training.steps 20",
],
],
"Optional checkpoint",
"optional_checkpoint",
),
]
return integration_tests_flavors

Expand Down
71 changes: 71 additions & 0 deletions tests/unit_tests/test_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,77 @@ def test_parse_pp_split_points(self):
config.experimental.pipeline_parallel_split_points == cmdline_splits
), config.experimental.pipeline_parallel_split_points

def test_parse_exclude_from_loading(self):

toml_splits = ["optimizer", "dataloader"]
toml_split_str = ",".join(toml_splits)
cmdline_splits = ["optimizer", "lr_scheduler"]
cmdline_split_str = ",".join(cmdline_splits)
# no split points specified
config = JobConfig()
config.parse_args(
[
"--job.config_file",
"./train_configs/debug_model.toml",
]
)
assert config.checkpoint.exclude_from_loading == []

# toml has no split points, but cmdline splits are specified
config = JobConfig()
config.parse_args(
[
"--job.config_file",
"./train_configs/debug_model.toml",
"--checkpoint.exclude_from_loading",
f"{cmdline_split_str}",
]
)
assert (
config.checkpoint.exclude_from_loading == cmdline_splits
), config.checkpoint.exclude_from_loading

# toml has split points, cmdline does not
with tempfile.NamedTemporaryFile() as fp:
with open(fp.name, "wb") as f:
tomli_w.dump(
{
"checkpoint": {
"exclude_from_loading": toml_split_str,
}
},
f,
)
config = JobConfig()
config.parse_args(["--job.config_file", fp.name])
assert (
config.checkpoint.exclude_from_loading == toml_splits
), config.checkpoint.exclude_from_loading

# toml has split points, cmdline overrides them
with tempfile.NamedTemporaryFile() as fp:
with open(fp.name, "wb") as f:
tomli_w.dump(
{
"checkpoint": {
"exclude_from_loading": toml_split_str,
}
},
f,
)
config = JobConfig()
config.parse_args(
[
"--job.config_file",
fp.name,
"--checkpoint.exclude_from_loading",
f"{cmdline_split_str}",
]
)
assert (
config.checkpoint.exclude_from_loading == cmdline_splits
), config.checkpoint.exclude_from_loading

def test_print_help(self):
config = JobConfig()
parser = config.parser
Expand Down
17 changes: 11 additions & 6 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,8 @@ def __init__(
which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening
support described in (1).
3. LR schedulers also index model states like optimizers and would need to be flattened properly to support
resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like
optimizers do, so it's hard to write a generic 'flattener' utility.
TODO: This is currently unsolved and needs a fix.
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers with the assumption that
all lr_schedulers have the same state_dict.
"""
self.states = states

Expand Down Expand Up @@ -204,6 +201,7 @@ def __init__(

self.model_weights_only = ckpt_config.model_weights_only
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
self.exclude_from_loading = ckpt_config.exclude_from_loading

self.mp = None
if async_mode == AsyncMode.DISABLED:
Expand Down Expand Up @@ -436,10 +434,17 @@ def load(self, step: int = -1) -> bool:
}
logger.info(f"Loading the checkpoint at step {step}.")
begin = time.monotonic()
states_to_load = {
k: v for k, v in states.items() if k not in self.exclude_from_loading
}
for exclude_key in self.exclude_from_loading:
if exclude_key not in states:
raise ValueError(f"{exclude_key} not found in state_dict.")
dcp.load(
states,
states_to_load,
checkpoint_id=self._create_checkpoint_id(step),
)
states.update(states_to_load)
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
)
Expand Down
23 changes: 22 additions & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


def string_list(raw_arg):
return raw_arg.split(",")
return [s.strip() for s in raw_arg.split(",") if s.strip()]


class JobConfig:
Expand Down Expand Up @@ -546,6 +546,17 @@ def __init__(self):
default=-1,
help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.",
)
self.parser.add_argument(
"--checkpoint.exclude_from_loading",
type=string_list,
nargs="*",
default=[],
help="""
Exclude specific keys from being loaded from the checkpoint.
Provide a comma-separated list of keys to exclude, e.g. 'optimizer,lr_scheduler,dataloader'.
This will load the model only, excluding the specified keys.
""",
)
# activation checkpointing configs
self.parser.add_argument(
"--activation_checkpoint.mode",
Expand Down Expand Up @@ -653,6 +664,13 @@ def parse_args(self, args_list: list = sys.argv[1:]):
exp["pipeline_parallel_split_points"] = string_list(
exp["pipeline_parallel_split_points"]
)
if (
"checkpoint" in args_dict
and "exclude_from_loading" in args_dict["checkpoint"]
and isinstance(args_dict["checkpoint"]["exclude_from_loading"], str)
):
ckpt = args_dict["checkpoint"]
ckpt["exclude_from_loading"] = string_list(ckpt["exclude_from_loading"])

# override args dict with cmd_args
cmd_args_dict = self._args_to_two_level_dict(cmd_args)
Expand Down Expand Up @@ -700,6 +718,9 @@ def parse_args_from_command_line(
# since the inferred type is just 'list' and it ends up flattening
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
aux_parser.add_argument("--" + arg, type=string_list)
elif arg == "checkpoint.exclude_from_loading":
# similar to the case above
aux_parser.add_argument("--" + arg, type=string_list)
else:
aux_parser.add_argument("--" + arg, type=type(val))

Expand Down
24 changes: 19 additions & 5 deletions torchtitan/models/llama/pipeline_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.pipelining import PipelineStage

from torch.distributed.pipelining.schedules import _PipelineSchedule
from torch.distributed.pipelining.schedules import _PipelineSchedule, get_schedule_class, ScheduleZBVZeroBubble

from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
Expand All @@ -39,14 +38,23 @@ def pipeline_llama(
device: DeviceType,
model_config: TransformerModelArgs,
loss_fn: Callable[..., torch.Tensor],
) -> tuple[_PipelineSchedule, list[nn.Module]]:
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
stages, models = pipeline_llama_manual_split(
model, pp_mesh, parallel_dims, job_config, device, model_config
)

pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)

return pp_schedule, models
# This is used in the train loop to determine whether to pass in the input_ids and labels
has_first_stage = False
has_last_stage = False
for stage in stages:
if stage.is_first:
has_first_stage = True
if stage.is_last:
has_last_stage = True

return pp_schedule, models, has_first_stage, has_last_stage


def pipeline_llama_manual_split(
Expand Down Expand Up @@ -112,7 +120,13 @@ def _build_stage(

stages = []
models = []
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style="loop"):

schedule_class = get_schedule_class(
job_config.experimental.pipeline_parallel_schedule
)
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"

for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
stage, model_chunk = _build_stage(
Expand Down
13 changes: 7 additions & 6 deletions torchtitan/train_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


from dataclasses import dataclass
from typing import Callable, Dict, List, Protocol, Tuple, Type, TypeAlias
from typing import Callable, Dict, Protocol, Type, TypeAlias

import torch.nn as nn
from torch.distributed.pipelining.schedules import _PipelineSchedule
Expand Down Expand Up @@ -36,15 +36,14 @@ class ModelProtocol(Protocol):
"""

@staticmethod
def from_model_args(args: BaseModelArgs) -> nn.Module:
...
def from_model_args(args: BaseModelArgs) -> nn.Module: ...


OptimizersBuilder: TypeAlias = Callable[
[List[nn.Module], JobConfig], OptimizersContainer
[list[nn.Module], JobConfig], OptimizersContainer
]
OptimizerBuilderWrapper: TypeAlias = Callable[
[List[nn.Module], JobConfig, OptimizersContainer], OptimizersContainer
[list[nn.Module], JobConfig, OptimizersContainer], OptimizersContainer
]
LRSchedulersBuilder: TypeAlias = Callable[[OptimizersContainer], LRSchedulersContainer]

Expand All @@ -55,7 +54,9 @@ class TrainSpec:
cls: Type[nn.Module]
config: Dict[str, BaseModelArgs]
parallelize_fn: Callable[[nn.Module], None]
pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]]
pipelining_fn: Callable[
[nn.Module], tuple[_PipelineSchedule, list[nn.Module], bool, bool]
]
build_optimizers_fn: OptimizersBuilder
build_lr_schedulers_fn: LRSchedulersBuilder

Expand Down
21 changes: 11 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,12 @@ def loss_fn(pred, labels):
# apply parallelisms and initialization
if parallel_dims.pp_enabled:
# apply PT-D Pipeline Parallel
pp_schedule, model_parts = train_spec.pipelining_fn(
(
pp_schedule,
model_parts,
has_first_stage,
has_last_stage,
) = train_spec.pipelining_fn(
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
)
# when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
Expand Down Expand Up @@ -286,22 +291,18 @@ def loss_fn(pred, labels):

if parallel_dims.pp_enabled:
# Pipeline Parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with train_context(optional_context_parallel_ctx):
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
losses = []
pp_schedule.step(target=labels, losses=losses)
targets, losses = (labels, []) if has_last_stage else (None, None)
if has_first_stage:
pp_schedule.step(input_ids, target=targets, losses=losses)
else:
pp_schedule.step()
pp_schedule.step(target=targets, losses=losses)

# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
loss = (
torch.mean(torch.stack(losses)).to(device)
if is_last_stage
if has_last_stage
else torch.tensor([-1.0], device=device)
)
else:
Expand Down

0 comments on commit c131309

Please sign in to comment.