Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Dynamic Model Import and ModelSpec Definition #814

Merged
merged 20 commits into from
Feb 12, 2025

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Jan 31, 2025

Stack from ghstack (oldest at bottom):

What does this PR do?

  1. This PR introduces ModelSpec to describe a model and how to parallelize a model.
    • All the models should call register_model_spec().
    • Users can also use --experimental.custom_model_path to dynamically import a model that is not implemented by TorchTitan. The module should also call register_model_spec().
  2. This PR also refactors OptimizersContainer and LRSchedulersContainers
    • Fixes an issue that optimizers will accept parameters that requires_grad is False.
    • Improve typing and docstring.
    • Improve the function and class reusability.
    • OptimizersContainer now inherits from torch.optim.Optimizer .
  3. This PR also moves parallelize_llama and pipelining_llama to the llama folder.

Why do we need this PR?
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

Next steps

  1. Dataloader is not included
  2. Checkpoint customization is not included yet.

[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 31, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.

**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.

**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.

ghstack-source-id: f0847f5efebfdf8c6619f58c1b0131a233502eaf
Pull Request resolved: #814
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 31, 2025
@fegin fegin requested review from tianyu-l, wconstab and fduwjj January 31, 2025 18:40
@fegin fegin changed the title Allow users to use the customized model Add Dynamic Model Import and ModelSpec Definition Jan 31, 2025
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 31, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.

**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.

**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.

ghstack-source-id: 28259eb74975eeb7ad790a774b6e719f3aa19a31
Pull Request resolved: #814
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 31, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.

**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.

**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.

ghstack-source-id: ba1389f57808b1c6b309f554a675523d09395b42
Pull Request resolved: #814
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 31, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.

**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.

**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.

ghstack-source-id: a88ff3ebe5c869055dd3314fb1b791855fd0e0b2
Pull Request resolved: #814
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 31, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.

**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.

**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.

ghstack-source-id: 362df77a3f6a2b9f3cff514938a415bfe25e2100
Pull Request resolved: #814
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial pass looks great. Had some suggestions on restructuring.

torchtitan/models/llama/model.py Outdated Show resolved Hide resolved
torchtitan/models/__init__.py Outdated Show resolved Hide resolved
torchtitan/models/__init__.py Outdated Show resolved Hide resolved
torchtitan/config_manager.py Outdated Show resolved Hide resolved
torchtitan/models/llama/__init__.py Outdated Show resolved Hide resolved
[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 6, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.

**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.

**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.

ghstack-source-id: 9ed1b54aa945af27ce0881ea02150c9e2f0022e8
Pull Request resolved: #814
[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 6, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.

**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.

**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.

ghstack-source-id: 01c89646326d2c356b6f82e0fa714a347da7b869
Pull Request resolved: #814
[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 6, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.

**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.

**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.

ghstack-source-id: bee7a1df8af55ea6a7ad7451a6bc9a3158922d4f
Pull Request resolved: #814
[ghstack-poisoned]
register_model_spec(spec)
new_spec = get_model_spec("fake2")

# Demonstrate how to register a optimizer hook for all model specs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Demonstrate how to register a optimizer hook for all model specs.

[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 11, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

ghstack-source-id: cb09b301e5516edb438818040430c8c4f4b69cda
Pull Request resolved: #814
[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 11, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

ghstack-source-id: 340302190896c19ec913e2f390d792c476f17f2f
Pull Request resolved: #814
[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 11, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

ghstack-source-id: b1d6d90da98b92c601de46dba5d3b3b98a34e687
Pull Request resolved: #814
[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 11, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

ghstack-source-id: ba1c4ecbc5d93a84521db751f52e7f19652d6612
Pull Request resolved: #814
[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 11, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

ghstack-source-id: 00d72fb63e5483beca2fd0f1bb98d9bcccad2249
Pull Request resolved: #814
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

return _train_specs[name]


def apply_to_train_specs(func: Callable[[TrainSpec], TrainSpec]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is applying obliviously to all TrainSpec in _train_specs. Now we only have one spec (llama3) and one converter (float8); down the road we may have some converter not working with all specs.

Practically it won't fail since it only defines new functions but won't actually run them. Conceptually it can be a little confusing. But allowing too much complexity is also confusing, so I think it is OK for now.



OptimizersBuilder: TypeAlias = Callable[
[List[nn.Module], JobConfig], OptimizersContainer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can imagine that later there could be cases where in a TrainSpec pipeline_fn is none, so we don't need to implement an optimizer supporting multiple model_parts. If we hit those scenarios, we can consider relax this List[nn.Module].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should relax this. Our trainer currently use List[nn.Module] even pipeline degree is 1. This will simplify the code logic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because of our specific way of composing PP and SPMD parallelisms. Later on we may need to revisit such protocols, e.g. how do we do DualPipe from DeepSeek V3 which likely will "fuse" PP and other parallelisms? One way is to put pipeline_fn and parallelize_fn into a bigger customizable concept.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you say relaxing List[nn.Moulde], what specific type do you have in your mind?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suppose a "community" TrainSpec has a customized pipeline_fn which don't produce multiple model_parts, e.g. something like 1f1b / gpipe, then their corresponding build_optimizer_fn only needs input of type nn.Module, not List[nn.Moulde].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ye, we previously utilized Optional[nn.Module, List[nn.Module]]. But since we changed the trainer to always use model_parts[0] when there is only one nn.Module, we now use List[nn.Module]. We can change it back if that feel more friendly. That will also cause the complexity of the typing as users may have to do cast.

[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 11, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

ghstack-source-id: 6326d08a942524755330746f145f82a2ae43b7f6
Pull Request resolved: #814
[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 11, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

ghstack-source-id: 9c1d1ebd8d09adeb0a6c3fe36fefeabde908bfd4
Pull Request resolved: #814
[ghstack-poisoned]
[ghstack-poisoned]
@fegin fegin merged commit 2f4d1ce into gh/fegin/8/base Feb 12, 2025
6 checks passed
fegin added a commit that referenced this pull request Feb 12, 2025
`ghstack` didn't land #814
correctly. Open this PR to do so. The detail discussion please refer to
#814

**What does this PR do?**
1. This PR introduces `ModelSpec` to describe a model and how to
parallelize a model.
    * All the models should call `register_model_spec()`. 
* Users can also use `--experimental.custom_model_path` to dynamically
import a model that is not implemented by TorchTitan. The module should
also call `register_model_spec()`.
2. This PR also refactors `OptimizersContainer` and
`LRSchedulersContainers`
* Fixes an issue that optimizers will accept parameters that
requires_grad is False.
    * Improve typing and docstring.
    * Improve the function and class reusability.
    * `OptimizersContainer` now inherits from `torch.optim.Optimizer` .
3. This PR also moves `parallelize_llama` and `pipelining_llama` to the
`llama` folder.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively
change TorchTitan code.

**Next steps**
1. Dataloader is not included
2. Checkpoint customization is not included yet.
garrett361 pushed a commit to garrett361/torchtitan that referenced this pull request Feb 12, 2025
`ghstack` didn't land pytorch#814
correctly. Open this PR to do so. The detail discussion please refer to
pytorch#814

**What does this PR do?**
1. This PR introduces `ModelSpec` to describe a model and how to
parallelize a model.
    * All the models should call `register_model_spec()`. 
* Users can also use `--experimental.custom_model_path` to dynamically
import a model that is not implemented by TorchTitan. The module should
also call `register_model_spec()`.
2. This PR also refactors `OptimizersContainer` and
`LRSchedulersContainers`
* Fixes an issue that optimizers will accept parameters that
requires_grad is False.
    * Improve typing and docstring.
    * Improve the function and class reusability.
    * `OptimizersContainer` now inherits from `torch.optim.Optimizer` .
3. This PR also moves `parallelize_llama` and `pipelining_llama` to the
`llama` folder.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively
change TorchTitan code.

**Next steps**
1. Dataloader is not included
2. Checkpoint customization is not included yet.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants