forked from pytorch/torchtitan
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Reland] Add Dynamic Model Import and ModelSpec Definition (pytorch#837)
`ghstack` didn't land pytorch#814 correctly. Open this PR to do so. The detail discussion please refer to pytorch#814 **What does this PR do?** 1. This PR introduces `ModelSpec` to describe a model and how to parallelize a model. * All the models should call `register_model_spec()`. * Users can also use `--experimental.custom_model_path` to dynamically import a model that is not implemented by TorchTitan. The module should also call `register_model_spec()`. 2. This PR also refactors `OptimizersContainer` and `LRSchedulersContainers` * Fixes an issue that optimizers will accept parameters that requires_grad is False. * Improve typing and docstring. * Improve the function and class reusability. * `OptimizersContainer` now inherits from `torch.optim.Optimizer` . 3. This PR also moves `parallelize_llama` and `pipelining_llama` to the `llama` folder. **Why do we need this PR?** This allows users to use TorchTitan with a new model without intrusively change TorchTitan code. **Next steps** 1. Dataloader is not included 2. Checkpoint customization is not included yet.
- Loading branch information
1 parent
d858118
commit ba4730c
Showing
19 changed files
with
643 additions
and
204 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from functools import partial | ||
|
||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.models.llama import parallelize_llama, pipeline_llama | ||
from torchtitan.optimizer import ( | ||
build_lr_schedulers, | ||
build_optimizers, | ||
OptimizersContainer, | ||
) | ||
from torchtitan.train_spec import ( | ||
apply_to_train_specs, | ||
BaseModelArgs, | ||
get_train_spec, | ||
ModelProtocol, | ||
register_train_spec, | ||
TrainSpec, | ||
) | ||
|
||
|
||
class FakeModel(ModelProtocol): | ||
@staticmethod | ||
def from_model_args(args: BaseModelArgs) -> nn.Module: | ||
return nn.Linear(8, 8) | ||
|
||
|
||
def fake_build_optimizers( | ||
model_parts: list[nn.Module], job_config: JobConfig | ||
) -> OptimizersContainer: | ||
optimizer_kwargs = { | ||
"lr": 0.1, | ||
"betas": (0.9, 0.95), | ||
"weight_decay": 0.1, | ||
"fused": True, | ||
"foreach": False, | ||
} | ||
return OptimizersContainer( | ||
model_parts=model_parts, | ||
optimizer_kwargs=optimizer_kwargs, | ||
name="Adam", | ||
) | ||
|
||
|
||
class TestTrainSpec: | ||
def test_register_train_spec(self): | ||
fake_config = {"fake": None} | ||
spec = TrainSpec( | ||
name="fake", | ||
cls=FakeModel, | ||
config=fake_config, | ||
parallelize_fn=parallelize_llama, | ||
pipelining_fn=pipeline_llama, | ||
build_optimizers_fn=build_optimizers, | ||
build_lr_schedulers_fn=build_lr_schedulers, | ||
) | ||
register_train_spec(spec) | ||
new_spec = get_train_spec("fake") | ||
assert new_spec == spec | ||
|
||
with pytest.raises(ValueError): | ||
new_spec = get_train_spec("fake2") | ||
|
||
def test_optim_hook(self): | ||
fake_config = {"fake": None} | ||
spec = TrainSpec( | ||
name="fake2", | ||
cls=FakeModel, | ||
config=fake_config, | ||
parallelize_fn=parallelize_llama, | ||
pipelining_fn=pipeline_llama, | ||
build_optimizers_fn=fake_build_optimizers, | ||
build_lr_schedulers_fn=build_lr_schedulers, | ||
) | ||
register_train_spec(spec) | ||
new_spec = get_train_spec("fake2") | ||
|
||
# Demonstrate how to register a optimizer hook for all model specs | ||
hook_called = False | ||
|
||
def my_hook( | ||
optimizer: torch.optim.Optimizer, | ||
args, | ||
kwargs, | ||
model_parts: list[nn.Module], | ||
) -> None: | ||
nonlocal hook_called | ||
hook_called = True | ||
|
||
def register_optimizer_hook_to_spec(spec: TrainSpec) -> TrainSpec: | ||
# Create a closure to capture the original spec.build_optimizers_fn | ||
original_build_optimizers_fn = spec.build_optimizers_fn | ||
|
||
def my_build_optimizer_fn( | ||
model_parts: list[nn.Module], job_config: JobConfig | ||
) -> OptimizersContainer: | ||
optimizers = original_build_optimizers_fn(model_parts, job_config) | ||
optimizers.register_step_post_hook( | ||
partial(my_hook, model_parts=model_parts) | ||
) | ||
return optimizers | ||
|
||
spec.build_optimizers_fn = my_build_optimizer_fn | ||
|
||
apply_to_train_specs(register_optimizer_hook_to_spec) | ||
|
||
model = new_spec.cls.from_model_args(BaseModelArgs()) | ||
model_parts = [model] | ||
optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig()) | ||
assert optimizers.optimizers[0].__class__.__name__ == "Adam" | ||
batch = torch.randn(8, 8) | ||
model(batch).sum().backward() | ||
assert not hook_called | ||
optimizers.step() | ||
assert hook_called |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
# Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
||
# Import the built-in models here so that the corresponding register_model_spec() | ||
# will be called. | ||
import torchtitan.models # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.