From fd683e61ec40fda40d9bad4a80dee5cc1c8d17e9 Mon Sep 17 00:00:00 2001 From: mori360 Date: Fri, 31 Jan 2025 13:37:32 -0800 Subject: [PATCH] update comments --- torchtitan/optimizer.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index d91b064c..034bc11f 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -180,16 +180,20 @@ def step(self) -> None: scheduler.step() def state_dict(self) -> Dict[str, Any]: - # We have lr_scheduler with the same state_dict for all optimizers, so can just save one. + # Currently, we have one scheduler per optimizer. However, when using MultiSchedule PP or optimizer-in-backward, + # there are multiple optimizers and schedulers, but the scheduler state_dict remains the same for all. + # Therefore, we only save the first one and later load it for all. assert ( len(self.schedulers) > 0 ), "Must have at least one scheduler to save state_dict" return self.schedulers[0].state_dict() def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - # Load the same state_dict for all schedulers + # Load the same state_dict for all schedulers. The key value we're concerned with in scheduler.state_dict() is `last_epoch`, + # which is an integer that will be automatically copied. As long as `training.steps` and `training.warmup_steps` remain + # unchanged when resuming from a checkpoint, this approach is safe. We call `.copy()` here to ensure extra safety. for scheduler in self.schedulers: - scheduler.load_state_dict(state_dict) + scheduler.load_state_dict(state_dict.copy()) def build_lr_schedulers(optimizers, job_config: JobConfig) -> SchedulersContainer: