Skip to content

Commit 53a1234

Browse files
committed
feat: add best_model_metrics to ModelCheckpoint callback
1 parent 79ffe50 commit 53a1234

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def __init__(
291291
self._last_global_step_saved = 0 # no need to save when no steps were taken
292292
self._last_time_checked: Optional[float] = None
293293
self.current_score: Optional[Tensor] = None
294+
self.best_model_metrics: Optional[dict[str, Tensor]] = {}
294295
self.best_k_models: dict[str, Tensor] = {}
295296
self.kth_best_model_path = ""
296297
self.best_model_score: Optional[Tensor] = None
@@ -974,6 +975,9 @@ def _update_best_and_save(
974975
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
975976
self.best_model_score = self.best_k_models[self.best_model_path]
976977

978+
if self.best_model_path == filepath:
979+
self.best_model_metrics = dict(monitor_candidates)
980+
977981
if self.verbose:
978982
epoch = monitor_candidates["epoch"]
979983
step = monitor_candidates["step"]

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1637,6 +1637,80 @@ def training_step(self, *args):
16371637
assert model_checkpoint.current_score == expected
16381638

16391639

1640+
def test_best_model_metrics(tmp_path):
1641+
"""Ensure ModelCheckpoint correctly tracks best_model_metrics."""
1642+
1643+
class TestModel(BoringModel):
1644+
def training_step(self, batch, batch_idx):
1645+
loss = super().training_step(batch, batch_idx)
1646+
self.log("train_loss", loss["loss"])
1647+
self.log("train_metric", (self.current_epoch + 1) / 10)
1648+
return loss
1649+
1650+
def validation_step(self, batch, batch_idx):
1651+
output = super().validation_step(batch, batch_idx)
1652+
loss = output["x"]
1653+
self.log("val_loss", loss)
1654+
self.log("val_metric", (self.current_epoch + 1) / 10)
1655+
return loss
1656+
1657+
checkpoint = ModelCheckpoint(dirpath=tmp_path, save_top_k=3, monitor="val_metric", mode="min")
1658+
1659+
trainer = Trainer(
1660+
default_root_dir=tmp_path,
1661+
max_epochs=3,
1662+
limit_train_batches=1,
1663+
limit_val_batches=1,
1664+
callbacks=[checkpoint],
1665+
logger=False,
1666+
enable_progress_bar=False,
1667+
enable_model_summary=False,
1668+
)
1669+
1670+
trainer.fit(TestModel())
1671+
1672+
assert hasattr(checkpoint, "best_model_metrics")
1673+
assert isinstance(checkpoint.best_model_metrics, dict)
1674+
assert "val_metric" in checkpoint.best_model_metrics
1675+
assert checkpoint.best_model_metrics["val_metric"] == 0.1 # best (lowest) value
1676+
assert "val_loss" in checkpoint.best_model_metrics
1677+
assert "train_loss" in checkpoint.best_model_metrics
1678+
assert "train_metric" in checkpoint.best_model_metrics
1679+
1680+
1681+
@pytest.mark.parametrize("mode", ["min", "max"])
1682+
def test_best_model_metrics_mode(tmp_path, mode: str):
1683+
"""Ensure ModelCheckpoint.best_model_metrics respects the 'mode' parameter."""
1684+
1685+
class TestModel(BoringModel):
1686+
def validation_step(self, batch, batch_idx):
1687+
output = super().validation_step(batch, batch_idx)
1688+
metric_value = (self.current_epoch + 1) / 10
1689+
self.log("val_metric", metric_value)
1690+
return output["x"]
1691+
1692+
checkpoint = ModelCheckpoint(dirpath=tmp_path, save_top_k=1, monitor="val_metric", mode=mode)
1693+
1694+
trainer = Trainer(
1695+
default_root_dir=tmp_path,
1696+
max_epochs=3,
1697+
limit_train_batches=1,
1698+
limit_val_batches=1,
1699+
callbacks=[checkpoint],
1700+
logger=False,
1701+
enable_progress_bar=False,
1702+
enable_model_summary=False,
1703+
)
1704+
1705+
trainer.fit(TestModel())
1706+
1707+
assert checkpoint.best_model_metrics is not None
1708+
assert "val_metric" in checkpoint.best_model_metrics
1709+
1710+
expected_value = 0.1 if mode == "min" else 0.3
1711+
assert checkpoint.best_model_metrics["val_metric"] == expected_value
1712+
1713+
16401714
@pytest.mark.parametrize("use_omegaconf", [False, pytest.param(True, marks=RunIf(omegaconf=True))])
16411715
def test_hparams_type(tmp_path, use_omegaconf):
16421716
class TestModel(BoringModel):

0 commit comments

Comments
 (0)