|
17 | 17 | import operator
|
18 | 18 | import os
|
19 | 19 | import sys
|
20 |
| -from contextlib import ExitStack, contextmanager, redirect_stdout |
| 20 | +from contextlib import ExitStack, contextmanager, redirect_stderr, redirect_stdout |
21 | 21 | from io import StringIO
|
22 | 22 | from pathlib import Path
|
23 | 23 | from typing import Callable, Optional, Union
|
@@ -487,6 +487,45 @@ def test_lightning_cli_print_config():
|
487 | 487 | assert outval["ckpt_path"] is None
|
488 | 488 |
|
489 | 489 |
|
| 490 | +class BoringCkptPathModel(BoringModel): |
| 491 | + def __init__(self, out_dim: int = 2, hidden_dim: int = 2) -> None: |
| 492 | + super().__init__() |
| 493 | + self.save_hyperparameters() |
| 494 | + self.layer = torch.nn.Linear(32, out_dim) |
| 495 | + |
| 496 | + |
| 497 | +def test_lightning_cli_ckpt_path_argument_hparams(cleandir): |
| 498 | + class CkptPathCLI(LightningCLI): |
| 499 | + def add_arguments_to_parser(self, parser): |
| 500 | + parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) |
| 501 | + |
| 502 | + cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"] |
| 503 | + with mock.patch("sys.argv", ["any.py"] + cli_args): |
| 504 | + cli = CkptPathCLI(BoringCkptPathModel) |
| 505 | + |
| 506 | + assert cli.config.fit.model.out_dim == 3 |
| 507 | + assert cli.config.fit.model.hidden_dim == 6 |
| 508 | + hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml" |
| 509 | + assert hparams_path.is_file() |
| 510 | + hparams = yaml.safe_load(hparams_path.read_text()) |
| 511 | + assert hparams["out_dim"] == 3 |
| 512 | + assert hparams["hidden_dim"] == 6 |
| 513 | + |
| 514 | + checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt")) |
| 515 | + cli_args = ["predict", f"--ckpt_path={checkpoint_path}"] |
| 516 | + with mock.patch("sys.argv", ["any.py"] + cli_args): |
| 517 | + cli = CkptPathCLI(BoringCkptPathModel) |
| 518 | + |
| 519 | + assert cli.config.predict.model.out_dim == 3 |
| 520 | + assert cli.config.predict.model.hidden_dim == 6 |
| 521 | + assert cli.config_init.predict.model.layer.out_features == 3 |
| 522 | + |
| 523 | + err = StringIO() |
| 524 | + with mock.patch("sys.argv", ["any.py"] + cli_args), redirect_stderr(err), pytest.raises(SystemExit): |
| 525 | + cli = LightningCLI(BoringModel) |
| 526 | + assert "Parsing of ckpt_path hyperparameters failed" in err.getvalue() |
| 527 | + |
| 528 | + |
490 | 529 | def test_lightning_cli_submodules(cleandir):
|
491 | 530 | class MainModule(BoringModel):
|
492 | 531 | def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1):
|
|
0 commit comments