Skip to content

Commit

Permalink
Allow users to use the customized model
Browse files Browse the repository at this point in the history
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

ghstack-source-id: b1d6d90da98b92c601de46dba5d3b3b98a34e687
Pull Request resolved: #814
  • Loading branch information
fegin committed Feb 11, 2025
1 parent 5940dde commit 57a8ba9
Show file tree
Hide file tree
Showing 17 changed files with 624 additions and 195 deletions.
122 changes: 122 additions & 0 deletions tests/unit_tests/test_train_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from functools import partial

import pytest
import torch
import torch.nn as nn
from torchtitan.config_manager import JobConfig
from torchtitan.train_spec import (
apply_to_train_specs,
BaseModelArgs,
get_train_spec,
ModelProtocol,
TrainSpec,
register_train_spec,
)
from torchtitan.models.llama import parallelize_llama, pipeline_llama
from torchtitan.optimizer import (
build_lr_schedulers,
build_optimizers,
OptimizersContainer,
)


class FakeModel(ModelProtocol):
@staticmethod
def from_model_args(args: BaseModelArgs) -> nn.Module:
return nn.Linear(8, 8)


def fake_build_optimizers(
model_parts: list[nn.Module], job_config: JobConfig
) -> OptimizersContainer:
optimizer_kwargs = {
"lr": 0.1,
"betas": (0.9, 0.95),
"weight_decay": 0.1,
"fused": True,
"foreach": False,
}
return OptimizersContainer(
model_parts=model_parts,
optimizer_kwargs=optimizer_kwargs,
name="Adam",
)


