Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
fegin committed Feb 11, 2025
1 parent 210707a commit 6fb1d74
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import torch
import torch.nn as nn
from torchtitan.config_manager import JobConfig
from torchtitan.model_spec import (
apply_to_model_specs,
from torchtitan.train_spec import (
apply_to_train_specs,
BaseModelArgs,
get_model_spec,
get_train_spec,
ModelProtocol,
ModelSpec,
register_model_spec,
TrainSpec,
register_train_spec,
)
from torchtitan.models.llama import parallelize_llama, pipeline_llama
from torchtitan.optimizer import (
Expand Down Expand Up @@ -49,10 +49,10 @@ def fake_build_optimizers(
)


class TestModelSpec:
def test_register_model_spec(self):
class TestTrainSpec:
def test_register_train_spec(self):
fake_config = {"fake": None}
spec = ModelSpec(
spec = TrainSpec(
name="fake",
cls=FakeModel,
config=fake_config,
Expand All @@ -62,16 +62,16 @@ def test_register_model_spec(self):
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
register_model_spec(spec)
new_spec = get_model_spec("fake")
register_train_spec(spec)
new_spec = get_train_spec("fake")
assert new_spec == spec

with pytest.raises(ValueError):
new_spec = get_model_spec("fake2")
new_spec = get_train_spec("fake2")

def test_optim_hook(self):
fake_config = {"fake": None}
spec = ModelSpec(
spec = TrainSpec(
name="fake2",
cls=FakeModel,
config=fake_config,
Expand All @@ -81,8 +81,8 @@ def test_optim_hook(self):
build_optimizers_fn=fake_build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
register_model_spec(spec)
new_spec = get_model_spec("fake2")
register_train_spec(spec)
new_spec = get_train_spec("fake2")

# Demonstrate how to register a optimizer hook for all model specs
hook_called = False
Expand All @@ -96,7 +96,7 @@ def my_hook(
nonlocal hook_called
hook_called = True

def register_optimizer_hook_to_spec(spec: ModelSpec) -> ModelSpec:
def register_optimizer_hook_to_spec(spec: TrainSpec) -> TrainSpec:
# Create a closure to capture the original spec.build_optimizers_fn
original_build_optimizers_fn = spec.build_optimizers_fn

Expand All @@ -111,7 +111,7 @@ def my_build_optimizer_fn(

spec.build_optimizers_fn = my_build_optimizer_fn

apply_to_model_specs(register_optimizer_hook_to_spec)
apply_to_train_specs(register_optimizer_hook_to_spec)

model = new_spec.cls.from_model_args(BaseModelArgs())
model_parts = [model]
Expand Down
1 change: 1 addition & 0 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import DataLoader

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import init_logger, logger
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer
Expand Down
26 changes: 17 additions & 9 deletions torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,27 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.model_spec import ModelSpec, register_model_spec
from torchtitan.models.llama.model import ModelArgs, Transformer
from torchtitan.train_spec import TrainSpec, register_train_spec
from torchtitan.models.llama.model import Transformer, TransformerModelArgs
from torchtitan.optimizer import build_lr_schedulers, build_optimizers

from .parallelize_llama import parallelize_llama
from .pipeline_llama import pipeline_llama

__all__ = ["parallelize_llama", "pipeline_llama", "ModelArgs", "Transformer"]
__all__ = [
"parallelize_llama",
"pipeline_llama",
"TransformerModelArgs",
"Transformer",
"llama3_configs",
]


llama3_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
"8B": ModelArgs(
"debugmodel": TransformerModelArgs(
dim=256, n_layers=8, n_heads=16, rope_theta=500000
),
"8B": TransformerModelArgs(
dim=4096,
n_layers=32,
n_heads=32,
Expand All @@ -27,7 +35,7 @@
multiple_of=1024,
rope_theta=500000,
),
"70B": ModelArgs(
"70B": TransformerModelArgs(
dim=8192,
n_layers=80,
n_heads=64,
Expand All @@ -36,7 +44,7 @@
multiple_of=4096,
rope_theta=500000,
),
"405B": ModelArgs(
"405B": TransformerModelArgs(
dim=16384,
n_layers=126,
n_heads=128,
Expand All @@ -48,8 +56,8 @@
}


register_model_spec(
ModelSpec(
register_train_spec(
TrainSpec(
name="llama3",
cls=Transformer,
config=llama3_configs,
Expand Down
24 changes: 12 additions & 12 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
import torch
import torch.nn.functional as F
from torch import nn
from torchtitan.model_spec import BaseModelArgs, ModelProtocol
from torchtitan.train_spec import BaseModelArgs, ModelProtocol
from torchtitan.models.norms import build_norm


@dataclass
class ModelArgs(BaseModelArgs):
class TransformerModelArgs(BaseModelArgs):
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
Expand Down Expand Up @@ -131,7 +131,7 @@ class Attention(nn.Module):
Multi-head attention module.
Args:
model_args (ModelArgs): Model configuration arguments.
model_args (TransformerModelArgs): Model configuration arguments.
Attributes:
n_kv_heads (int): Number of key and value heads.
Expand All @@ -145,7 +145,7 @@ class Attention(nn.Module):
"""

def __init__(self, model_args: ModelArgs):
def __init__(self, model_args: TransformerModelArgs):
super().__init__()
self.n_heads = model_args.n_heads
self.n_kv_heads = (
Expand Down Expand Up @@ -265,7 +265,7 @@ class TransformerBlock(nn.Module):
Args:
layer_id (int): Identifier for the layer.
model_args (ModelArgs): Model configuration arguments.
model_args (TransformerModelArgs): Model configuration arguments.
Attributes:
n_heads (int): Number of attention heads.
Expand All @@ -279,7 +279,7 @@ class TransformerBlock(nn.Module):
"""

def __init__(self, layer_id: int, model_args: ModelArgs):
def __init__(self, layer_id: int, model_args: TransformerModelArgs):
super().__init__()
self.n_heads = model_args.n_heads
self.dim = model_args.dim
Expand Down Expand Up @@ -337,10 +337,10 @@ class Transformer(nn.Module, ModelProtocol):
Transformer Module
Args:
model_args (ModelArgs): Model configuration arguments.
model_args (TransformerModelArgs): Model configuration arguments.
Attributes:
model_args (ModelArgs): Model configuration arguments.
model_args (TransformerModelArgs): Model configuration arguments.
vocab_size (int): Vocabulary size.
n_layers (int): Number of layers in the model.
tok_embeddings (ParallelEmbedding): Token embeddings.
Expand All @@ -351,7 +351,7 @@ class Transformer(nn.Module, ModelProtocol):
"""

def __init__(self, model_args: ModelArgs):
def __init__(self, model_args: TransformerModelArgs):
super().__init__()
self.model_args = model_args
self.vocab_size = model_args.vocab_size
Expand Down Expand Up @@ -447,12 +447,12 @@ def forward(self, tokens: torch.Tensor):
return output

@classmethod
def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
def from_model_args(cls, model_args: TransformerModelArgs) -> "Transformer":
"""
Initialize a Transformer model from a ModelArgs object.
Initialize a Transformer model from a TransformerModelArgs object.
Args:
model_args (ModelArgs): Model configuration arguments.
model_args (TransformerModelArgs): Model configuration arguments.
Returns:
Transformer: Transformer model.
Expand Down
16 changes: 7 additions & 9 deletions torchtitan/models/llama/pipeline_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# This file applies the PT-D pipeline parallelism to the Llama model.

import copy
from typing import Callable, Union
from typing import Callable, Union, Optional

import torch
import torch.nn as nn
Expand All @@ -18,14 +18,12 @@

from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
from torchtitan.parallelisms import (
build_pipeline_schedule,
generate_split_points,
ParallelDims,
stage_ids_this_rank,
from torchtitan.parallelisms.pipeline import (
build_pipeline_schedule, generate_split_points, stage_ids_this_rank,
)
from torchtitan.parallelisms import ParallelDims

from .model import ModelArgs
from .model import TransformerModelArgs


DeviceType = Union[int, str, torch.device]
Expand All @@ -37,7 +35,7 @@ def pipeline_llama(
parallel_dims: ParallelDims,
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs,
model_config: TransformerModelArgs,
loss_fn: Callable[..., torch.Tensor],
) -> tuple[_PipelineSchedule, list[nn.Module]]:
stages, models = pipeline_llama_manual_split(
Expand All @@ -55,7 +53,7 @@ def pipeline_llama_manual_split(
parallel_dims: ParallelDims,
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs,
model_config: TransformerModelArgs,
) -> tuple[list[PipelineStage], list[nn.Module]]:
"""
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
Expand Down
13 changes: 7 additions & 6 deletions torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
import functools
from typing import Any, Callable, Dict, Iterable, List

Expand All @@ -17,6 +18,7 @@
from torch.distributed.checkpoint.stateful import Stateful
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler

from torchtitan.config_manager import JobConfig


Expand Down Expand Up @@ -71,8 +73,6 @@ class OptimizersContainer(Optimizer):
def __init__(
self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str
) -> None:
# We need to call super().__init__() to initialize some necessary optimizer
# functionality such as hooks.
all_params = []
self.optimizers: List[Optimizer] = []
self.model_parts = model_parts
Expand Down Expand Up @@ -124,6 +124,8 @@ def _validate_length(self, expected_length: int) -> None:
def _post_init(
self, all_params: list[nn.Parameter], optimizer_kwargs: dict[str, Any]
) -> None:
# We need to call Optimizer.__init__() to initialize some necessary optimizer
# functionality such as hooks.
Optimizer.__init__(self, all_params, optimizer_kwargs)


Expand Down Expand Up @@ -188,7 +190,7 @@ def build_optimizers(
**Note**
Users who want to customize the optimizer behavior can create their own
``OptimizersContainer`` subclass and ``build_optimizers``. Passing the
customized ``build_optimizers`` to ``ModelSpec`` will create the customized
customized ``build_optimizers`` to ``TrainSpec`` will create the customized
``OptimizersContainer``.
Args:
Expand Down Expand Up @@ -273,9 +275,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# that is immutable. As long as ``training.steps`` and ``training.warmup_steps``
# in ``job_config`` remain unchanged when resuming from a checkpoint, this
# approach is safe. We call ``copy()`` here to ensure extra safety.
# TODO: Should we deepcopy the state_dict?
for scheduler in self.schedulers:
scheduler.load_state_dict(state_dict.copy())
scheduler.load_state_dict(copy.deepcopy(state_dict))


def build_lr_schedulers(
Expand All @@ -289,7 +290,7 @@ def build_lr_schedulers(
**Note**
Users who want to customize the lr scheduler behavior can create their own
``LRSchedulersContainer`` subclass and ``build_lr_scheduler``. Passing the
customized ``build_lr_schedulers`` to ``ModelSpec`` will create the customized
customized ``build_lr_schedulers`` to ``TrainSpec`` will create the customized
``LRSchedulersContainer``.
Expand Down
12 changes: 1 addition & 11 deletions torchtitan/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,6 @@


from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.pipelining_utils import (
build_pipeline_schedule,
generate_split_points,
stage_ids_this_rank,
)


__all__ = [
"ParallelDims",
"build_pipeline_schedule",
"generate_split_points",
"stage_ids_this_rank",
]
__all__ = ["ParallelDims"]
1 change: 1 addition & 0 deletions torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from functools import cached_property

from torch.distributed.device_mesh import init_device_mesh

from torchtitan.logging import logger


Expand Down
File renamed without changes.
Loading

0 comments on commit 6fb1d74

Please sign in to comment.