Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Dynamic Model Import and ModelSpec Definition #814

Merged
merged 20 commits into from
Feb 12, 2025
124 changes: 124 additions & 0 deletions tests/unit_tests/test_model_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from functools import partial

import pytest
import torch
import torch.nn as nn
from torchtitan.config_manager import JobConfig
from torchtitan.model_spec import (
apply_to_model_specs,
BaseModelArgs,
get_model_spec,
ModelProtocol,
ModelSpec,
register_model_spec,
)
from torchtitan.models.llama import parallelize_llama, pipeline_llama
from torchtitan.optimizer import (
build_lr_schedulers,
build_optimizers,
OptimizersContainer,
)


class FakeModel(ModelProtocol):
@staticmethod
def from_model_args(args: BaseModelArgs) -> nn.Module:
return nn.Linear(8, 8)


def fake_build_optimizers(
model_parts: list[nn.Module], job_config: JobConfig
) -> OptimizersContainer:
optimizer_kwargs = {
"lr": 0.1,
"betas": (0.9, 0.95),
"weight_decay": 0.1,
"fused": True,
"foreach": False,
}
return OptimizersContainer(
model_parts=model_parts,
optimizer_kwargs=optimizer_kwargs,
name="Adam",
)


class TestModelSpec:
def test_register_model_spec(self):
fake_config = {"fake": None}
spec = ModelSpec(
name="fake",
cls=FakeModel,
config=fake_config,
tokenizer="tiktoken",
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
register_model_spec(spec)
new_spec = get_model_spec("fake")
assert new_spec == spec

with pytest.raises(ValueError):
new_spec = get_model_spec("fake2")

def test_optim_hook(self):
fake_config = {"fake": None}
spec = ModelSpec(
name="fake2",
cls=FakeModel,
config=fake_config,
tokenizer="tiktoken",
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=fake_build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
register_model_spec(spec)
new_spec = get_model_spec("fake2")

# Demonstrate how to register a optimizer hook for all model specs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Demonstrate how to register a optimizer hook for all model specs.

hook_called = False

def my_hook(
optimizer: torch.optim.Optimizer,
args,
kwargs,
model_parts: list[nn.Module],
) -> None:
nonlocal hook_called
hook_called = True

def register_optimizer_hook_to_spec(spec: ModelSpec) -> ModelSpec:
# Create a closure to capture the original spec.build_optimizers_fn
original_build_optimizers_fn = spec.build_optimizers_fn

def my_build_optimizer_fn(
model_parts: list[nn.Module], job_config: JobConfig
) -> OptimizersContainer:
optimizers = original_build_optimizers_fn(model_parts, job_config)
optimizers.register_step_post_hook(
partial(my_hook, model_parts=model_parts)
)
return optimizers

spec.build_optimizers_fn = my_build_optimizer_fn

apply_to_model_specs(register_optimizer_hook_to_spec)

model = new_spec.cls.from_model_args(BaseModelArgs())
model_parts = [model]
optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig())
assert optimizers.optimizers[0].__class__.__name__ == "Adam"
batch = torch.randn(8, 8)
model(batch).sum().backward()
assert not hook_called
optimizers.step()
assert hook_called
12 changes: 12 additions & 0 deletions torchtitan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models

4 changes: 2 additions & 2 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torch.utils.data import DataLoader
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import init_logger, logger
from torchtitan.optimizer import OptimizersContainer, SchedulersContainer
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer


class IntervalType(enum.Enum):
Expand Down Expand Up @@ -140,7 +140,7 @@ def __init__(
dataloader: DataLoader,
model_parts: List[nn.Module],
optimizers: OptimizersContainer,
lr_schedulers: SchedulersContainer,
lr_schedulers: LRSchedulersContainer,
states: Dict[str, Any],
job_config: JobConfig,
) -> None:
Expand Down
17 changes: 17 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,23 @@ def __init__(self):
The default value is 'allgather'.
""",
)
# I'm not particularly fond of this. Users can choose to write their own wrapper
# module and import TorchTitan training loop and execute it, which look cleaner.
# One reason to provide this option is to allow users to use the existing run script.
# While the script is pretty trivial now, we may add more logic when integrating
# with TorchFT.
# This option is subject to change and may be deleted in the future.
self.parser.add_argument(
"--experimental.custom_model_path",
type=str,
default="",
help="""
The --custom_model_path option allows to specify a custom path to a model module
that is not natively implemented within TorchTitan.
Acceptable values are the file system path to the module (e.g., my_models/model_x)
dotted import module (e.g., some_package.model_x).
""",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
92 changes: 92 additions & 0 deletions torchtitan/model_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.


from dataclasses import dataclass
from typing import Callable, Dict, List, Protocol, Tuple, Type, TypeAlias

import torch.nn as nn
from torch.distributed.pipelining.schedules import _PipelineSchedule

from torchtitan.config_manager import JobConfig
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer


@dataclass
class BaseModelArgs:
"""All ModelArgs should inherit from this class.

The only usage of this class is type checking but allows us to extend common
arguments to all models in the future.
"""

_enforced: str = "This field is used to enforce all fields have defaults."


class ModelProtocol(Protocol):
"""Defines the interface for a model class.

This is used to enforce that all model classes have some methods that are
required by the TorchTitan trainer.
"""

@staticmethod
def from_model_args(args: BaseModelArgs) -> nn.Module: ...


OptimizersBuilder: TypeAlias = Callable[
[List[nn.Module], JobConfig], OptimizersContainer
]
OptimizerBuilderWrapper: TypeAlias = Callable[
[List[nn.Module], JobConfig, OptimizersContainer], OptimizersContainer
]
LRSchedulersBuilder: TypeAlias = Callable[[OptimizersContainer], LRSchedulersContainer]


@dataclass
class ModelSpec:
fegin marked this conversation as resolved.
Show resolved Hide resolved
name: str
cls: Type[nn.Module]
config: Dict[str, BaseModelArgs]
# TODO: Add a ``build_dataloader_fn``
# As for now, this is a string. So it will have to be built-in to the
# TorchTitan library. A better way would be to have a dataloader class
# and a ``build_dataloader`` function that take job_config to consume
# the different dataloader and tokenizer configs.
tokenizer: str
fegin marked this conversation as resolved.
Show resolved Hide resolved
parallelize_fn: Callable[[nn.Module], None]
pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]]
build_optimizers_fn: OptimizersBuilder
build_lr_schedulers_fn: LRSchedulersBuilder

# TODO: Add a FQN convert fn to allow users to load checkpoints from
# HuggingFace or other sources that have different FQN conventions.


_model_specs = {}


def register_model_spec(model_spec: ModelSpec) -> None:
global _model_specs
if model_spec.name in _model_specs:
raise ValueError(f"Model {model_spec.name} is already registered.")

_model_specs[model_spec.name] = model_spec


def get_model_spec(name: str) -> ModelSpec:
global _model_specs
if name not in _model_specs:
raise ValueError(f"Model {name} is not registered.")
return _model_specs[name]


def apply_to_model_specs(func: Callable[[ModelSpec], ModelSpec]) -> None:
global _model_specs
for name, model_spec in _model_specs.items():
_model_specs[name] = func(model_spec)
13 changes: 3 additions & 10 deletions torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.models.llama import llama3_configs, Transformer

models_config = {
"llama3": llama3_configs,
}

model_name_to_cls = {"llama3": Transformer}

model_name_to_tokenizer = {
"llama3": "tiktoken",
}
# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models.llama # noqa
22 changes: 21 additions & 1 deletion torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.model_spec import ModelSpec, register_model_spec
from torchtitan.models.llama.model import ModelArgs, Transformer
from torchtitan.optimizer import build_lr_schedulers, build_optimizers

from .parallelize_llama import parallelize_llama
from .pipeline_llama import pipeline_llama

__all__ = ["parallelize_llama", "pipeline_llama", "ModelArgs", "Transformer"]
fegin marked this conversation as resolved.
Show resolved Hide resolved

__all__ = ["Transformer"]

llama3_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
Expand Down Expand Up @@ -40,3 +46,17 @@
rope_theta=500000,
),
}


register_model_spec(
ModelSpec(
name="llama3",
cls=Transformer,
config=llama3_configs,
tokenizer="tiktoken",
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
)
5 changes: 3 additions & 2 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
import torch
import torch.nn.functional as F
from torch import nn
from torchtitan.model_spec import BaseModelArgs, ModelProtocol
from torchtitan.models.norms import build_norm


@dataclass
class ModelArgs:
class ModelArgs(BaseModelArgs):
fegin marked this conversation as resolved.
Show resolved Hide resolved
fegin marked this conversation as resolved.
Show resolved Hide resolved
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
Expand Down Expand Up @@ -331,7 +332,7 @@ def init_weights(self):
self.feed_forward.init_weights(self.weight_init_std)


class Transformer(nn.Module):
class Transformer(nn.Module, ModelProtocol):
"""
Transformer Module

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import logger
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms import ParallelDims


def parallelize_llama(
Expand Down
Loading
Loading