class TestTrainSpec:
def test_register_train_spec(self):
fake_config = {"fake": None}
spec = TrainSpec(
name="fake",
cls=FakeModel,
config=fake_config,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
register_train_spec(spec)
new_spec = get_train_spec("fake")
assert new_spec == spec

with pytest.raises(ValueError):
new_spec = get_train_spec("fake2")

def test_optim_hook(self):
fake_config = {"fake": None}
spec = TrainSpec(
name="fake2",
cls=FakeModel,
config=fake_config,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=fake_build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
register_train_spec(spec)
new_spec = get_train_spec("fake2")

# Demonstrate how to register a optimizer hook for all model specs
hook_called = False

def my_hook(
optimizer: torch.optim.Optimizer,
args,
kwargs,
model_parts: list[nn.Module],
) -> None:
nonlocal hook_called
hook_called = True

def register_optimizer_hook_to_spec(spec: TrainSpec) -> TrainSpec:
# Create a closure to capture the original spec.build_optimizers_fn
original_build_optimizers_fn = spec.build_optimizers_fn

def my_build_optimizer_fn(
model_parts: list[nn.Module], job_config: JobConfig
) -> OptimizersContainer:
optimizers = original_build_optimizers_fn(model_parts, job_config)
optimizers.register_step_post_hook(
partial(my_hook, model_parts=model_parts)
)
return optimizers

spec.build_optimizers_fn = my_build_optimizer_fn

apply_to_train_specs(register_optimizer_hook_to_spec)

model = new_spec.cls.from_model_args(BaseModelArgs())
model_parts = [model]
optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig())
assert optimizers.optimizers[0].__class__.__name__ == "Adam"
batch = torch.randn(8, 8)
model(batch).sum().backward()
assert not hook_called
optimizers.step()
assert hook_called
11 changes: 11 additions & 0 deletions torchtitan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models # noqa: F401
5 changes: 3 additions & 2 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import DataLoader

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import init_logger, logger
from torchtitan.optimizer import OptimizersContainer, SchedulersContainer
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer


class IntervalType(enum.Enum):
Expand Down Expand Up @@ -140,7 +141,7 @@ def __init__(
dataloader: DataLoader,
model_parts: List[nn.Module],
optimizers: OptimizersContainer,
lr_schedulers: SchedulersContainer,
lr_schedulers: LRSchedulersContainer,
states: Dict[str, Any],
job_config: JobConfig,
) -> None:
Expand Down
17 changes: 17 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,23 @@ def __init__(self):
The default value is 'allgather'.
""",
)
# I'm not particularly fond of this. Users can choose to write their own wrapper
# module and import TorchTitan training loop and execute it, which look cleaner.
# One reason to provide this option is to allow users to use the existing run script.
# While the script is pretty trivial now, we may add more logic when integrating
# with TorchFT.
# This option is subject to change and may be deleted in the future.
self.parser.add_argument(
"--experimental.custom_model_path",
type=str,
default="",
help="""
The --custom_model_path option allows to specify a custom path to a model module
that is not natively implemented within TorchTitan.
Acceptable values are the file system path to the module (e.g., my_models/model_x)
dotted import module (e.g., some_package.model_x).
""",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
12 changes: 4 additions & 8 deletions torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.models.llama import llama3_configs, Transformer

models_config = {
"llama3": llama3_configs,
}
# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models.llama # noqa: F401

model_name_to_cls = {"llama3": Transformer}

model_name_to_tokenizer = {
"llama3": "tiktoken",
}
model_name_to_tokenizer = {"llama3": "tiktoken"}
39 changes: 33 additions & 6 deletions torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,27 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.models.llama.model import ModelArgs, Transformer
from torchtitan.models.llama.model import Transformer, TransformerModelArgs
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.train_spec import register_train_spec, TrainSpec

from .parallelize_llama import parallelize_llama
from .pipeline_llama import pipeline_llama

__all__ = [
"parallelize_llama",
"pipeline_llama",
"TransformerModelArgs",
"Transformer",
"llama3_configs",
]

__all__ = ["Transformer"]

llama3_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
"8B": ModelArgs(
"debugmodel": TransformerModelArgs(
dim=256, n_layers=8, n_heads=16, rope_theta=500000
),
"8B": TransformerModelArgs(
dim=4096,
n_layers=32,
n_heads=32,
Expand All @@ -21,7 +35,7 @@
multiple_of=1024,
rope_theta=500000,
),
"70B": ModelArgs(
"70B": TransformerModelArgs(
dim=8192,
n_layers=80,
n_heads=64,
Expand All @@ -30,7 +44,7 @@
multiple_of=4096,
rope_theta=500000,
),
"405B": ModelArgs(
"405B": TransformerModelArgs(
dim=16384,
n_layers=126,
n_heads=128,
Expand All @@ -40,3 +54,16 @@
rope_theta=500000,
),
}


register_train_spec(
TrainSpec(
name="llama3",
cls=Transformer,
config=llama3_configs,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
)
25 changes: 13 additions & 12 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
import torch
import torch.nn.functional as F
from torch import nn
from torchtitan.train_spec import BaseModelArgs, ModelProtocol
from torchtitan.models.norms import build_norm


@dataclass
class ModelArgs:
class TransformerModelArgs(BaseModelArgs):
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
Expand Down Expand Up @@ -130,7 +131,7 @@ class Attention(nn.Module):
Multi-head attention module.
Args:
model_args (ModelArgs): Model configuration arguments.
model_args (TransformerModelArgs): Model configuration arguments.
Attributes:
n_kv_heads (int): Number of key and value heads.
Expand All @@ -144,7 +145,7 @@ class Attention(nn.Module):
"""

def __init__(self, model_args: ModelArgs):
def __init__(self, model_args: TransformerModelArgs):
super().__init__()
self.n_heads = model_args.n_heads
self.n_kv_heads = (
Expand Down Expand Up @@ -264,7 +265,7 @@ class TransformerBlock(nn.Module):
Args:
layer_id (int): Identifier for the layer.
model_args (ModelArgs): Model configuration arguments.
model_args (TransformerModelArgs): Model configuration arguments.
Attributes:
n_heads (int): Number of attention heads.
Expand All @@ -278,7 +279,7 @@ class TransformerBlock(nn.Module):
"""

def __init__(self, layer_id: int, model_args: ModelArgs):
def __init__(self, layer_id: int, model_args: TransformerModelArgs):
super().__init__()
self.n_heads = model_args.n_heads
self.dim = model_args.dim
Expand Down Expand Up @@ -331,15 +332,15 @@ def init_weights(self):
self.feed_forward.init_weights(self.weight_init_std)


class Transformer(nn.Module):
class Transformer(nn.Module, ModelProtocol):
"""
Transformer Module
Args:
model_args (ModelArgs): Model configuration arguments.
model_args (TransformerModelArgs): Model configuration arguments.
Attributes:
model_args (ModelArgs): Model configuration arguments.
model_args (TransformerModelArgs): Model configuration arguments.
vocab_size (int): Vocabulary size.
n_layers (int): Number of layers in the model.
tok_embeddings (ParallelEmbedding): Token embeddings.
Expand All @@ -350,7 +351,7 @@ class Transformer(nn.Module):
"""

def __init__(self, model_args: ModelArgs):
def __init__(self, model_args: TransformerModelArgs):
super().__init__()
self.model_args = model_args
self.vocab_size = model_args.vocab_size
Expand Down Expand Up @@ -446,12 +447,12 @@ def forward(self, tokens: torch.Tensor):
return output

@classmethod
def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
def from_model_args(cls, model_args: TransformerModelArgs) -> "Transformer":
"""
Initialize a Transformer model from a ModelArgs object.
Initialize a Transformer model from a TransformerModelArgs object.
Args:
model_args (ModelArgs): Model configuration arguments.
model_args (TransformerModelArgs): Model configuration arguments.
Returns:
Transformer: Transformer model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import logger
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms import ParallelDims


def parallelize_llama(
Expand Down
Loading

0 comments on commit 57a8ba9

Please sign in to comment.