Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE] Lr schduler flatten #794

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Conversation

mori360
Copy link
Contributor

@mori360 mori360 commented Jan 16, 2025

Currently, lr_scheduler is stored differently as optimizer, model and data_loader, with keys to be "lr_scheduler_0", "lr_scheduler_1", ... stored in the state
This PR aims to flatten lr_shceduler so that all the schedulers would be stored as a list under self.state['lr_scheduler'], which is consistent with optimizer, model and data_loader

The PR is tested by 2 parts:

  1. before and after this PR, lr_shceduler values are the same

  2. Memory trace:
    Before the flatten, rerun llama3_8b.toml from step 5 to step 10:

Screenshot 2025-01-16 at 2 40 03 PM

After the flatten, rerun llama3_8b.toml from step 5 to step 10:
Screenshot 2025-01-16 at 2 40 21 PM

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 16, 2025
@mori360 mori360 changed the title [do not review] Lr schduler flatten [BE] Lr schduler flatten Jan 17, 2025
@mori360 mori360 marked this pull request as ready for review January 17, 2025 22:39
@mori360 mori360 marked this pull request as draft January 17, 2025 22:39
@mori360 mori360 marked this pull request as ready for review January 17, 2025 23:22
@mori360 mori360 requested a review from fegin January 17, 2025 23:22
@@ -183,9 +183,9 @@ def __init__(
"model": ModelWrapper(model_parts),
"optimizer": optimizers,
"dataloader": dataloader,
"lr_scheduler": lr_schedulers,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it won't be this simple. Both OptimizersContainer and ModelWrapper define state_dict and load_state_dict to handle flattening and unflattening. Since we don't have things like get_model_state_dict and set_model_state_dict for lr scheduler in torch.distributed.checkpoint.state_dict, we likely will need to manually write something for the LambdaLR we are using. See #738 (comment)

Let's work with @fegin on this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compared lr_schedulers before and after flattening, with/without checkpoint
lr_scheduler values are consistent with changes here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it support DCP resharding? e.g. PP degree from 2 to 4 across two jobs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this PR doesn't address the resharding issue, hence the [BE] prefix. Supporting lr resharding deserve a separate PR.

@tianyu-l tianyu-l added this to the torchtitan v1.0.0 release milestone Jan 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants