Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
fegin committed Feb 11, 2025
1 parent bab9bf5 commit 210707a
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 6 deletions.
124 changes: 124 additions & 0 deletions tests/unit_tests/test_model_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from functools import partial

import pytest
import torch
import torch.nn as nn
from torchtitan.config_manager import JobConfig
from torchtitan.model_spec import (
apply_to_model_specs,
BaseModelArgs,
get_model_spec,
ModelProtocol,
ModelSpec,
register_model_spec,
)
from torchtitan.models.llama import parallelize_llama, pipeline_llama
from torchtitan.optimizer import (
build_lr_schedulers,
build_optimizers,
OptimizersContainer,
)


class FakeModel(ModelProtocol):
@staticmethod
def from_model_args(args: BaseModelArgs) -> nn.Module:
return nn.Linear(8, 8)


def fake_build_optimizers(
model_parts: list[nn.Module], job_config: JobConfig
) -> OptimizersContainer:
optimizer_kwargs = {
"lr": 0.1,
"betas": (0.9, 0.95),
"weight_decay": 0.1,
"fused": True,
"foreach": False,
}
return OptimizersContainer(
model_parts=model_parts,
optimizer_kwargs=optimizer_kwargs,
name="Adam",
)


class TestModelSpec:
def test_register_model_spec(self):
fake_config = {"fake": None}
spec = ModelSpec(
name="fake",
cls=FakeModel,
config=fake_config,
tokenizer="tiktoken",
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
register_model_spec(spec)
new_spec = get_model_spec("fake")
assert new_spec == spec

with pytest.raises(ValueError):
new_spec = get_model_spec("fake2")

def test_optim_hook(self):
fake_config = {"fake": None}
spec = ModelSpec(
name="fake2",
cls=FakeModel,
config=fake_config,
tokenizer="tiktoken",
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=fake_build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
register_model_spec(spec)
new_spec = get_model_spec("fake2")

# Demonstrate how to register a optimizer hook for all model specs
hook_called = False

def my_hook(
optimizer: torch.optim.Optimizer,
args,
kwargs,
model_parts: list[nn.Module],
) -> None:
nonlocal hook_called
hook_called = True

def register_optimizer_hook_to_spec(spec: ModelSpec) -> ModelSpec:
# Create a closure to capture the original spec.build_optimizers_fn
original_build_optimizers_fn = spec.build_optimizers_fn

def my_build_optimizer_fn(
model_parts: list[nn.Module], job_config: JobConfig
) -> OptimizersContainer:
optimizers = original_build_optimizers_fn(model_parts, job_config)
optimizers.register_step_post_hook(
partial(my_hook, model_parts=model_parts)
)
return optimizers

spec.build_optimizers_fn = my_build_optimizer_fn

apply_to_model_specs(register_optimizer_hook_to_spec)

model = new_spec.cls.from_model_args(BaseModelArgs())
model_parts = [model]
optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig())
assert optimizers.optimizers[0].__class__.__name__ == "Adam"
batch = torch.randn(8, 8)
model(batch).sum().backward()
assert not hook_called
optimizers.step()
assert hook_called
27 changes: 21 additions & 6 deletions torchtitan/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@


from dataclasses import dataclass
from typing import Callable, Dict, List, Protocol, Tuple, Type
from typing import Callable, Dict, List, Protocol, Tuple, Type, TypeAlias

import torch.nn as nn
from torch.distributed.pipelining.schedules import _PipelineSchedule

from torchtitan.config_manager import JobConfig
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer

Expand All @@ -35,7 +36,16 @@ class ModelProtocol(Protocol):
"""

@staticmethod
def from_model_args(self, args: BaseModelArgs) -> nn.Module: ...
def from_model_args(args: BaseModelArgs) -> nn.Module: ...


OptimizersBuilder: TypeAlias = Callable[
[List[nn.Module], JobConfig], OptimizersContainer
]
OptimizerBuilderWrapper: TypeAlias = Callable[
[List[nn.Module], JobConfig, OptimizersContainer], OptimizersContainer
]
LRSchedulersBuilder: TypeAlias = Callable[[OptimizersContainer], LRSchedulersContainer]


@dataclass
Expand All @@ -51,10 +61,8 @@ class ModelSpec:
tokenizer: str
parallelize_fn: Callable[[nn.Module], None]
pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]]
build_optimizers_fn: Callable[[List[nn.Module], JobConfig], OptimizersContainer]
build_lr_schedulers_fn: Callable[
[List[nn.Module], JobConfig], LRSchedulersContainer
]
build_optimizers_fn: OptimizersBuilder
build_lr_schedulers_fn: LRSchedulersBuilder

# TODO: Add a FQN convert fn to allow users to load checkpoints from
# HuggingFace or other sources that have different FQN conventions.
Expand All @@ -67,6 +75,7 @@ def register_model_spec(model_spec: ModelSpec) -> None:
global _model_specs
if model_spec.name in _model_specs:
raise ValueError(f"Model {model_spec.name} is already registered.")

_model_specs[model_spec.name] = model_spec


Expand All @@ -75,3 +84,9 @@ def get_model_spec(name: str) -> ModelSpec:
if name not in _model_specs:
raise ValueError(f"Model {name} is not registered.")
return _model_specs[name]


def apply_to_model_specs(func: Callable[[ModelSpec], ModelSpec]) -> None:
global _model_specs
for name, model_spec in _model_specs.items():
_model_specs[name] = func(model_spec)

0 comments on commit 210707a

Please sign in to comment.