Skip to content

Commit

Permalink
[Reland] Add Dynamic Model Import and ModelSpec Definition (#837)
Browse files Browse the repository at this point in the history
`ghstack` didn't land #814
correctly. Open this PR to do so. The detail discussion please refer to
#814

**What does this PR do?**
1. This PR introduces `ModelSpec` to describe a model and how to
parallelize a model.
    * All the models should call `register_model_spec()`. 
* Users can also use `--experimental.custom_model_path` to dynamically
import a model that is not implemented by TorchTitan. The module should
also call `register_model_spec()`.
2. This PR also refactors `OptimizersContainer` and
`LRSchedulersContainers`
* Fixes an issue that optimizers will accept parameters that
requires_grad is False.
    * Improve typing and docstring.
    * Improve the function and class reusability.
    * `OptimizersContainer` now inherits from `torch.optim.Optimizer` .
3. This PR also moves `parallelize_llama` and `pipelining_llama` to the
`llama` folder.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively
change TorchTitan code.

**Next steps**
1. Dataloader is not included
2. Checkpoint customization is not included yet.
  • Loading branch information
fegin authored Feb 12, 2025
1 parent 3996b63 commit fb0a942
Show file tree
Hide file tree
Showing 19 changed files with 643 additions and 204 deletions.
15 changes: 9 additions & 6 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from torchtitan.datasets import build_tokenizer
from torchtitan.float8 import Float8Handler
from torchtitan.logging import init_logger, logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.models import model_name_to_tokenizer
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
from torchtitan.parallelisms import ParallelDims
from torchtitan.train_spec import get_train_spec


def estimate_memory(job_config: JobConfig):
Expand Down Expand Up @@ -74,6 +75,8 @@ def estimate_memory(job_config: JobConfig):
"fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store
)

train_spec = get_train_spec(job_config.model.name)

# build meshes
world_mesh = parallel_dims.build_mesh(device_type="cuda")

Expand All @@ -95,8 +98,8 @@ def loss_fn(pred, labels):
)

# build model (using meta init)
model_cls = model_name_to_cls[model_name]
model_config = models_config[model_name][job_config.model.flavor]
model_cls = train_spec.cls
model_config = train_spec.config[job_config.model.flavor]
# set the model configs from training inputs:
# 1. norm type to decide which norm layer to use
# 2. vocab size from tokenizer
Expand All @@ -112,7 +115,7 @@ def loss_fn(pred, labels):
):

logger.info(
f"Building {model_name} {job_config.model.flavor} with {model_config}"
f"Building {train_spec.name} {job_config.model.flavor} with {model_config}"
)
with torch.device("meta"):
model = model_cls.from_model_args(model_config)
Expand All @@ -123,7 +126,7 @@ def loss_fn(pred, labels):
float8_handler.convert_to_float8_training(model)

# apply PT-D DP/TP parallelisms and activation checkpointing
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)
train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)

model.to_empty(device="cuda")
if not active_fake_mode():
Expand Down
13 changes: 7 additions & 6 deletions scripts/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@
)

from torchtitan import utils

from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_tokenizer
from torchtitan.logging import init_logger, logger
from torchtitan.metrics import build_device_memory_monitor
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.models import model_name_to_tokenizer
from torchtitan.parallelisms import ParallelDims

from torchtitan.train_spec import get_train_spec
from torchtitan.utils import device_module, device_type

# support running w/o installing as package
Expand Down Expand Up @@ -102,21 +103,21 @@ def test_generate(
device_module.set_device(device)
device_memory_monitor = build_device_memory_monitor()

model_name = config.model.name
train_spec = get_train_spec(config.model.name)

logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}")

# Tokenizer setup
tokenizer = build_tokenizer(
model_name_to_tokenizer[model_name], config.model.tokenizer_path
model_name_to_tokenizer[train_spec.name], config.model.tokenizer_path
)

model_config = models_config[model_name][config.model.flavor]
model_config = train_spec.config[config.model.flavor]
model_config.norm_type = config.model.norm_type
model_config.max_seq_len = config.training.seq_len
model_config.vocab_size = tokenizer.n_words

model_cls = model_name_to_cls[model_name]
model_cls = train_spec.cls
init_device = "meta" if world_size > 1 else device
with torch.device(init_device):
logger.info(f"Init model on init_device: {init_device}")
Expand Down
122 changes: 122 additions & 0 deletions tests/unit_tests/test_train_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from functools import partial

import pytest
import torch
import torch.nn as nn
from torchtitan.config_manager import JobConfig
from torchtitan.models.llama import parallelize_llama, pipeline_llama
from torchtitan.optimizer import (
build_lr_schedulers,
build_optimizers,
OptimizersContainer,
)
from torchtitan.train_spec import (
apply_to_train_specs,
BaseModelArgs,
get_train_spec,
ModelProtocol,
register_train_spec,
TrainSpec,
)


class FakeModel(ModelProtocol):
@staticmethod
def from_model_args(args: BaseModelArgs) -> nn.Module:
return nn.Linear(8, 8)


def fake_build_optimizers(
model_parts: list[nn.Module], job_config: JobConfig
) -> OptimizersContainer:
optimizer_kwargs = {
"lr": 0.1,
"betas": (0.9, 0.95),
"weight_decay": 0.1,
"fused": True,
"foreach": False,
}
return OptimizersContainer(
model_parts=model_parts,
optimizer_kwargs=optimizer_kwargs,
name="Adam",
)


