Skip to content

Commit

Permalink
update comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mori360 committed Jan 31, 2025
1 parent 12a0bb2 commit fd683e6
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,20 @@ def step(self) -> None:
scheduler.step()

def state_dict(self) -> Dict[str, Any]:
# We have lr_scheduler with the same state_dict for all optimizers, so can just save one.
# Currently, we have one scheduler per optimizer. However, when using MultiSchedule PP or optimizer-in-backward,
# there are multiple optimizers and schedulers, but the scheduler state_dict remains the same for all.
# Therefore, we only save the first one and later load it for all.
assert (
len(self.schedulers) > 0
), "Must have at least one scheduler to save state_dict"
return self.schedulers[0].state_dict()

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# Load the same state_dict for all schedulers
# Load the same state_dict for all schedulers. The key value we're concerned with in scheduler.state_dict() is `last_epoch`,
# which is an integer that will be automatically copied. As long as `training.steps` and `training.warmup_steps` remain
# unchanged when resuming from a checkpoint, this approach is safe. We call `.copy()` here to ensure extra safety.
for scheduler in self.schedulers:
scheduler.load_state_dict(state_dict)
scheduler.load_state_dict(state_dict.copy())


def build_lr_schedulers(optimizers, job_config: JobConfig) -> SchedulersContainer:
Expand Down

0 comments on commit fd683e6

Please sign in to comment.