diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 81ea10b8d..8ec614d33 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -19,9 +19,10 @@ from torchtitan.datasets import build_tokenizer from torchtitan.float8 import Float8Handler from torchtitan.logging import init_logger, logger -from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config +from torchtitan.models import model_name_to_tokenizer from torchtitan.optimizer import build_lr_schedulers, build_optimizers -from torchtitan.parallelisms import models_parallelize_fns, ParallelDims +from torchtitan.parallelisms import ParallelDims +from torchtitan.train_spec import get_train_spec def estimate_memory(job_config: JobConfig): @@ -74,6 +75,8 @@ def estimate_memory(job_config: JobConfig): "fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store ) + train_spec = get_train_spec(job_config.model.name) + # build meshes world_mesh = parallel_dims.build_mesh(device_type="cuda") @@ -95,8 +98,8 @@ def loss_fn(pred, labels): ) # build model (using meta init) - model_cls = model_name_to_cls[model_name] - model_config = models_config[model_name][job_config.model.flavor] + model_cls = train_spec.cls + model_config = train_spec.config[job_config.model.flavor] # set the model configs from training inputs: # 1. norm type to decide which norm layer to use # 2. vocab size from tokenizer @@ -112,7 +115,7 @@ def loss_fn(pred, labels): ): logger.info( - f"Building {model_name} {job_config.model.flavor} with {model_config}" + f"Building {train_spec.name} {job_config.model.flavor} with {model_config}" ) with torch.device("meta"): model = model_cls.from_model_args(model_config) @@ -123,7 +126,7 @@ def loss_fn(pred, labels): float8_handler.convert_to_float8_training(model) # apply PT-D DP/TP parallelisms and activation checkpointing - models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) + train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) model.to_empty(device="cuda") if not active_fake_mode(): diff --git a/tests/unit_tests/test_train_spec.py b/tests/unit_tests/test_train_spec.py new file mode 100644 index 000000000..4c01d74bd --- /dev/null +++ b/tests/unit_tests/test_train_spec.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial + +import pytest +import torch +import torch.nn as nn +from torchtitan.config_manager import JobConfig +from torchtitan.models.llama import parallelize_llama, pipeline_llama +from torchtitan.optimizer import ( + build_lr_schedulers, + build_optimizers, + OptimizersContainer, +) +from torchtitan.train_spec import ( + apply_to_train_specs, + BaseModelArgs, + get_train_spec, + ModelProtocol, + register_train_spec, + TrainSpec, +) + + +class FakeModel(ModelProtocol): + @staticmethod + def from_model_args(args: BaseModelArgs) -> nn.Module: + return nn.Linear(8, 8) + + +def fake_build_optimizers( + model_parts: list[nn.Module], job_config: JobConfig +) -> OptimizersContainer: + optimizer_kwargs = { + "lr": 0.1, + "betas": (0.9, 0.95), + "weight_decay": 0.1, + "fused": True, + "foreach": False, + } + return OptimizersContainer( + model_parts=model_parts, + optimizer_kwargs=optimizer_kwargs, + name="Adam", + ) + + +class TestTrainSpec: + def test_register_train_spec(self): + fake_config = {"fake": None} + spec = TrainSpec( + name="fake", + cls=FakeModel, + config=fake_config, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + ) + register_train_spec(spec) + new_spec = get_train_spec("fake") + assert new_spec == spec + + with pytest.raises(ValueError): + new_spec = get_train_spec("fake2") + + def test_optim_hook(self): + fake_config = {"fake": None} + spec = TrainSpec( + name="fake2", + cls=FakeModel, + config=fake_config, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=fake_build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + ) + register_train_spec(spec) + new_spec = get_train_spec("fake2") + + # Demonstrate how to register a optimizer hook for all model specs + hook_called = False + + def my_hook( + optimizer: torch.optim.Optimizer, + args, + kwargs, + model_parts: list[nn.Module], + ) -> None: + nonlocal hook_called + hook_called = True + + def register_optimizer_hook_to_spec(spec: TrainSpec) -> TrainSpec: + # Create a closure to capture the original spec.build_optimizers_fn + original_build_optimizers_fn = spec.build_optimizers_fn + + def my_build_optimizer_fn( + model_parts: list[nn.Module], job_config: JobConfig + ) -> OptimizersContainer: + optimizers = original_build_optimizers_fn(model_parts, job_config) + optimizers.register_step_post_hook( + partial(my_hook, model_parts=model_parts) + ) + return optimizers + + spec.build_optimizers_fn = my_build_optimizer_fn + + apply_to_train_specs(register_optimizer_hook_to_spec) + + model = new_spec.cls.from_model_args(BaseModelArgs()) + model_parts = [model] + optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig()) + assert optimizers.optimizers[0].__class__.__name__ == "Adam" + batch = torch.randn(8, 8) + model(batch).sum().backward() + assert not hook_called + optimizers.step() + assert hook_called diff --git a/torchtitan/__init__.py b/torchtitan/__init__.py new file mode 100644 index 000000000..be0d95f3c --- /dev/null +++ b/torchtitan/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +# Import the built-in models here so that the corresponding register_model_spec() +# will be called. +import torchtitan.models # noqa: F401 diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 7d1433830..c75111578 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -26,9 +26,10 @@ ) from torch.distributed.checkpoint.stateful import Stateful from torch.utils.data import DataLoader + from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging import init_logger, logger -from torchtitan.optimizer import OptimizersContainer, SchedulersContainer +from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer class IntervalType(enum.Enum): @@ -140,7 +141,7 @@ def __init__( dataloader: DataLoader, model_parts: List[nn.Module], optimizers: OptimizersContainer, - lr_schedulers: SchedulersContainer, + lr_schedulers: LRSchedulersContainer, states: Dict[str, Any], job_config: JobConfig, ) -> None: diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 2d3024912..0cda86b44 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -393,6 +393,23 @@ def __init__(self): The default value is 'allgather'. """, ) + # I'm not particularly fond of this. Users can choose to write their own wrapper + # module and import TorchTitan training loop and execute it, which look cleaner. + # One reason to provide this option is to allow users to use the existing run script. + # While the script is pretty trivial now, we may add more logic when integrating + # with TorchFT. + # This option is subject to change and may be deleted in the future. + self.parser.add_argument( + "--experimental.custom_model_path", + type=str, + default="", + help=""" + The --custom_model_path option allows to specify a custom path to a model module + that is not natively implemented within TorchTitan. + Acceptable values are the file system path to the module (e.g., my_models/model_x) + dotted import module (e.g., some_package.model_x). + """, + ) self.parser.add_argument( "--training.mixed_precision_param", type=str, diff --git a/torchtitan/models/__init__.py b/torchtitan/models/__init__.py index c666b0655..16d940d22 100644 --- a/torchtitan/models/__init__.py +++ b/torchtitan/models/__init__.py @@ -4,14 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtitan.models.llama import llama3_configs, Transformer -models_config = { - "llama3": llama3_configs, -} +# Import the built-in models here so that the corresponding register_model_spec() +# will be called. +import torchtitan.models.llama # noqa: F401 -model_name_to_cls = {"llama3": Transformer} -model_name_to_tokenizer = { - "llama3": "tiktoken", -} +model_name_to_tokenizer = {"llama3": "tiktoken"} diff --git a/torchtitan/models/llama/__init__.py b/torchtitan/models/llama/__init__.py index 3bb430d2c..5cdedb083 100644 --- a/torchtitan/models/llama/__init__.py +++ b/torchtitan/models/llama/__init__.py @@ -6,13 +6,27 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -from torchtitan.models.llama.model import ModelArgs, Transformer +from torchtitan.models.llama.model import Transformer, TransformerModelArgs +from torchtitan.optimizer import build_lr_schedulers, build_optimizers +from torchtitan.train_spec import register_train_spec, TrainSpec + +from .parallelize_llama import parallelize_llama +from .pipeline_llama import pipeline_llama + +__all__ = [ + "parallelize_llama", + "pipeline_llama", + "TransformerModelArgs", + "Transformer", + "llama3_configs", +] -__all__ = ["Transformer"] llama3_configs = { - "debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000), - "8B": ModelArgs( + "debugmodel": TransformerModelArgs( + dim=256, n_layers=8, n_heads=16, rope_theta=500000 + ), + "8B": TransformerModelArgs( dim=4096, n_layers=32, n_heads=32, @@ -21,7 +35,7 @@ multiple_of=1024, rope_theta=500000, ), - "70B": ModelArgs( + "70B": TransformerModelArgs( dim=8192, n_layers=80, n_heads=64, @@ -30,7 +44,7 @@ multiple_of=4096, rope_theta=500000, ), - "405B": ModelArgs( + "405B": TransformerModelArgs( dim=16384, n_layers=126, n_heads=128, @@ -40,3 +54,16 @@ rope_theta=500000, ), } + + +register_train_spec( + TrainSpec( + name="llama3", + cls=Transformer, + config=llama3_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + ) +) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 641ef6de9..0a9644511 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -13,11 +13,13 @@ import torch import torch.nn.functional as F from torch import nn + from torchtitan.models.norms import build_norm +from torchtitan.train_spec import BaseModelArgs, ModelProtocol @dataclass -class ModelArgs: +class TransformerModelArgs(BaseModelArgs): dim: int = 4096 n_layers: int = 32 n_heads: int = 32 @@ -130,7 +132,7 @@ class Attention(nn.Module): Multi-head attention module. Args: - model_args (ModelArgs): Model configuration arguments. + model_args (TransformerModelArgs): Model configuration arguments. Attributes: n_kv_heads (int): Number of key and value heads. @@ -144,7 +146,7 @@ class Attention(nn.Module): """ - def __init__(self, model_args: ModelArgs): + def __init__(self, model_args: TransformerModelArgs): super().__init__() self.n_heads = model_args.n_heads self.n_kv_heads = ( @@ -264,7 +266,7 @@ class TransformerBlock(nn.Module): Args: layer_id (int): Identifier for the layer. - model_args (ModelArgs): Model configuration arguments. + model_args (TransformerModelArgs): Model configuration arguments. Attributes: n_heads (int): Number of attention heads. @@ -278,7 +280,7 @@ class TransformerBlock(nn.Module): """ - def __init__(self, layer_id: int, model_args: ModelArgs): + def __init__(self, layer_id: int, model_args: TransformerModelArgs): super().__init__() self.n_heads = model_args.n_heads self.dim = model_args.dim @@ -331,15 +333,15 @@ def init_weights(self): self.feed_forward.init_weights(self.weight_init_std) -class Transformer(nn.Module): +class Transformer(nn.Module, ModelProtocol): """ Transformer Module Args: - model_args (ModelArgs): Model configuration arguments. + model_args (TransformerModelArgs): Model configuration arguments. Attributes: - model_args (ModelArgs): Model configuration arguments. + model_args (TransformerModelArgs): Model configuration arguments. vocab_size (int): Vocabulary size. n_layers (int): Number of layers in the model. tok_embeddings (ParallelEmbedding): Token embeddings. @@ -350,7 +352,7 @@ class Transformer(nn.Module): """ - def __init__(self, model_args: ModelArgs): + def __init__(self, model_args: TransformerModelArgs): super().__init__() self.model_args = model_args self.vocab_size = model_args.vocab_size @@ -446,12 +448,12 @@ def forward(self, tokens: torch.Tensor): return output @classmethod - def from_model_args(cls, model_args: ModelArgs) -> "Transformer": + def from_model_args(cls, model_args: TransformerModelArgs) -> "Transformer": """ - Initialize a Transformer model from a ModelArgs object. + Initialize a Transformer model from a TransformerModelArgs object. Args: - model_args (ModelArgs): Model configuration arguments. + model_args (TransformerModelArgs): Model configuration arguments. Returns: Transformer: Transformer model. diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/models/llama/parallelize_llama.py similarity index 99% rename from torchtitan/parallelisms/parallelize_llama.py rename to torchtitan/models/llama/parallelize_llama.py index fda12b53f..27c89feb0 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/models/llama/parallelize_llama.py @@ -33,7 +33,7 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging import logger -from torchtitan.parallelisms.parallel_dims import ParallelDims +from torchtitan.parallelisms import ParallelDims def parallelize_llama( diff --git a/torchtitan/parallelisms/pipeline_llama.py b/torchtitan/models/llama/pipeline_llama.py similarity index 83% rename from torchtitan/parallelisms/pipeline_llama.py rename to torchtitan/models/llama/pipeline_llama.py index 6605a57d6..1ede183f9 100644 --- a/torchtitan/parallelisms/pipeline_llama.py +++ b/torchtitan/models/llama/pipeline_llama.py @@ -7,23 +7,26 @@ # This file applies the PT-D pipeline parallelism to the Llama model. import copy -from typing import Callable, Union +from typing import Callable, Optional, Union import torch import torch.nn as nn from torch.distributed import DeviceMesh from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.schedules import _PipelineSchedule + from torchtitan.config_manager import JobConfig from torchtitan.logging import logger -from torchtitan.models.llama.model import ModelArgs -from torchtitan.parallelisms.parallel_dims import ParallelDims -from torchtitan.parallelisms.pipelining_utils import ( +from torchtitan.parallelisms import ParallelDims +from torchtitan.parallelisms.pipeline import ( build_pipeline_schedule, generate_split_points, stage_ids_this_rank, ) +from .model import TransformerModelArgs + DeviceType = Union[int, str, torch.device] @@ -34,9 +37,9 @@ def pipeline_llama( parallel_dims: ParallelDims, job_config: JobConfig, device: DeviceType, - model_config: ModelArgs, + model_config: TransformerModelArgs, loss_fn: Callable[..., torch.Tensor], -): +) -> tuple[_PipelineSchedule, list[nn.Module]]: stages, models = pipeline_llama_manual_split( model, pp_mesh, parallel_dims, job_config, device, model_config ) @@ -52,8 +55,8 @@ def pipeline_llama_manual_split( parallel_dims: ParallelDims, job_config: JobConfig, device: DeviceType, - model_config: ModelArgs, -): + model_config: TransformerModelArgs, +) -> tuple[list[PipelineStage], list[nn.Module]]: """ This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. @@ -67,10 +70,16 @@ def pipeline_llama_manual_split( splits = ( job_config.experimental.pipeline_parallel_split_points - or generate_split_points(job_config, parallel_dims.pp, model_config) + or generate_split_points(job_config, parallel_dims.pp, model_config.n_layers) ) - def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=False): + def _build_stage( + stage_idx: int, + start_layer: Optional[str], + stop_layer: Optional[str], + is_first: bool = False, + is_last: bool = False, + ) -> tuple[PipelineStage, nn.Module]: model = copy.deepcopy(whole_model) if not is_first: model.tok_embeddings = None diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index 1b724b7a1..e351fd132 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -4,8 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy import functools -from typing import Any, Dict, List +from typing import Any, Callable, Dict, Iterable, List import torch import torch.nn as nn @@ -15,35 +16,78 @@ StateDictOptions, ) from torch.distributed.checkpoint.stateful import Stateful -from torch.optim.lr_scheduler import LambdaLR +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR, LRScheduler + from torchtitan.config_manager import JobConfig -class OptimizersContainer(Stateful): - """Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages - and saving/loading optimizer state_dict at checkpoint. +__all__ = [ + "OptimizersContainer", + "LRSchedulersContainer", + "build_optimizers", + "build_lr_schedulers", +] + + +def _create_optimizer( + parameters: Iterable[nn.Parameter], optimizer_kwargs: Dict[str, Any], name: str +) -> Optimizer: + if name == "Adam": + return torch.optim.Adam(parameters, **optimizer_kwargs) + elif name == "AdamW": + return torch.optim.AdamW(parameters, **optimizer_kwargs) + else: + raise NotImplementedError(f"Optimizer {name} not added.") + + +class OptimizersContainer(Optimizer): + """A container for multiple optimizers. + + This class is used to wrap multiple optimizers into a single object that can be + used to reduce the complexity of the training loop. This mimics the behavior of + ``torch.optim.Optimizer``. This class currently only supports ``Adam`` and ``AdamW``. + + **Note** + Users who want to customize the optimizer behavior can inherit from this class and + extend the functionality as needed. The following methods must follow the same signature + as ``torch.optim.Optimizer`` class: ``step()``, ``zero_grad()``, ``state_dict()``, + ``load_state_dict()``. + + **Limitations** + This class assumes that all the optimizers are the same type and have the same + configurations. With this assumption, TorchTitan can support lr scheduler resharding + (e.g., loading a checkpoint with a different number of GPUs and/or different + parallelization strategy). Note that ``get_optimizer_state_dict`` already enables the + resharding for the optimizer state but not for the lr scheduler state, hence the limitation. + + Args: + model_parts (List[nn.Module]): List of model parts to be optimized. + optimizer_kwargs (Dict[str, Any]): Keyword arguments for the optimizers. + name (str): Name of the optimizers. """ + optimizers: List[Optimizer] + model_parts: List[nn.Module] + def __init__( self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str ) -> None: - self.optimizers = [] + all_params = [] + self.optimizers: List[Optimizer] = [] self.model_parts = model_parts for model in self.model_parts: - if name == "Adam": - # TODO: make the optimizer options configurable by toml/cmd args - optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs) - elif name == "AdamW": - optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs) - else: - raise NotImplementedError(f"Optimizer {name} not added.") - self.optimizers.append(optimizer) + params = [p for p in model.parameters() if p.requires_grad] + self.optimizers.append(_create_optimizer(params, optimizer_kwargs, name)) + all_params.extend(params) self._validate_length(len(self.model_parts)) + self._post_init(all_params, optimizer_kwargs) - def _validate_length(self, expected_length) -> None: - assert expected_length == len( - self.optimizers - ), "Must pass one optimizer per model part or per param if using OptimizersInBackwardContainer" + def __iter__(self) -> Optimizer: + return iter(self.optimizers) + + def __len__(self) -> int: + return len(self.optimizers) def step(self) -> None: for optimizer in self.optimizers: @@ -72,34 +116,40 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ) list(map(func, self.model_parts, self.optimizers)) + def _validate_length(self, expected_length: int) -> None: + assert expected_length == len( + self.optimizers + ), "Must pass one optimizer per model part or per param if using OptimizersInBackwardContainer" + + def _post_init( + self, all_params: list[nn.Parameter], optimizer_kwargs: dict[str, Any] + ) -> None: + # We need to call Optimizer.__init__() to initialize some necessary optimizer + # functionality such as hooks. + Optimizer.__init__(self, all_params, optimizer_kwargs) + class OptimizersInBackwardContainer(OptimizersContainer): - """Optimiers in backward to skip .step() and .zero_grad()""" + """OptimizersContainer for executing ``optim.step()`` in backward pass. + + This class extend ``OptimizersContainer`` to support optimizer step in + backward pass. ``step()`` and ``zero_grad()`` are no-op in this class. + Instead, ``register_post_accumulate_grad_hook`` is used to register a hook to + execute these methods when the gradient is accumulated. + """ def __init__( self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str ) -> None: - self.optimizers = [] + all_params = [] self.model_parts = model_parts + optim_dict = {} for model in self.model_parts: - if name == "Adam": - # TODO: make the optimizer options configurable by toml/cmd args - optim_dict.update( - { - param: torch.optim.Adam([param], **optimizer_kwargs) - for param in model.parameters() - } - ) - elif name == "AdamW": - optim_dict.update( - { - param: torch.optim.AdamW([param], **optimizer_kwargs) - for param in model.parameters() - } - ) - else: - raise NotImplementedError(f"Optimizer {name} not added.") + for p in model.parameters(): + if p.requires_grad: + optim_dict[p] = _create_optimizer([p], optimizer_kwargs, name) + all_params.append(p) def optim_hook(param) -> None: optim_dict[param].step() @@ -110,7 +160,7 @@ def optim_hook(param) -> None: if param.requires_grad: param.register_post_accumulate_grad_hook(optim_hook) - self.optimizers.extend([optim_dict[param] for param in model.parameters()]) + self.optimizers = list(optim_dict.values()) self._validate_length( sum( @@ -118,6 +168,7 @@ def optim_hook(param) -> None: for model in self.model_parts ) ) + self._post_init(all_params, optimizer_kwargs) def step(self) -> None: pass @@ -126,12 +177,25 @@ def zero_grad(self) -> None: pass -# consider split between PP and non-PP def build_optimizers( model_parts: List[nn.Module], job_config: JobConfig ) -> OptimizersContainer: - """Wrap one optimizer per model part in an OptimizersContainer which provides a single - step() and zero_grad() method for all the child optimizers. + """Create a OptimizersContainer for the given model parts and job config. + + This function creates a ``OptimizersContainer`` for the given model parts. + ``job_config`` should define the correct optimizer name and parameters. + This function currently supports creating ``OptimizersContainer`` and + ``OptimizersInBackwardContainer``. + + **Note** + Users who want to customize the optimizer behavior can create their own + ``OptimizersContainer`` subclass and ``build_optimizers``. Passing the + customized ``build_optimizers`` to ``TrainSpec`` will create the customized + ``OptimizersContainer``. + + Args: + model_parts (List[nn.Module]): List of model parts to be optimized. + job_config (JobConfig): Job config containing the optimizer name and parameters. """ optim_in_bwd = job_config.optimizer.early_step_in_backward if optim_in_bwd and job_config.experimental.pipeline_parallel_degree > 1: @@ -156,60 +220,108 @@ def build_optimizers( ) -def linear_warmup_linear_decay( - warmup_steps: int, decay_steps: int, current_step: int -) -> float: - """Computes linear warmup followed by linear decay. - Per LambdaLR requirement, this is accomplished by returning - a multiplicative factor to adjust the learning rate to - create the desired schedule. +class LRSchedulersContainer(Stateful): + """Container for multiple learning rate schedulers. + + This class is used to wrap multiple LRSchedulers into a single object that can be + used to reduce the complexity of the training loop. This mimics the behavior of + ``torch.optim.lr_scheduler.LRScheduler``. The design concept is the same as + ``OptimizersContainer``. This class currently only supports ``LambdaLR``. + + **Note** + Users who want to customize the lr_scheduler behavior can inherit from this class and + extend the functionality as needed. The following methods must follow the same + signature as ``torch.optim.lr_scheduler.LRScheduler`` class: ``step()``, ``state_dict()``, + ``load_state_dict()``. + + **Limitations** + This class assumes all the lr schedulers are the same. There is no easy way to support + resharding for multiple different LRSchedulers because LRScheduler.state_dict() is not + resharding friendly. Therefore, the limitation is used to allow TorchTitan to support + lr scheduler resharding. + + Args: + optimizers (OptimizersContainer): The corresponding optimizers for the lr_schedulers. """ - if current_step < warmup_steps: - # linear warmup - # 0-indexed step, hence + 1 adjustments - current_step += 1 - curr_adjustment = float(current_step / (warmup_steps + 1)) - else: - # linear decay - normalized_step = decay_steps - (current_step - warmup_steps) - curr_adjustment = 1 - (decay_steps - normalized_step) / decay_steps + schedulers: List[LRScheduler] - return curr_adjustment + def __init__(self, optimizers: OptimizersContainer, lr_lambda: Callable) -> None: + assert ( + len(optimizers) > 0 + ), "Must have at least one optimizer to create LRScheduler" + self.schedulers = [LambdaLR(optimizer, lr_lambda) for optimizer in optimizers] -class SchedulersContainer(Stateful): - """Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages""" + def __iter__(self) -> LRScheduler: + return iter(self.schedulers) - def __init__(self, optimizers, lr_lambda) -> None: - self.schedulers = [] - for optimizer in optimizers: - self.schedulers.append(LambdaLR(optimizer, lr_lambda=lr_lambda)) + def __len__(self) -> int: + return len(self.schedulers) def step(self) -> None: for scheduler in self.schedulers: scheduler.step() def state_dict(self) -> Dict[str, Any]: - # Currently, we have one scheduler per optimizer. However, when using MultiSchedule PP or optimizer-in-backward, - # there are multiple optimizers and schedulers, but the scheduler state_dict remains the same for all. - # Therefore, we only save the first one and later load it for all. - assert ( - len(self.schedulers) > 0 - ), "Must have at least one scheduler to save state_dict" + # While there may be multiple schedulers, we only save the first one because + # the state_dict is the same for all. See the limitations section in the + # docstring. return self.schedulers[0].state_dict() def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - # Load the same state_dict for all schedulers. The key value we're concerned with in scheduler.state_dict() is `last_epoch`, - # which is an integer that will be automatically copied. As long as `training.steps` and `training.warmup_steps` remain - # unchanged when resuming from a checkpoint, this approach is safe. We call `.copy()` here to ensure extra safety. + # Load the same state_dict for all schedulers. The key value we're concerned + # within ``LRScheduler.state_dict()`` is ``last_epoch``, which is an integer + # that is immutable. As long as ``training.steps`` and ``training.warmup_steps`` + # in ``job_config`` remain unchanged when resuming from a checkpoint, this + # approach is safe. We call ``copy()`` here to ensure extra safety. for scheduler in self.schedulers: - scheduler.load_state_dict(state_dict.copy()) + scheduler.load_state_dict(copy.deepcopy(state_dict)) + +def build_lr_schedulers( + optimizers: OptimizersContainer, job_config: JobConfig +) -> LRSchedulersContainer: + """Create a LRSchedulerContainer for the given optimizers and job config. -def build_lr_schedulers(optimizers, job_config: JobConfig) -> SchedulersContainer: + This function creates a ``LRSchedulersContainer`` for the given optimizers. + ``job_config`` should define the correct lr scheduler parameters. + + **Note** + Users who want to customize the lr scheduler behavior can create their own + ``LRSchedulersContainer`` subclass and ``build_lr_scheduler``. Passing the + customized ``build_lr_schedulers`` to ``TrainSpec`` will create the customized + ``LRSchedulersContainer``. + + + Args: + optimizers (OptimizersContainer): The corresponding optimizers for the + lr_schedulers. + """ warmup_steps = int(job_config.training.warmup_steps) decay_steps = float(max(1, job_config.training.steps - warmup_steps)) - lr_lambda = functools.partial(linear_warmup_linear_decay, warmup_steps, decay_steps) - return SchedulersContainer(optimizers, lr_lambda) + def linear_warmup_linear_decay( + warmup_steps: int, decay_steps: int, current_step: int + ) -> float: + """Computes linear warmup followed by linear decay. + + Per LambdaLR requirement, this is accomplished by returning + a multiplicative factor to adjust the learning rate to + create the desired schedule. + """ + if current_step < warmup_steps: + # linear warmup + # 0-indexed step, hence + 1 adjustments + current_step += 1 + curr_adjustment = float(current_step / (warmup_steps + 1)) + + else: + # linear decay + normalized_step = decay_steps - (current_step - warmup_steps) + curr_adjustment = 1 - (decay_steps - normalized_step) / decay_steps + + return curr_adjustment + + lr_lambda = functools.partial(linear_warmup_linear_decay, warmup_steps, decay_steps) + return LRSchedulersContainer(optimizers, lr_lambda) diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index f1f1d1fba..1a187282e 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -6,19 +6,6 @@ from torchtitan.parallelisms.parallel_dims import ParallelDims -from torchtitan.parallelisms.parallelize_llama import parallelize_llama -from torchtitan.parallelisms.pipeline_llama import pipeline_llama -__all__ = [ - "models_parallelize_fns", - "models_pipelining_fns", - "ParallelDims", -] - -models_parallelize_fns = { - "llama3": parallelize_llama, -} -models_pipelining_fns = { - "llama3": pipeline_llama, -} +__all__ = ["ParallelDims"] diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index 13d066a84..f5e6a0e4c 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -8,9 +8,13 @@ from functools import cached_property from torch.distributed.device_mesh import init_device_mesh + from torchtitan.logging import logger +__all__ = ["ParallelDims"] + + @dataclass class ParallelDims: dp_replicate: int diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipeline.py similarity index 74% rename from torchtitan/parallelisms/pipelining_utils.py rename to torchtitan/parallelisms/pipeline.py index 7b2994f80..aa47189c7 100644 --- a/torchtitan/parallelisms/pipelining_utils.py +++ b/torchtitan/parallelisms/pipeline.py @@ -4,25 +4,40 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import os -from typing import List, Tuple +from typing import Callable from torch.distributed.pipelining.schedules import ( + _PipelineSchedule, _PipelineScheduleRuntime, get_schedule_class, PipelineScheduleMulti, PipelineScheduleSingle, ) +from torch.distributed.pipelining.stage import PipelineStage + from torchtitan.config_manager import JobConfig from torchtitan.logging import logger -from torchtitan.models.llama.model import ModelArgs +__all__ = ["build_pipeline_schedule", "generate_split_points", "stage_ids_this_rank"] + + +# TODO: It's unclear if this API is general enough to be used by other models. +# If not, we should move it to a Transformer-specific directory. def generate_split_points( - job_config: JobConfig, pp_dim: int, model_config: ModelArgs -) -> List[str]: + job_config: JobConfig, pp_dim: int, num_layers: int +) -> list[str]: """ Generate a default split point based on the number of layers and pipeline parallel dimension. + + Args: + job_config (JobConfig): The job configuration. + pp_dim (int): The pipeline parallel dimension. + num_layers (int): The number of layers in the model. + + Returns: + list[str]: A list of split point FQNs. """ schedule_class = get_schedule_class( @@ -39,7 +54,6 @@ def generate_split_points( f"Unsupported pipeline schedule: {job_config.experimental.pipeline_parallel_schedule}" ) total_stages = pp_dim * num_stages_per_rank - num_layers = model_config.n_layers if total_stages > num_layers: raise ValueError("Total stages cannot be greater than the number of layers") @@ -60,13 +74,25 @@ def generate_split_points( current_layer += base_interval splits.append("layers." + str(current_layer)) logger.info( - f"No 'pipeline_parallel_split_points' provided so the generated splits are: {splits} \ -This may be sub-optimal as the number of layers per stage may be unbalanced." + f"No 'pipeline_parallel_split_points' provided so the generated splits are: {splits} " + "This may be sub-optimal as the number of layers per stage may be unbalanced." ) return splits -def build_pipeline_schedule(job_config, stages, loss_fn): +def build_pipeline_schedule( + job_config: JobConfig, stages: list[PipelineStage], loss_fn: Callable +) -> _PipelineSchedule: + """Builds a pipeline schedule for the given job configuration and stages. + + Args: + job_config (JobConfig): The job configuration. + stages (list[PipelineStage]): The stages to be scheduled. + loss_fn (Callable): The loss function. + + Returns: + _PipelineSchedule: The pipeline schedule for the given stages. + """ pp_schedule_csv = job_config.experimental.pipeline_parallel_schedule_csv # Validate that pp_schedule_csv is a valid path @@ -89,8 +115,8 @@ def build_pipeline_schedule(job_config, stages, loss_fn): n_microbatches = num_total_stages elif n_microbatches < num_total_stages: logger.warning( - f"Number of microbatches ({n_microbatches}) is less than the total number \ -of stages ({num_total_stages}) which may result in a bubble in the pipeline." + f"Number of microbatches ({n_microbatches}) is less than the total number " + f"of stages ({num_total_stages}) which may result in a bubble in the pipeline." ) # validate that the batch size is divisible by the number of microbatches otherwise we'll hang or error during training @@ -106,8 +132,8 @@ def build_pipeline_schedule(job_config, stages, loss_fn): loss_fn=loss_fn, ) logger.info( - f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule} \ -with {n_microbatches} microbatches and {num_total_stages} stages." + f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule} " + f"with {n_microbatches} microbatches and {num_total_stages} stages." ) if pp_schedule_csv: @@ -115,8 +141,10 @@ def build_pipeline_schedule(job_config, stages, loss_fn): PipelineScheduleSingle, PipelineScheduleMulti, _PipelineScheduleRuntime, - ], "Only PipelineScheduleSingle (single stage), PipelineScheduleMulti (multistage), \ - and _PipelineScheduleRuntime support csv schedules" + ], ( + "Only PipelineScheduleSingle (single stage), PipelineScheduleMulti (multistage), " + "and _PipelineScheduleRuntime support csv schedules" + ) schedule._load_csv(pp_schedule_csv) return schedule @@ -125,7 +153,7 @@ def build_pipeline_schedule(job_config, stages, loss_fn): # TODO(whc) should this be a utility inside torch.pipelining? def stage_ids_this_rank( pp_rank: int, pp_size: int, num_stages: int, style: str = "loop" -) -> Tuple[int]: +) -> tuple[int]: """Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule""" assert ( num_stages % pp_size == 0 diff --git a/torchtitan/parallelisms/utils.py b/torchtitan/parallelisms/utils.py deleted file mode 100644 index a84af7981..000000000 --- a/torchtitan/parallelisms/utils.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from typing import Optional - -import torch -from torchtitan.logging import logger - - -def check_if_feature_in_pytorch( - feature_name: str, - pull_request: str, - min_nightly_version: Optional[str] = None, -) -> None: - if "git" in torch.__version__: # pytorch is built from source - # notify users to check if the pull request is included in their pytorch - logger.warning( - "detected that the pytorch is built from source. Please make sure the PR " - f"({pull_request_link}) is included in pytorch for correct {feature_name}." - ) - elif min_nightly_version is not None and torch.__version__ < min_nightly_version: - logger.warning( - f"detected that the pytorch version {torch.__version__} is older than " - f"{min_nightly_version}. Please upgrade a newer version to include the " - f"change in ({pull_request_link}) for correct {feature_name}." - ) diff --git a/torchtitan/train_spec.py b/torchtitan/train_spec.py new file mode 100644 index 000000000..f199d2e03 --- /dev/null +++ b/torchtitan/train_spec.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + + +from dataclasses import dataclass +from typing import Callable, Dict, List, Protocol, Tuple, Type, TypeAlias + +import torch.nn as nn +from torch.distributed.pipelining.schedules import _PipelineSchedule + +from torchtitan.config_manager import JobConfig +from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer + + +@dataclass +class BaseModelArgs: + """All ModelArgs should inherit from this class. + + The only usage of this class is type checking but allows us to extend common + arguments to all models in the future. + """ + + _enforced: str = "This field is used to enforce all fields have defaults." + + +class ModelProtocol(Protocol): + """Defines the interface for a model class. + + This is used to enforce that all model classes have some methods that are + required by the TorchTitan trainer. + """ + + @staticmethod + def from_model_args(args: BaseModelArgs) -> nn.Module: + ... + + +OptimizersBuilder: TypeAlias = Callable[ + [List[nn.Module], JobConfig], OptimizersContainer +] +OptimizerBuilderWrapper: TypeAlias = Callable[ + [List[nn.Module], JobConfig, OptimizersContainer], OptimizersContainer +] +LRSchedulersBuilder: TypeAlias = Callable[[OptimizersContainer], LRSchedulersContainer] + + +@dataclass +class TrainSpec: + name: str + cls: Type[nn.Module] + config: Dict[str, BaseModelArgs] + parallelize_fn: Callable[[nn.Module], None] + pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]] + build_optimizers_fn: OptimizersBuilder + build_lr_schedulers_fn: LRSchedulersBuilder + + # TODO: Add a ``build_dataloader_fn`` + + # TODO: Add a FQN convert fn to allow users to load checkpoints from + # HuggingFace or other sources that have different FQN conventions. + + +_train_specs = {} + + +def register_train_spec(train_spec: TrainSpec) -> None: + global _train_specs + if train_spec.name in _train_specs: + raise ValueError(f"Model {train_spec.name} is already registered.") + + _train_specs[train_spec.name] = train_spec + + +def get_train_spec(name: str) -> TrainSpec: + global _train_specs + if name not in _train_specs: + raise ValueError(f"Model {name} is not registered.") + return _train_specs[name] + + +def apply_to_train_specs(func: Callable[[TrainSpec], TrainSpec]) -> None: + global _train_specs + for name, train_spec in _train_specs.items(): + _train_specs[name] = func(train_spec) diff --git a/torchtitan/utils.py b/torchtitan/utils.py index c9dcf2fac..122a406f6 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -6,9 +6,11 @@ import contextlib import gc +import importlib import math import os import subprocess +import sys from dataclasses import dataclass from datetime import timedelta from typing import Generator, Iterable, List, Optional, Set, Union @@ -20,6 +22,7 @@ from torch._utils import _get_available_device_type, _get_device_module from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor + from torchtitan.logging import logger @@ -411,3 +414,55 @@ def clip_grad_norm_( torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) return total_norm + + +def check_if_feature_in_pytorch( + feature_name: str, + pull_request: str, + min_nightly_version: Optional[str] = None, +) -> None: + if "git" in torch.__version__: # pytorch is built from source + # notify users to check if the pull request is included in their pytorch + logger.warning( + "detected that the pytorch is built from source. Please make sure the PR " + f"({pull_request_link}) is included in pytorch for correct {feature_name}." + ) + elif min_nightly_version is not None and torch.__version__ < min_nightly_version: + logger.warning( + f"detected that the pytorch version {torch.__version__} is older than " + f"{min_nightly_version}. Please upgrade a newer version to include the " + f"change in ({pull_request_link}) for correct {feature_name}." + ) + + +def import_module_from_path(path: str): + path = os.path.expanduser(path) + + # 1. Check if path is an existing file or directory path. + if os.path.exists(path): + if not os.path.isdir(path): + raise ImportError(f"Path '{path}' is not a directory.") + init_file = os.path.join(path, "__init__.py") + if os.path.isfile(init_file): + return _import_module_from_init(path) + + raise ImportError( + f"Directory '{path}' is not a Python package because it does not " + "contain an __init__.py file." + ) + + # 2. If not a valid path, assume it's a dotted module name. + return importlib.import_module(path) + + +def _import_module_from_init(path: str): + init_file = os.path.join(path, "__init__.py") + module_name = os.path.basename(path) + spec = importlib.util.spec_from_file_location(module_name, init_file) + if spec is None: + raise ImportError(f"Could not create spec from '{init_file}'") + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module diff --git a/train.py b/train.py index 761393f7f..a278a1435 100644 --- a/train.py +++ b/train.py @@ -19,15 +19,11 @@ from torchtitan.float8 import Float8Handler from torchtitan.logging import init_logger, logger from torchtitan.metrics import build_device_memory_monitor, build_metric_logger -from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config -from torchtitan.optimizer import build_lr_schedulers, build_optimizers -from torchtitan.parallelisms import ( - models_parallelize_fns, - models_pipelining_fns, - ParallelDims, -) +from torchtitan.models import model_name_to_tokenizer +from torchtitan.parallelisms import ParallelDims from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling -from torchtitan.utils import device_module, device_type +from torchtitan.train_spec import get_train_spec +from torchtitan.utils import device_module, device_type, import_module_from_path # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html @@ -36,6 +32,9 @@ def main(job_config: JobConfig): init_logger() logger.info(f"Starting job: {job_config.job.description}") + if job_config.experimental.custom_model_path: + import_module_from_path(job_config.experimental.custom_model_path) + if job_config.job.print_args: logger.info(f"Running with args: {job_config.to_dict()}") @@ -79,10 +78,10 @@ def main(job_config: JobConfig): utils.set_determinism( world_mesh, device, job_config.training.seed, job_config.training.deterministic ) - model_name = job_config.model.name + train_spec = get_train_spec(job_config.model.name) # build tokenizer - tokenizer_type = model_name_to_tokenizer[model_name] + tokenizer_type = model_name_to_tokenizer[train_spec.name] tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) # build dataloader data_loader = build_hf_data_loader( @@ -96,8 +95,8 @@ def main(job_config: JobConfig): ) # build model (using meta init) - model_cls = model_name_to_cls[model_name] - model_config = models_config[model_name][job_config.model.flavor] + model_cls = train_spec.cls + model_config = train_spec.config[job_config.model.flavor] # set the model configs from training inputs: # 1. norm type to decide which norm layer to use # 2. vocab size from tokenizer @@ -106,7 +105,9 @@ def main(job_config: JobConfig): model_config.vocab_size = tokenizer.n_words model_config.max_seq_len = job_config.training.seq_len - logger.info(f"Building {model_name} {job_config.model.flavor} with {model_config}") + logger.info( + f"Building {train_spec.name} {job_config.model.flavor} with {model_config}" + ) with torch.device("meta"): model = model_cls.from_model_args(model_config) @@ -123,7 +124,7 @@ def main(job_config: JobConfig): job_config.training.seq_len, ) logger.info( - f"{color.blue}Model {model_name} {job_config.model.flavor} " + f"{color.blue}Model {train_spec.name} {job_config.model.flavor} " f"{color.red}size: {model_param_count:,} total parameters{color.reset}" ) @@ -151,7 +152,7 @@ def loss_fn(pred, labels): # apply parallelisms and initialization if parallel_dims.pp_enabled: # apply PT-D Pipeline Parallel - pp_schedule, model_parts = models_pipelining_fns[model_name]( + pp_schedule, model_parts = train_spec.pipelining_fn( model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn ) # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead @@ -162,14 +163,14 @@ def loss_fn(pred, labels): # optimizer, and checkpointing for m in model_parts: # apply SPMD-style PT-D techniques - models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) + train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config) m.to_empty(device=init_device) with torch.no_grad(): m.init_weights(buffer_device=buffer_device) m.train() else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel - models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) + train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) model.to_empty(device=init_device) with torch.no_grad(): model.init_weights(buffer_device=buffer_device) @@ -185,8 +186,8 @@ def loss_fn(pred, labels): ) # build optimizer after applying parallelisms to the model - optimizers = build_optimizers(model_parts, job_config) - lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) + optimizers = train_spec.build_optimizers_fn(model_parts, job_config) + lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config) train_state = TrainState()