class TestTrainSpec:
def test_register_train_spec(self):
fake_config = {"fake": None}
spec = TrainSpec(
name="fake",
cls=FakeModel,
config=fake_config,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
register_train_spec(spec)
new_spec = get_train_spec("fake")
assert new_spec == spec

with pytest.raises(ValueError):
new_spec = get_train_spec("fake2")

def test_optim_hook(self):
fake_config = {"fake": None}
spec = TrainSpec(
name="fake2",
cls=FakeModel,
config=fake_config,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=fake_build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
register_train_spec(spec)
new_spec = get_train_spec("fake2")

# Demonstrate how to register a optimizer hook for all model specs
hook_called = False

def my_hook(
optimizer: torch.optim.Optimizer,
args,
kwargs,
model_parts: list[nn.Module],
) -> None:
nonlocal hook_called
hook_called = True

def register_optimizer_hook_to_spec(spec: TrainSpec) -> TrainSpec:
# Create a closure to capture the original spec.build_optimizers_fn
original_build_optimizers_fn = spec.build_optimizers_fn

def my_build_optimizer_fn(
model_parts: list[nn.Module], job_config: JobConfig
) -> OptimizersContainer:
optimizers = original_build_optimizers_fn(model_parts, job_config)
optimizers.register_step_post_hook(
partial(my_hook, model_parts=model_parts)
)
return optimizers

spec.build_optimizers_fn = my_build_optimizer_fn

apply_to_train_specs(register_optimizer_hook_to_spec)

model = new_spec.cls.from_model_args(BaseModelArgs())
model_parts = [model]
optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig())
assert optimizers.optimizers[0].__class__.__name__ == "Adam"
batch = torch.randn(8, 8)
model(batch).sum().backward()
assert not hook_called
optimizers.step()
assert hook_called
11 changes: 11 additions & 0 deletions torchtitan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models # noqa: F401
5 changes: 3 additions & 2 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import DataLoader

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import init_logger, logger
from torchtitan.optimizer import OptimizersContainer, SchedulersContainer
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer


class IntervalType(enum.Enum):
Expand Down Expand Up @@ -140,7 +141,7 @@ def __init__(
dataloader: DataLoader,
model_parts: List[nn.Module],
optimizers: OptimizersContainer,
lr_schedulers: SchedulersContainer,
lr_schedulers: LRSchedulersContainer,
states: Dict[str, Any],
job_config: JobConfig,
) -> None:
Expand Down
17 changes: 17 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,23 @@ def __init__(self):
The default value is 'allgather'.
""",
)
# I'm not particularly fond of this. Users can choose to write their own wrapper
# module and import TorchTitan training loop and execute it, which look cleaner.
# One reason to provide this option is to allow users to use the existing run script.
# While the script is pretty trivial now, we may add more logic when integrating
# with TorchFT.
# This option is subject to change and may be deleted in the future.
self.parser.add_argument(
"--experimental.custom_model_path",
type=str,
default="",
help="""
The --custom_model_path option allows to specify a custom path to a model module
that is not natively implemented within TorchTitan.
Acceptable values are the file system path to the module (e.g., my_models/model_x)
dotted import module (e.g., some_package.model_x).
""",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
12 changes: 4 additions & 8 deletions torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.models.llama import llama3_configs, Transformer

models_config = {
"llama3": llama3_configs,
}
# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models.llama # noqa: F401

model_name_to_cls = {"llama3": Transformer}

model_name_to_tokenizer = {
"llama3": "tiktoken",
}
model_name_to_tokenizer = {"llama3": "tiktoken"}
39 changes: 33 additions & 6 deletions torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,27 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.models.llama.model import ModelArgs, Transformer
from torchtitan.models.llama.model import Transformer, TransformerModelArgs
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.train_spec import register_train_spec, TrainSpec

from .parallelize_llama import parallelize_llama
from .pipeline_llama import pipeline_llama

__all__ = [
"parallelize_llama",
"pipeline_llama",
"TransformerModelArgs",
"Transformer",
"llama3_configs",
]

__all__ = ["Transformer"]

llama3_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
"8B": ModelArgs(
"debugmodel": TransformerModelArgs(
dim=256, n_layers=8, n_heads=16, rope_theta=500000
),
"8B": TransformerModelArgs(
dim=4096,
n_layers=32,
n_heads=32,
Expand All @@ -21,7 +35,7 @@
multiple_of=1024,
rope_theta=500000,
),
"70B": ModelArgs(
"70B": TransformerModelArgs(
dim=8192,
n_layers=80,
n_heads=64,
Expand All @@ -30,7 +44,7 @@
multiple_of=4096,
rope_theta=500000,
),
"405B": ModelArgs(
"405B": TransformerModelArgs(
dim=16384,
n_layers=126,
n_heads=128,
Expand All @@ -40,3 +54,16 @@
rope_theta=500000,
),
}


register_train_spec(
TrainSpec(
name="llama3",
cls=Transformer,
config=llama3_configs,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
)
Loading

0 comments on commit fb0a942

Please sign in to comment.