Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- Added `best_model_metrics` attribute to `ModelCheckpoint` callback to store all logged metrics associated with the best model checkpoint ([#21355](https://github.com/Lightning-AI/pytorch-lightning/pull/21355))

### Changed

Expand Down
7 changes: 7 additions & 0 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def __init__(
self._last_global_step_saved = 0 # no need to save when no steps were taken
self._last_time_checked: Optional[float] = None
self.current_score: Optional[Tensor] = None
self.best_model_metrics: Optional[dict[str, Tensor]] = {}
self.best_k_models: dict[str, Tensor] = {}
self.kth_best_model_path = ""
self.best_model_score: Optional[Tensor] = None
Expand Down Expand Up @@ -550,10 +551,13 @@ def state_dict(self) -> dict[str, Any]:
"kth_best_model_path": self.kth_best_model_path,
"kth_value": self.kth_value,
"last_model_path": self.last_model_path,
"best_model_metrics": self.best_model_metrics,
}

@override
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self.best_model_metrics = state_dict.get("best_model_metrics", {})

dirpath_from_ckpt = state_dict.get("dirpath", self.dirpath)

if self.dirpath == dirpath_from_ckpt:
Expand Down Expand Up @@ -974,6 +978,9 @@ def _update_best_and_save(
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
self.best_model_score = self.best_k_models[self.best_model_path]

if self.best_model_path == filepath:
self.best_model_metrics = dict(monitor_candidates)

if self.verbose:
epoch = monitor_candidates["epoch"]
step = monitor_candidates["step"]
Expand Down
121 changes: 121 additions & 0 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,6 +1637,127 @@ def training_step(self, *args):
assert model_checkpoint.current_score == expected


def test_best_model_metrics(tmp_path):
"""Ensure ModelCheckpoint correctly tracks and restores best_model_metrics."""

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
loss = super().training_step(batch, batch_idx)
self.log("train_loss", loss["loss"])
self.log("train_metric", (self.current_epoch + 1) / 10)
return loss

def validation_step(self, batch, batch_idx):
output = super().validation_step(batch, batch_idx)
loss = output["x"]
self.log("val_loss", loss)
self.log("val_metric", (self.current_epoch + 1) / 10)
return loss

checkpoint = ModelCheckpoint(
dirpath=tmp_path,
save_top_k=3,
monitor="val_metric",
mode="min",
)

trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=3,
limit_train_batches=1,
limit_val_batches=1,
callbacks=[checkpoint],
logger=False,
enable_progress_bar=False,
enable_model_summary=False,
)

trainer.fit(TestModel())

assert hasattr(checkpoint, "best_model_metrics")
assert isinstance(checkpoint.best_model_metrics, dict)
assert "val_metric" in checkpoint.best_model_metrics
assert checkpoint.best_model_metrics["val_metric"] == 0.1 # lowest value
assert "val_loss" in checkpoint.best_model_metrics
assert "train_loss" in checkpoint.best_model_metrics
assert "train_metric" in checkpoint.best_model_metrics

best_ckpt_path = checkpoint.best_model_path
assert best_ckpt_path
assert os.path.exists(best_ckpt_path)

loaded = torch.load(best_ckpt_path, weights_only=False)

callbacks_state = loaded.get("callbacks", {})
assert callbacks_state # ensure not empty

ckpt_key = next(
(k for k in callbacks_state if k.startswith("ModelCheckpoint")),
None,
)

assert ckpt_key is not None

loaded_metrics = callbacks_state[ckpt_key]["best_model_metrics"]

assert isinstance(loaded_metrics, dict)
assert loaded_metrics == checkpoint.best_model_metrics
assert loaded_metrics["val_metric"] == 0.1


@pytest.mark.parametrize("mode", ["min", "max"])
def test_best_model_metrics_mode(tmp_path, mode: str):
"""Ensure ModelCheckpoint.best_model_metrics respects the 'mode' parameter and is restored correctly."""

class TestModel(BoringModel):
def validation_step(self, batch, batch_idx):
output = super().validation_step(batch, batch_idx)
metric_value = (self.current_epoch + 1) / 10
self.log("val_metric", metric_value)
return output["x"]

checkpoint = ModelCheckpoint(dirpath=tmp_path, save_top_k=1, monitor="val_metric", mode=mode)

trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=3,
limit_train_batches=1,
limit_val_batches=1,
callbacks=[checkpoint],
logger=False,
enable_progress_bar=False,
enable_model_summary=False,
)

trainer.fit(TestModel())

assert checkpoint.best_model_metrics is not None
assert "val_metric" in checkpoint.best_model_metrics

expected_value = 0.1 if mode == "min" else 0.3
assert checkpoint.best_model_metrics["val_metric"] == expected_value

# load the checkpoint and verify metrics are restored
best_ckpt_path = checkpoint.best_model_path
assert best_ckpt_path
assert os.path.exists(best_ckpt_path)

loaded = torch.load(best_ckpt_path, weights_only=False)
callbacks_state = loaded.get("callbacks", {})
assert callbacks_state

ckpt_key = next(
(k for k in callbacks_state if k.startswith("ModelCheckpoint")),
None,
)
assert ckpt_key is not None

loaded_metrics = callbacks_state[ckpt_key]["best_model_metrics"]

assert isinstance(loaded_metrics, dict)
assert loaded_metrics["val_metric"] == expected_value


@pytest.mark.parametrize("use_omegaconf", [False, pytest.param(True, marks=RunIf(omegaconf=True))])
def test_hparams_type(tmp_path, use_omegaconf):
class TestModel(BoringModel):
Expand Down