Skip to content

Commit

Permalink
Refine test_step formatting and trainer configuration
Browse files Browse the repository at this point in the history
- Update test_step method signature formatting in models.py for better readability
- Simplify on_test_start by removing unnecessary test_config logging
- Add trainer.test call in train_hnet.py to execute testing after training
  • Loading branch information
MaloOLIVIER committed Dec 2, 2024
1 parent 8e9b7aa commit 88f9df2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
9 changes: 4 additions & 5 deletions hungarian_net/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def validation_step(self, batch, batch_idx) -> Dict[str, torch.Tensor]:

return outputs

def test_step(self, batch: Dict[str, torch.Tensor], batch_idx) -> Dict[str, torch.Tensor]:
def test_step(
self, batch: Dict[str, torch.Tensor], batch_idx
) -> Dict[str, torch.Tensor]:
"""
Lightning test step.
Expand Down Expand Up @@ -157,10 +159,7 @@ def on_test_start(self) -> None:
# 1. Reset test metrics
self.f1.reset()

# 2. Log test configurations
self.logger.log_dict({"test_config": self.hparams})

# 3. Disable certain training-specific settings
# 2. Disable certain training-specific settings
self.model.eval()

def on_train_batch_end(
Expand Down
2 changes: 2 additions & 0 deletions hungarian_net/train_hnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def main(

trainer.fit(lightning_module, datamodule=lightning_datamodule)

trainer.test(ckpt_path="best", datamodule=lightning_datamodule)


def set_seed(seed=42):
L.seed_everything(seed, workers=True)
Expand Down

0 comments on commit 88f9df2

Please sign in to comment.