Skip to content

Commit affa9b0

Browse files
Merge pull request #501 from IBM/romeokienzler-patch-2
Update cli_tools.py, make subcommand optional
2 parents 884e43e + 6a4cc08 commit affa9b0

File tree

1 file changed

+23
-23
lines changed

1 file changed

+23
-23
lines changed

terratorch/cli_tools.py

+23-23
Original file line numberDiff line numberDiff line change
@@ -97,30 +97,30 @@ def write_tiff(img_wrt, filename, metadata):
9797
return filename
9898

9999
def add_default_checkpointing_config(config):
100-
101-
subcommand = config["subcommand"]
102-
enable_checkpointing = config[subcommand + ".trainer.enable_checkpointing"]
103-
callbacks = config[subcommand + ".trainer.callbacks"]
104-
check_callbacks = [op for op in callbacks if "ModelCheckpoint" in op.class_path]
105-
106-
if len(check_callbacks) > 0:
107-
there_is_checkpointing = True
108-
else:
109-
there_is_checkpointing = False
110-
111-
if enable_checkpointing:
112-
if not there_is_checkpointing:
113-
logger.info("Enabling ModelCheckpoint since the user defined enable_checkpointing=True.")
114-
115-
config["ModelCheckpoint"] = StateDictAwareModelCheckpoint
116-
config["ModelCheckpoint.filename"] = "{epoch}"
117-
config["ModelCheckpoint.monitor"] = "val/loss"
118-
config["StateDictModelCheckpoint"] = StateDictAwareModelCheckpoint
119-
config["StateDictModelCheckpoint.filename"] = "{epoch}_state_dict"
120-
config["StateDictModelCheckpoint.save_weights_only"] = True
121-
config["StateDictModelCheckpoint.monitor"] = "val/loss"
100+
subcommand = config.get("subcommand", None)
101+
if subcommand is not None:
102+
enable_checkpointing = config[subcommand + ".trainer.enable_checkpointing"]
103+
callbacks = config[subcommand + ".trainer.callbacks"]
104+
check_callbacks = [op for op in callbacks if "ModelCheckpoint" in op.class_path]
105+
106+
if len(check_callbacks) > 0:
107+
there_is_checkpointing = True
122108
else:
123-
logger.info("No extra checkpoint config will be added, since the user already defined it in the callbacks.")
109+
there_is_checkpointing = False
110+
111+
if enable_checkpointing:
112+
if not there_is_checkpointing:
113+
logger.info("Enabling ModelCheckpoint since the user defined enable_checkpointing=True.")
114+
115+
config["ModelCheckpoint"] = StateDictAwareModelCheckpoint
116+
config["ModelCheckpoint.filename"] = "{epoch}"
117+
config["ModelCheckpoint.monitor"] = "val/loss"
118+
config["StateDictModelCheckpoint"] = StateDictAwareModelCheckpoint
119+
config["StateDictModelCheckpoint.filename"] = "{epoch}_state_dict"
120+
config["StateDictModelCheckpoint.save_weights_only"] = True
121+
config["StateDictModelCheckpoint.monitor"] = "val/loss"
122+
else:
123+
logger.info("No extra checkpoint config will be added, since the user already defined it in the callbacks.")
124124

125125
return config
126126

0 commit comments

Comments
 (0)