Skip to content

Commit

Permalink
add linear lr warmup and lr decay scheduler (pytorch#23)
Browse files Browse the repository at this point in the history
this PR adds a linear lr scheduler and includes some automation based on
current best practices:
a - takes user lr provided in args as lr_max, and computes final min_lr
for the decay schedule based on lr / 10, per chinchilla paper. (i.e.
total decay will be one order of magnitude).
b - computes an automated linear warmup schedule of 10% total iters as
warmup, with min warmup of 2 steps.
c - computes a linear decay schedule after warmup, declining from lr_max
to lr_min over the end of warmup to end of training. (per Aarons latest
paper, linear is preferred schedule).
d - I updated learning rate to 8e-4, in order to provide more visible
per iter results to the user assuming debugModel.

LR scheduling produces much improved loss curve:

<img width="1052" alt="Screenshot 2024-01-28 at 6 39 34 PM"
src="https://github.com/pytorch-labs/torchtrain/assets/46302957/667e8520-809f-419e-bfdd-c3bb8f82ff95">

I added two log prints - the warmup schedule as one line, and then a
step and current lr at each iter.
Both could be disabled if too much info.
  • Loading branch information
lessw2020 authored Feb 1, 2024
1 parent 83ee9f7 commit bd5176c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
39 changes: 39 additions & 0 deletions torchtrain/lr_scheduling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

from torch.optim.lr_scheduler import LambdaLR

# global states for scheduling
# these are needed as LambdaLR does not support argument passing
_warmup_steps = 2
_decay_steps = 0


def linear_warmup_linear_decay(current_step: int) -> float:
"""Computes linear warmup followed by linear decay.
Per LambdaLR requirement, this is accomplished by returning
a multiplicative factor to adjust the learning rate to
create the desired schedule.
"""
if current_step < _warmup_steps:
# linear warmup
# 0-indexed step, hence + 1 adjustments
current_step += 1
curr_adjustment = float(current_step / (_warmup_steps + 1))

else:
# linear decay
normalized_step = _decay_steps - (current_step - _warmup_steps)
curr_adjustment = 1 - (_decay_steps - normalized_step) / _decay_steps

return curr_adjustment


def get_lr_scheduler(optimizer, args):
"""Build a linear warmup and linear decay scheduler"""
global _warmup_steps, _decay_steps
_warmup_steps = max(int(args.steps * args.warmup_pct), 2)
_decay_steps = float(max(1, args.steps - _warmup_steps))

warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay)
return warmup_scheduler
23 changes: 18 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from torchtrain.models import models_config, model_name_to_cls, model_name_to_tokenizer
from torchtrain.parallelisms import models_parallelize_fns
from torchtrain.lr_scheduling import get_lr_scheduler


@dataclass
Expand All @@ -46,7 +47,7 @@ def build_grad_scaler(model):
# apply gradient scaling if mixed precision training is enabled with fp16 param dtype
if model.mixed_precision.param_dtype == torch.float16:
enable_grad_scaling = True
rank0_log(f"Enabling gradient scaling for mixed precision training.")
rank0_log("Enabling gradient scaling for mixed precision training.")
else:
enable_grad_scaling = False
rank0_log("Gradient scaling not enabled.")
Expand Down Expand Up @@ -85,8 +86,8 @@ def main(args):
assert isinstance(model, FSDP)

# build optimizer after apply parallelisms to the model
# TODO: add scheduler if needed
optimizer = build_optimizer(model, args)
scheduler = get_lr_scheduler(optimizer, args)

scaler = build_grad_scaler(model)

Expand Down Expand Up @@ -144,7 +145,10 @@ def main(args):
train_state.current_loss = loss.item()
train_state.losses.append(train_state.current_loss)

rank0_log(f"current loss: {train_state.current_loss}")
rank0_log(
f"step: {train_state.step}, current loss: {train_state.current_loss}, lr: {scheduler.get_last_lr()}"
)
scheduler.step()


if __name__ == "__main__":
Expand All @@ -171,9 +175,18 @@ def main(args):
parser.add_argument(
"--optimizer", type=str, default="AdamW", help="optimizer to use"
)
parser.add_argument("--lr", type=float, default=2e-5, help="learning rate to use")
parser.add_argument("--lr", type=float, default=8e-4, help="learning rate to use")
parser.add_argument(
"--warmup_pct",
type=float,
default=0.10,
help="percentage of total training steps to use for warmup",
)
parser.add_argument(
"--max_norm", type=Union[float, int], default=1.0, help="max norm for gradient clipping"
"--max_norm",
type=Union[float, int],
default=1.0,
help="max norm for gradient clipping",
)
parser.add_argument(
"--steps", type=int, default=-1, help="how many train steps to run"
Expand Down

0 comments on commit bd5176c

Please sign in to comment.