Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Reland] Add Dynamic Model Import and ModelSpec Definition #837

Merged
merged 24 commits into from
Feb 12, 2025
15 changes: 9 additions & 6 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from torchtitan.datasets import build_tokenizer
from torchtitan.float8 import Float8Handler
from torchtitan.logging import init_logger, logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.models import model_name_to_tokenizer
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
from torchtitan.parallelisms import ParallelDims
from torchtitan.train_spec import get_train_spec


def estimate_memory(job_config: JobConfig):
Expand Down Expand Up @@ -74,6 +75,8 @@ def estimate_memory(job_config: JobConfig):
"fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store
)

train_spec = get_train_spec(job_config.model.name)

# build meshes
world_mesh = parallel_dims.build_mesh(device_type="cuda")

Expand All @@ -95,8 +98,8 @@ def loss_fn(pred, labels):
)

# build model (using meta init)
model_cls = model_name_to_cls[model_name]
model_config = models_config[model_name][job_config.model.flavor]
model_cls = train_spec.cls
model_config = train_spec.config[job_config.model.flavor]
# set the model configs from training inputs:
# 1. norm type to decide which norm layer to use
# 2. vocab size from tokenizer
Expand All @@ -112,7 +115,7 @@ def loss_fn(pred, labels):
):

logger.info(
f"Building {model_name} {job_config.model.flavor} with {model_config}"
f"Building {train_spec.name} {job_config.model.flavor} with {model_config}"
)
with torch.device("meta"):
model = model_cls.from_model_args(model_config)
Expand All @@ -123,7 +126,7 @@ def loss_fn(pred, labels):
float8_handler.convert_to_float8_training(model)

# apply PT-D DP/TP parallelisms and activation checkpointing
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)
train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)

model.to_empty(device="cuda")
if not active_fake_mode():
Expand Down
13 changes: 7 additions & 6 deletions scripts/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@
)

from torchtitan import utils

from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_tokenizer
from torchtitan.logging import init_logger, logger
from torchtitan.metrics import build_device_memory_monitor
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.models import model_name_to_tokenizer
from torchtitan.parallelisms import ParallelDims

from torchtitan.train_spec import get_train_spec
from torchtitan.utils import device_module, device_type

# support running w/o installing as package
Expand Down Expand Up @@ -102,21 +103,21 @@ def test_generate(
device_module.set_device(device)
device_memory_monitor = build_device_memory_monitor()

model_name = config.model.name
train_spec = get_train_spec(config.model.name)

logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}")

# Tokenizer setup
tokenizer = build_tokenizer(
model_name_to_tokenizer[model_name], config.model.tokenizer_path
model_name_to_tokenizer[train_spec.name], config.model.tokenizer_path
)

model_config = models_config[model_name][config.model.flavor]
model_config = train_spec.config[config.model.flavor]
model_config.norm_type = config.model.norm_type
model_config.max_seq_len = config.training.seq_len
model_config.vocab_size = tokenizer.n_words

model_cls = model_name_to_cls[model_name]
model_cls = train_spec.cls
init_device = "meta" if world_size > 1 else device
with torch.device(init_device):
logger.info(f"Init model on init_device: {init_device}")
Expand Down
122 changes: 122 additions & 0 deletions tests/unit_tests/test_train_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from functools import partial

import pytest
import torch
import torch.nn as nn
from torchtitan.config_manager import JobConfig
from torchtitan.models.llama import parallelize_llama, pipeline_llama
from torchtitan.optimizer import (
build_lr_schedulers,
build_optimizers,
OptimizersContainer,
)
from torchtitan.train_spec import (
apply_to_train_specs,
BaseModelArgs,
get_train_spec,
ModelProtocol,
register_train_spec,
TrainSpec,
)


class FakeModel(ModelProtocol):
@staticmethod
def from_model_args(args: BaseModelArgs) -> nn.Module:
return nn.Linear(8, 8)


def fake_build_optimizers(
model_parts: list[nn.Module], job_config: JobConfig
) -> OptimizersContainer:
optimizer_kwargs = {
"lr": 0.1,
"betas": (0.9, 0.95),
"weight_decay": 0.1,
"fused": True,
"foreach": False,
}
return OptimizersContainer(
model_parts=model_parts,
optimizer_kwargs=optimizer_kwargs,
name="Adam",
)


class TestTrainSpec:
def test_register_train_spec(self):
fake_config = {"fake": None}
spec = TrainSpec(
name="fake",
cls=FakeModel,
config=fake_config,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
register_train_spec(spec)
new_spec = get_train_spec("fake")
assert new_spec == spec

with pytest.raises(ValueError):
new_spec = get_train_spec("fake2")

def test_optim_hook(self):
fake_config = {"fake": None}
spec = TrainSpec(
name="fake2",
cls=FakeModel,
config=fake_config,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=fake_build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
register_train_spec(spec)
new_spec = get_train_spec("fake2")

# Demonstrate how to register a optimizer hook for all model specs
hook_called = False

def my_hook(
optimizer: torch.optim.Optimizer,
args,
kwargs,
model_parts: list[nn.Module],
) -> None:
nonlocal hook_called
hook_called = True

def register_optimizer_hook_to_spec(spec: TrainSpec) -> TrainSpec:
# Create a closure to capture the original spec.build_optimizers_fn
original_build_optimizers_fn = spec.build_optimizers_fn

def my_build_optimizer_fn(
model_parts: list[nn.Module], job_config: JobConfig
) -> OptimizersContainer:
optimizers = original_build_optimizers_fn(model_parts, job_config)
optimizers.register_step_post_hook(
partial(my_hook, model_parts=model_parts)
)
return optimizers

spec.build_optimizers_fn = my_build_optimizer_fn

apply_to_train_specs(register_optimizer_hook_to_spec)

model = new_spec.cls.from_model_args(BaseModelArgs())
model_parts = [model]
optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig())
assert optimizers.optimizers[0].__class__.__name__ == "Adam"
batch = torch.randn(8, 8)
model(batch).sum().backward()
assert not hook_called
optimizers.step()
assert hook_called
11 changes: 11 additions & 0 deletions torchtitan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models # noqa: F401
5 changes: 3 additions & 2 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import DataLoader

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import init_logger, logger
from torchtitan.optimizer import OptimizersContainer, SchedulersContainer
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer


class IntervalType(enum.Enum):
Expand Down Expand Up @@ -140,7 +141,7 @@ def __init__(
dataloader: DataLoader,
model_parts: List[nn.Module],
optimizers: OptimizersContainer,
lr_schedulers: SchedulersContainer,
lr_schedulers: LRSchedulersContainer,
states: Dict[str, Any],
job_config: JobConfig,
) -> None:
Expand Down
17 changes: 17 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,23 @@ def __init__(self):
The default value is 'allgather'.
""",
)
# I'm not particularly fond of this. Users can choose to write their own wrapper
# module and import TorchTitan training loop and execute it, which look cleaner.
# One reason to provide this option is to allow users to use the existing run script.
# While the script is pretty trivial now, we may add more logic when integrating
# with TorchFT.
# This option is subject to change and may be deleted in the future.
self.parser.add_argument(
"--experimental.custom_model_path",
type=str,
default="",
help="""
The --custom_model_path option allows to specify a custom path to a model module
that is not natively implemented within TorchTitan.
Acceptable values are the file system path to the module (e.g., my_models/model_x)
dotted import module (e.g., some_package.model_x).
""",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
12 changes: 4 additions & 8 deletions torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.models.llama import llama3_configs, Transformer

models_config = {
"llama3": llama3_configs,
}
# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models.llama # noqa: F401

model_name_to_cls = {"llama3": Transformer}

model_name_to_tokenizer = {
"llama3": "tiktoken",
}
model_name_to_tokenizer = {"llama3": "tiktoken"}
39 changes: 33 additions & 6 deletions torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,27 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.models.llama.model import ModelArgs, Transformer
from torchtitan.models.llama.model import Transformer, TransformerModelArgs
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.train_spec import register_train_spec, TrainSpec

from .parallelize_llama import parallelize_llama
from .pipeline_llama import pipeline_llama

__all__ = [
"parallelize_llama",
"pipeline_llama",
"TransformerModelArgs",
"Transformer",
"llama3_configs",
]

__all__ = ["Transformer"]

llama3_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
"8B": ModelArgs(
"debugmodel": TransformerModelArgs(
dim=256, n_layers=8, n_heads=16, rope_theta=500000
),
"8B": TransformerModelArgs(
dim=4096,
n_layers=32,
n_heads=32,
Expand All @@ -21,7 +35,7 @@
multiple_of=1024,
rope_theta=500000,
),
"70B": ModelArgs(
"70B": TransformerModelArgs(
dim=8192,
n_layers=80,
n_heads=64,
Expand All @@ -30,7 +44,7 @@
multiple_of=4096,
rope_theta=500000,
),
"405B": ModelArgs(
"405B": TransformerModelArgs(
dim=16384,
n_layers=126,
n_heads=128,
Expand All @@ -40,3 +54,16 @@
rope_theta=500000,
),
}


register_train_spec(
TrainSpec(
name="llama3",
cls=Transformer,
config=llama3_configs,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
)
Loading