-
Notifications
You must be signed in to change notification settings - Fork 270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Dynamic Model Import and ModelSpec Definition #814
Merged
Merged
Changes from 11 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
df1bc6a
Update
fegin dfc1649
Update
fegin 720f12a
Update
fegin 225bfcc
Update
fegin 650152e
Update
fegin 687fda9
Update
fegin 6a51325
Update
fegin 5b33b65
Update
fegin 2e569d7
Update
fegin bab9bf5
Update
fegin 210707a
Update
fegin 6fb1d74
Update
fegin 02c87b2
Update
fegin 4234a26
Update
fegin a5491da
Update
fegin b5cd485
Update
fegin 078d4ad
Update
fegin caf5b97
Update
fegin c131309
Update
fegin 2f4d1ce
Update
fegin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from functools import partial | ||
|
||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.model_spec import ( | ||
apply_to_model_specs, | ||
BaseModelArgs, | ||
get_model_spec, | ||
ModelProtocol, | ||
ModelSpec, | ||
register_model_spec, | ||
) | ||
from torchtitan.models.llama import parallelize_llama, pipeline_llama | ||
from torchtitan.optimizer import ( | ||
build_lr_schedulers, | ||
build_optimizers, | ||
OptimizersContainer, | ||
) | ||
|
||
|
||
class FakeModel(ModelProtocol): | ||
@staticmethod | ||
def from_model_args(args: BaseModelArgs) -> nn.Module: | ||
return nn.Linear(8, 8) | ||
|
||
|
||
def fake_build_optimizers( | ||
model_parts: list[nn.Module], job_config: JobConfig | ||
) -> OptimizersContainer: | ||
optimizer_kwargs = { | ||
"lr": 0.1, | ||
"betas": (0.9, 0.95), | ||
"weight_decay": 0.1, | ||
"fused": True, | ||
"foreach": False, | ||
} | ||
return OptimizersContainer( | ||
model_parts=model_parts, | ||
optimizer_kwargs=optimizer_kwargs, | ||
name="Adam", | ||
) | ||
|
||
|
||
class TestModelSpec: | ||
def test_register_model_spec(self): | ||
fake_config = {"fake": None} | ||
spec = ModelSpec( | ||
name="fake", | ||
cls=FakeModel, | ||
config=fake_config, | ||
tokenizer="tiktoken", | ||
parallelize_fn=parallelize_llama, | ||
pipelining_fn=pipeline_llama, | ||
build_optimizers_fn=build_optimizers, | ||
build_lr_schedulers_fn=build_lr_schedulers, | ||
) | ||
register_model_spec(spec) | ||
new_spec = get_model_spec("fake") | ||
assert new_spec == spec | ||
|
||
with pytest.raises(ValueError): | ||
new_spec = get_model_spec("fake2") | ||
|
||
def test_optim_hook(self): | ||
fake_config = {"fake": None} | ||
spec = ModelSpec( | ||
name="fake2", | ||
cls=FakeModel, | ||
config=fake_config, | ||
tokenizer="tiktoken", | ||
parallelize_fn=parallelize_llama, | ||
pipelining_fn=pipeline_llama, | ||
build_optimizers_fn=fake_build_optimizers, | ||
build_lr_schedulers_fn=build_lr_schedulers, | ||
) | ||
register_model_spec(spec) | ||
new_spec = get_model_spec("fake2") | ||
|
||
# Demonstrate how to register a optimizer hook for all model specs | ||
hook_called = False | ||
|
||
def my_hook( | ||
optimizer: torch.optim.Optimizer, | ||
args, | ||
kwargs, | ||
model_parts: list[nn.Module], | ||
) -> None: | ||
nonlocal hook_called | ||
hook_called = True | ||
|
||
def register_optimizer_hook_to_spec(spec: ModelSpec) -> ModelSpec: | ||
# Create a closure to capture the original spec.build_optimizers_fn | ||
original_build_optimizers_fn = spec.build_optimizers_fn | ||
|
||
def my_build_optimizer_fn( | ||
model_parts: list[nn.Module], job_config: JobConfig | ||
) -> OptimizersContainer: | ||
optimizers = original_build_optimizers_fn(model_parts, job_config) | ||
optimizers.register_step_post_hook( | ||
partial(my_hook, model_parts=model_parts) | ||
) | ||
return optimizers | ||
|
||
spec.build_optimizers_fn = my_build_optimizer_fn | ||
|
||
apply_to_model_specs(register_optimizer_hook_to_spec) | ||
|
||
model = new_spec.cls.from_model_args(BaseModelArgs()) | ||
model_parts = [model] | ||
optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig()) | ||
assert optimizers.optimizers[0].__class__.__name__ == "Adam" | ||
batch = torch.randn(8, 8) | ||
model(batch).sum().backward() | ||
assert not hook_called | ||
optimizers.step() | ||
assert hook_called |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
# Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
||
# Import the built-in models here so that the corresponding register_model_spec() | ||
# will be called. | ||
import torchtitan.models | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
# Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
||
|
||
from dataclasses import dataclass | ||
from typing import Callable, Dict, List, Protocol, Tuple, Type, TypeAlias | ||
|
||
import torch.nn as nn | ||
from torch.distributed.pipelining.schedules import _PipelineSchedule | ||
|
||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer | ||
|
||
|
||
@dataclass | ||
class BaseModelArgs: | ||
"""All ModelArgs should inherit from this class. | ||
|
||
The only usage of this class is type checking but allows us to extend common | ||
arguments to all models in the future. | ||
""" | ||
|
||
_enforced: str = "This field is used to enforce all fields have defaults." | ||
|
||
|
||
class ModelProtocol(Protocol): | ||
"""Defines the interface for a model class. | ||
|
||
This is used to enforce that all model classes have some methods that are | ||
required by the TorchTitan trainer. | ||
""" | ||
|
||
@staticmethod | ||
def from_model_args(args: BaseModelArgs) -> nn.Module: ... | ||
|
||
|
||
OptimizersBuilder: TypeAlias = Callable[ | ||
[List[nn.Module], JobConfig], OptimizersContainer | ||
] | ||
OptimizerBuilderWrapper: TypeAlias = Callable[ | ||
[List[nn.Module], JobConfig, OptimizersContainer], OptimizersContainer | ||
] | ||
LRSchedulersBuilder: TypeAlias = Callable[[OptimizersContainer], LRSchedulersContainer] | ||
|
||
|
||
@dataclass | ||
class ModelSpec: | ||
fegin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
name: str | ||
cls: Type[nn.Module] | ||
config: Dict[str, BaseModelArgs] | ||
# TODO: Add a ``build_dataloader_fn`` | ||
# As for now, this is a string. So it will have to be built-in to the | ||
# TorchTitan library. A better way would be to have a dataloader class | ||
# and a ``build_dataloader`` function that take job_config to consume | ||
# the different dataloader and tokenizer configs. | ||
tokenizer: str | ||
fegin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
parallelize_fn: Callable[[nn.Module], None] | ||
pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]] | ||
build_optimizers_fn: OptimizersBuilder | ||
build_lr_schedulers_fn: LRSchedulersBuilder | ||
|
||
# TODO: Add a FQN convert fn to allow users to load checkpoints from | ||
# HuggingFace or other sources that have different FQN conventions. | ||
|
||
|
||
_model_specs = {} | ||
|
||
|
||
def register_model_spec(model_spec: ModelSpec) -> None: | ||
global _model_specs | ||
if model_spec.name in _model_specs: | ||
raise ValueError(f"Model {model_spec.name} is already registered.") | ||
|
||
_model_specs[model_spec.name] = model_spec | ||
|
||
|
||
def get_model_spec(name: str) -> ModelSpec: | ||
global _model_specs | ||
if name not in _model_specs: | ||
raise ValueError(f"Model {name} is not registered.") | ||
return _model_specs[name] | ||
|
||
|
||
def apply_to_model_specs(func: Callable[[ModelSpec], ModelSpec]) -> None: | ||
global _model_specs | ||
for name, model_spec in _model_specs.items(): | ||
_model_specs[name] = func(model_spec) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Demonstrate how to register a optimizer hook for all model specs.