diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 3558fe3cadc66..8f5062315e805 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Added `best_model_metrics` attribute to `ModelCheckpoint` callback to store all logged metrics associated with the best model checkpoint ([#21355](https://github.com/Lightning-AI/pytorch-lightning/pull/21355)) ### Changed diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 34dca4232a475..61eefa73b53d8 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -291,6 +291,7 @@ def __init__( self._last_global_step_saved = 0 # no need to save when no steps were taken self._last_time_checked: Optional[float] = None self.current_score: Optional[Tensor] = None + self.best_model_metrics: Optional[dict[str, Tensor]] = {} self.best_k_models: dict[str, Tensor] = {} self.kth_best_model_path = "" self.best_model_score: Optional[Tensor] = None @@ -550,10 +551,13 @@ def state_dict(self) -> dict[str, Any]: "kth_best_model_path": self.kth_best_model_path, "kth_value": self.kth_value, "last_model_path": self.last_model_path, + "best_model_metrics": self.best_model_metrics, } @override def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self.best_model_metrics = state_dict.get("best_model_metrics", {}) + dirpath_from_ckpt = state_dict.get("dirpath", self.dirpath) if self.dirpath == dirpath_from_ckpt: @@ -974,6 +978,9 @@ def _update_best_and_save( self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] self.best_model_score = self.best_k_models[self.best_model_path] + if self.best_model_path == filepath: + self.best_model_metrics = dict(monitor_candidates) + if self.verbose: epoch = monitor_candidates["epoch"] step = monitor_candidates["step"] diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 449484da970a8..b8c49d6a19ffa 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -1637,6 +1637,127 @@ def training_step(self, *args): assert model_checkpoint.current_score == expected +def test_best_model_metrics(tmp_path): + """Ensure ModelCheckpoint correctly tracks and restores best_model_metrics.""" + + class TestModel(BoringModel): + def training_step(self, batch, batch_idx): + loss = super().training_step(batch, batch_idx) + self.log("train_loss", loss["loss"]) + self.log("train_metric", (self.current_epoch + 1) / 10) + return loss + + def validation_step(self, batch, batch_idx): + output = super().validation_step(batch, batch_idx) + loss = output["x"] + self.log("val_loss", loss) + self.log("val_metric", (self.current_epoch + 1) / 10) + return loss + + checkpoint = ModelCheckpoint( + dirpath=tmp_path, + save_top_k=3, + monitor="val_metric", + mode="min", + ) + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=3, + limit_train_batches=1, + limit_val_batches=1, + callbacks=[checkpoint], + logger=False, + enable_progress_bar=False, + enable_model_summary=False, + ) + + trainer.fit(TestModel()) + + assert hasattr(checkpoint, "best_model_metrics") + assert isinstance(checkpoint.best_model_metrics, dict) + assert "val_metric" in checkpoint.best_model_metrics + assert checkpoint.best_model_metrics["val_metric"] == 0.1 # lowest value + assert "val_loss" in checkpoint.best_model_metrics + assert "train_loss" in checkpoint.best_model_metrics + assert "train_metric" in checkpoint.best_model_metrics + + best_ckpt_path = checkpoint.best_model_path + assert best_ckpt_path + assert os.path.exists(best_ckpt_path) + + loaded = torch.load(best_ckpt_path, weights_only=False) + + callbacks_state = loaded.get("callbacks", {}) + assert callbacks_state # ensure not empty + + ckpt_key = next( + (k for k in callbacks_state if k.startswith("ModelCheckpoint")), + None, + ) + + assert ckpt_key is not None + + loaded_metrics = callbacks_state[ckpt_key]["best_model_metrics"] + + assert isinstance(loaded_metrics, dict) + assert loaded_metrics == checkpoint.best_model_metrics + assert loaded_metrics["val_metric"] == 0.1 + + +@pytest.mark.parametrize("mode", ["min", "max"]) +def test_best_model_metrics_mode(tmp_path, mode: str): + """Ensure ModelCheckpoint.best_model_metrics respects the 'mode' parameter and is restored correctly.""" + + class TestModel(BoringModel): + def validation_step(self, batch, batch_idx): + output = super().validation_step(batch, batch_idx) + metric_value = (self.current_epoch + 1) / 10 + self.log("val_metric", metric_value) + return output["x"] + + checkpoint = ModelCheckpoint(dirpath=tmp_path, save_top_k=1, monitor="val_metric", mode=mode) + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=3, + limit_train_batches=1, + limit_val_batches=1, + callbacks=[checkpoint], + logger=False, + enable_progress_bar=False, + enable_model_summary=False, + ) + + trainer.fit(TestModel()) + + assert checkpoint.best_model_metrics is not None + assert "val_metric" in checkpoint.best_model_metrics + + expected_value = 0.1 if mode == "min" else 0.3 + assert checkpoint.best_model_metrics["val_metric"] == expected_value + + # load the checkpoint and verify metrics are restored + best_ckpt_path = checkpoint.best_model_path + assert best_ckpt_path + assert os.path.exists(best_ckpt_path) + + loaded = torch.load(best_ckpt_path, weights_only=False) + callbacks_state = loaded.get("callbacks", {}) + assert callbacks_state + + ckpt_key = next( + (k for k in callbacks_state if k.startswith("ModelCheckpoint")), + None, + ) + assert ckpt_key is not None + + loaded_metrics = callbacks_state[ckpt_key]["best_model_metrics"] + + assert isinstance(loaded_metrics, dict) + assert loaded_metrics["val_metric"] == expected_value + + @pytest.mark.parametrize("use_omegaconf", [False, pytest.param(True, marks=RunIf(omegaconf=True))]) def test_hparams_type(tmp_path, use_omegaconf): class TestModel(BoringModel):