forked from pytorch/torchtitan
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add linear lr warmup and lr decay scheduler (pytorch#23)
this PR adds a linear lr scheduler and includes some automation based on current best practices: a - takes user lr provided in args as lr_max, and computes final min_lr for the decay schedule based on lr / 10, per chinchilla paper. (i.e. total decay will be one order of magnitude). b - computes an automated linear warmup schedule of 10% total iters as warmup, with min warmup of 2 steps. c - computes a linear decay schedule after warmup, declining from lr_max to lr_min over the end of warmup to end of training. (per Aarons latest paper, linear is preferred schedule). d - I updated learning rate to 8e-4, in order to provide more visible per iter results to the user assuming debugModel. LR scheduling produces much improved loss curve: <img width="1052" alt="Screenshot 2024-01-28 at 6 39 34 PM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/667e8520-809f-419e-bfdd-c3bb8f82ff95"> I added two log prints - the warmup schedule as one line, and then a step and current lr at each iter. Both could be disabled if too much info.
- Loading branch information
Showing
2 changed files
with
57 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
from torch.optim.lr_scheduler import LambdaLR | ||
|
||
# global states for scheduling | ||
# these are needed as LambdaLR does not support argument passing | ||
_warmup_steps = 2 | ||
_decay_steps = 0 | ||
|
||
|
||
def linear_warmup_linear_decay(current_step: int) -> float: | ||
"""Computes linear warmup followed by linear decay. | ||
Per LambdaLR requirement, this is accomplished by returning | ||
a multiplicative factor to adjust the learning rate to | ||
create the desired schedule. | ||
""" | ||
if current_step < _warmup_steps: | ||
# linear warmup | ||
# 0-indexed step, hence + 1 adjustments | ||
current_step += 1 | ||
curr_adjustment = float(current_step / (_warmup_steps + 1)) | ||
|
||
else: | ||
# linear decay | ||
normalized_step = _decay_steps - (current_step - _warmup_steps) | ||
curr_adjustment = 1 - (_decay_steps - normalized_step) / _decay_steps | ||
|
||
return curr_adjustment | ||
|
||
|
||
def get_lr_scheduler(optimizer, args): | ||
"""Build a linear warmup and linear decay scheduler""" | ||
global _warmup_steps, _decay_steps | ||
_warmup_steps = max(int(args.steps * args.warmup_pct), 2) | ||
_decay_steps = float(max(1, args.steps - _warmup_steps)) | ||
|
||
warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay) | ||
return warmup_scheduler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters