Skip to content

Commit e72fc19

Browse files
mauvilsalantiga
authored andcommitted
Fix LightningCLI not using ckpt_path hyperparameters to instantiate classes (#21116)
* Fix LightningCLI not using ckpt_path hyperparameters to instantiate classes * Add message about ckpt_path hyperparameters when parsing fails * Apply suggestions from code review --------- Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Nicki Skafte Detlefsen <[email protected]> Co-authored-by: Jirka B <[email protected]> (cherry picked from commit 4824cc1)
1 parent 5f43767 commit e72fc19

File tree

3 files changed

+63
-1
lines changed

3 files changed

+63
-1
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1717

1818
### Fixed
1919

20+
- Fixed `LightningCLI` not using `ckpt_path` hyperparameters to instantiate classes ([#21116](https://github.com/Lightning-AI/pytorch-lightning/pull/21116))
21+
22+
2023
- Fixed callbacks by defer step/time-triggered `ModelCheckpoint` saves until validation metrics are available ([#21106](https://github.com/Lightning-AI/pytorch-lightning/pull/21106))
2124

2225

src/lightning/pytorch/cli.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import sys
1717
from collections.abc import Iterable
1818
from functools import partial, update_wrapper
19+
from pathlib import Path
1920
from types import MethodType
2021
from typing import Any, Callable, Optional, TypeVar, Union
2122

@@ -397,6 +398,7 @@ def __init__(
397398
main_kwargs, subparser_kwargs = self._setup_parser_kwargs(self.parser_kwargs)
398399
self.setup_parser(run, main_kwargs, subparser_kwargs)
399400
self.parse_arguments(self.parser, args)
401+
self._parse_ckpt_path()
400402

401403
self.subcommand = self.config["subcommand"] if run else None
402404

@@ -551,6 +553,24 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
551553
else:
552554
self.config = parser.parse_args(args)
553555

556+
def _parse_ckpt_path(self) -> None:
557+
"""If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config."""
558+
if not self.config.get("subcommand"):
559+
return
560+
ckpt_path = self.config[self.config.subcommand].get("ckpt_path")
561+
if ckpt_path and Path(ckpt_path).is_file():
562+
ckpt = torch.load(ckpt_path, weights_only=True, map_location="cpu")
563+
hparams = ckpt.get("hyper_parameters", {})
564+
hparams.pop("_instantiator", None)
565+
if not hparams:
566+
return
567+
hparams = {self.config.subcommand: {"model": hparams}}
568+
try:
569+
self.config = self.parser.parse_object(hparams, self.config)
570+
except SystemExit:
571+
sys.stderr.write("Parsing of ckpt_path hyperparameters failed!\n")
572+
raise
573+
554574
def _dump_config(self) -> None:
555575
if hasattr(self, "config_dump"):
556576
return

tests/tests_pytorch/test_cli.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import operator
1818
import os
1919
import sys
20-
from contextlib import ExitStack, contextmanager, redirect_stdout
20+
from contextlib import ExitStack, contextmanager, redirect_stderr, redirect_stdout
2121
from io import StringIO
2222
from pathlib import Path
2323
from typing import Callable, Optional, Union
@@ -487,6 +487,45 @@ def test_lightning_cli_print_config():
487487
assert outval["ckpt_path"] is None
488488

489489

490+
class BoringCkptPathModel(BoringModel):
491+
def __init__(self, out_dim: int = 2, hidden_dim: int = 2) -> None:
492+
super().__init__()
493+
self.save_hyperparameters()
494+
self.layer = torch.nn.Linear(32, out_dim)
495+
496+
497+
def test_lightning_cli_ckpt_path_argument_hparams(cleandir):
498+
class CkptPathCLI(LightningCLI):
499+
def add_arguments_to_parser(self, parser):
500+
parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2)
501+
502+
cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"]
503+
with mock.patch("sys.argv", ["any.py"] + cli_args):
504+
cli = CkptPathCLI(BoringCkptPathModel)
505+
506+
assert cli.config.fit.model.out_dim == 3
507+
assert cli.config.fit.model.hidden_dim == 6
508+
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
509+
assert hparams_path.is_file()
510+
hparams = yaml.safe_load(hparams_path.read_text())
511+
assert hparams["out_dim"] == 3
512+
assert hparams["hidden_dim"] == 6
513+
514+
checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"))
515+
cli_args = ["predict", f"--ckpt_path={checkpoint_path}"]
516+
with mock.patch("sys.argv", ["any.py"] + cli_args):
517+
cli = CkptPathCLI(BoringCkptPathModel)
518+
519+
assert cli.config.predict.model.out_dim == 3
520+
assert cli.config.predict.model.hidden_dim == 6
521+
assert cli.config_init.predict.model.layer.out_features == 3
522+
523+
err = StringIO()
524+
with mock.patch("sys.argv", ["any.py"] + cli_args), redirect_stderr(err), pytest.raises(SystemExit):
525+
cli = LightningCLI(BoringModel)
526+
assert "Parsing of ckpt_path hyperparameters failed" in err.getvalue()
527+
528+
490529
def test_lightning_cli_submodules(cleandir):
491530
class MainModule(BoringModel):
492531
def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1):

0 commit comments

Comments
 (0)