-
Notifications
You must be signed in to change notification settings - Fork 270
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow users to use the customized model
**What does this PR do?** 1. This PR introduce ModelSpec to decribe a model and how to parallelize a model. 2. All the models should define `build_model_spec()` or `model_spec` to be imported by the `model` module. 3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec. 4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan. **Why do we need this PR?** This allows users to use TorchTitan with a new model without intrusively change TorchTitan code. **Next steps** 1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and pipelining_fn. We may also want to extend ModelSpec to include optimizer and lr_scheduler 2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports. We should fix this issue. ghstack-source-id: 340302190896c19ec913e2f390d792c476f17f2f Pull Request resolved: #814
- Loading branch information
Showing
17 changed files
with
629 additions
and
197 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from functools import partial | ||
|
||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.train_spec import ( | ||
apply_to_train_specs, | ||
BaseModelArgs, | ||
get_train_spec, | ||
ModelProtocol, | ||
TrainSpec, | ||
register_train_spec, | ||
) | ||
from torchtitan.models.llama import parallelize_llama, pipeline_llama | ||
from torchtitan.optimizer import ( | ||
build_lr_schedulers, | ||
build_optimizers, | ||
OptimizersContainer, | ||
) | ||
|
||
|
||
class FakeModel(ModelProtocol): | ||
@staticmethod | ||
def from_model_args(args: BaseModelArgs) -> nn.Module: | ||
return nn.Linear(8, 8) | ||
|
||
|
||
def fake_build_optimizers( | ||
model_parts: list[nn.Module], job_config: JobConfig | ||
) -> OptimizersContainer: | ||
optimizer_kwargs = { | ||
"lr": 0.1, | ||
"betas": (0.9, 0.95), | ||
"weight_decay": 0.1, | ||
"fused": True, | ||
"foreach": False, | ||
} | ||
return OptimizersContainer( | ||
model_parts=model_parts, | ||
optimizer_kwargs=optimizer_kwargs, | ||
name="Adam", | ||
) | ||
|
||
|
||
class TestTrainSpec: | ||
def test_register_train_spec(self): | ||
fake_config = {"fake": None} | ||
spec = TrainSpec( | ||
name="fake", | ||
cls=FakeModel, | ||
config=fake_config, | ||
tokenizer="tiktoken", | ||
parallelize_fn=parallelize_llama, | ||
pipelining_fn=pipeline_llama, | ||
build_optimizers_fn=build_optimizers, | ||
build_lr_schedulers_fn=build_lr_schedulers, | ||
) | ||
register_train_spec(spec) | ||
new_spec = get_train_spec("fake") | ||
assert new_spec == spec | ||
|
||
with pytest.raises(ValueError): | ||
new_spec = get_train_spec("fake2") | ||
|
||
def test_optim_hook(self): | ||
fake_config = {"fake": None} | ||
spec = TrainSpec( | ||
name="fake2", | ||
cls=FakeModel, | ||
config=fake_config, | ||
tokenizer="tiktoken", | ||
parallelize_fn=parallelize_llama, | ||
pipelining_fn=pipeline_llama, | ||
build_optimizers_fn=fake_build_optimizers, | ||
build_lr_schedulers_fn=build_lr_schedulers, | ||
) | ||
register_train_spec(spec) | ||
new_spec = get_train_spec("fake2") | ||
|
||
# Demonstrate how to register a optimizer hook for all model specs | ||
hook_called = False | ||
|
||
def my_hook( | ||
optimizer: torch.optim.Optimizer, | ||
args, | ||
kwargs, | ||
model_parts: list[nn.Module], | ||
) -> None: | ||
nonlocal hook_called | ||
hook_called = True | ||
|
||
def register_optimizer_hook_to_spec(spec: TrainSpec) -> TrainSpec: | ||
# Create a closure to capture the original spec.build_optimizers_fn | ||
original_build_optimizers_fn = spec.build_optimizers_fn | ||
|
||
def my_build_optimizer_fn( | ||
model_parts: list[nn.Module], job_config: JobConfig | ||
) -> OptimizersContainer: | ||
optimizers = original_build_optimizers_fn(model_parts, job_config) | ||
optimizers.register_step_post_hook( | ||
partial(my_hook, model_parts=model_parts) | ||
) | ||
return optimizers | ||
|
||
spec.build_optimizers_fn = my_build_optimizer_fn | ||
|
||
apply_to_train_specs(register_optimizer_hook_to_spec) | ||
|
||
model = new_spec.cls.from_model_args(BaseModelArgs()) | ||
model_parts = [model] | ||
optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig()) | ||
assert optimizers.optimizers[0].__class__.__name__ == "Adam" | ||
batch = torch.randn(8, 8) | ||
model(batch).sum().backward() | ||
assert not hook_called | ||
optimizers.step() | ||
assert hook_called |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
# Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
||
# Import the built-in models here so that the corresponding register_model_spec() | ||
# will be called. | ||
import torchtitan.models # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.