@@ -1637,6 +1637,80 @@ def training_step(self, *args):
16371637 assert model_checkpoint .current_score == expected
16381638
16391639
1640+ def test_best_model_metrics (tmp_path ):
1641+ """Ensure ModelCheckpoint correctly tracks best_model_metrics."""
1642+
1643+ class TestModel (BoringModel ):
1644+ def training_step (self , batch , batch_idx ):
1645+ loss = super ().training_step (batch , batch_idx )
1646+ self .log ("train_loss" , loss ["loss" ])
1647+ self .log ("train_metric" , (self .current_epoch + 1 ) / 10 )
1648+ return loss
1649+
1650+ def validation_step (self , batch , batch_idx ):
1651+ output = super ().validation_step (batch , batch_idx )
1652+ loss = output ["x" ]
1653+ self .log ("val_loss" , loss )
1654+ self .log ("val_metric" , (self .current_epoch + 1 ) / 10 )
1655+ return loss
1656+
1657+ checkpoint = ModelCheckpoint (dirpath = tmp_path , save_top_k = 3 , monitor = "val_metric" , mode = "min" )
1658+
1659+ trainer = Trainer (
1660+ default_root_dir = tmp_path ,
1661+ max_epochs = 3 ,
1662+ limit_train_batches = 1 ,
1663+ limit_val_batches = 1 ,
1664+ callbacks = [checkpoint ],
1665+ logger = False ,
1666+ enable_progress_bar = False ,
1667+ enable_model_summary = False ,
1668+ )
1669+
1670+ trainer .fit (TestModel ())
1671+
1672+ assert hasattr (checkpoint , "best_model_metrics" )
1673+ assert isinstance (checkpoint .best_model_metrics , dict )
1674+ assert "val_metric" in checkpoint .best_model_metrics
1675+ assert checkpoint .best_model_metrics ["val_metric" ] == 0.1 # best (lowest) value
1676+ assert "val_loss" in checkpoint .best_model_metrics
1677+ assert "train_loss" in checkpoint .best_model_metrics
1678+ assert "train_metric" in checkpoint .best_model_metrics
1679+
1680+
1681+ @pytest .mark .parametrize ("mode" , ["min" , "max" ])
1682+ def test_best_model_metrics_mode (tmp_path , mode : str ):
1683+ """Ensure ModelCheckpoint.best_model_metrics respects the 'mode' parameter."""
1684+
1685+ class TestModel (BoringModel ):
1686+ def validation_step (self , batch , batch_idx ):
1687+ output = super ().validation_step (batch , batch_idx )
1688+ metric_value = (self .current_epoch + 1 ) / 10
1689+ self .log ("val_metric" , metric_value )
1690+ return output ["x" ]
1691+
1692+ checkpoint = ModelCheckpoint (dirpath = tmp_path , save_top_k = 1 , monitor = "val_metric" , mode = mode )
1693+
1694+ trainer = Trainer (
1695+ default_root_dir = tmp_path ,
1696+ max_epochs = 3 ,
1697+ limit_train_batches = 1 ,
1698+ limit_val_batches = 1 ,
1699+ callbacks = [checkpoint ],
1700+ logger = False ,
1701+ enable_progress_bar = False ,
1702+ enable_model_summary = False ,
1703+ )
1704+
1705+ trainer .fit (TestModel ())
1706+
1707+ assert checkpoint .best_model_metrics is not None
1708+ assert "val_metric" in checkpoint .best_model_metrics
1709+
1710+ expected_value = 0.1 if mode == "min" else 0.3
1711+ assert checkpoint .best_model_metrics ["val_metric" ] == expected_value
1712+
1713+
16401714@pytest .mark .parametrize ("use_omegaconf" , [False , pytest .param (True , marks = RunIf (omegaconf = True ))])
16411715def test_hparams_type (tmp_path , use_omegaconf ):
16421716 class TestModel (BoringModel ):
0 commit comments