diff --git a/hungarian_net/models.py b/hungarian_net/models.py index 64ed4b1..f817bd8 100644 --- a/hungarian_net/models.py +++ b/hungarian_net/models.py @@ -121,7 +121,9 @@ def validation_step(self, batch, batch_idx) -> Dict[str, torch.Tensor]: return outputs - def test_step(self, batch: Dict[str, torch.Tensor], batch_idx) -> Dict[str, torch.Tensor]: + def test_step( + self, batch: Dict[str, torch.Tensor], batch_idx + ) -> Dict[str, torch.Tensor]: """ Lightning test step. @@ -157,10 +159,7 @@ def on_test_start(self) -> None: # 1. Reset test metrics self.f1.reset() - # 2. Log test configurations - self.logger.log_dict({"test_config": self.hparams}) - - # 3. Disable certain training-specific settings + # 2. Disable certain training-specific settings self.model.eval() def on_train_batch_end( diff --git a/hungarian_net/train_hnet.py b/hungarian_net/train_hnet.py index 204eab2..374a6d8 100644 --- a/hungarian_net/train_hnet.py +++ b/hungarian_net/train_hnet.py @@ -163,6 +163,8 @@ def main( trainer.fit(lightning_module, datamodule=lightning_datamodule) + trainer.test(ckpt_path="best", datamodule=lightning_datamodule) + def set_seed(seed=42): L.seed_everything(seed, workers=True)