@@ -97,30 +97,30 @@ def write_tiff(img_wrt, filename, metadata):
97
97
return filename
98
98
99
99
def add_default_checkpointing_config (config ):
100
-
101
- subcommand = config ["subcommand" ]
102
- enable_checkpointing = config [subcommand + ".trainer.enable_checkpointing" ]
103
- callbacks = config [subcommand + ".trainer.callbacks" ]
104
- check_callbacks = [op for op in callbacks if "ModelCheckpoint" in op .class_path ]
105
-
106
- if len (check_callbacks ) > 0 :
107
- there_is_checkpointing = True
108
- else :
109
- there_is_checkpointing = False
110
-
111
- if enable_checkpointing :
112
- if not there_is_checkpointing :
113
- logger .info ("Enabling ModelCheckpoint since the user defined enable_checkpointing=True." )
114
-
115
- config ["ModelCheckpoint" ] = StateDictAwareModelCheckpoint
116
- config ["ModelCheckpoint.filename" ] = "{epoch}"
117
- config ["ModelCheckpoint.monitor" ] = "val/loss"
118
- config ["StateDictModelCheckpoint" ] = StateDictAwareModelCheckpoint
119
- config ["StateDictModelCheckpoint.filename" ] = "{epoch}_state_dict"
120
- config ["StateDictModelCheckpoint.save_weights_only" ] = True
121
- config ["StateDictModelCheckpoint.monitor" ] = "val/loss"
100
+ subcommand = config .get ("subcommand" , None )
101
+ if subcommand is not None :
102
+ enable_checkpointing = config [subcommand + ".trainer.enable_checkpointing" ]
103
+ callbacks = config [subcommand + ".trainer.callbacks" ]
104
+ check_callbacks = [op for op in callbacks if "ModelCheckpoint" in op .class_path ]
105
+
106
+ if len (check_callbacks ) > 0 :
107
+ there_is_checkpointing = True
122
108
else :
123
- logger .info ("No extra checkpoint config will be added, since the user already defined it in the callbacks." )
109
+ there_is_checkpointing = False
110
+
111
+ if enable_checkpointing :
112
+ if not there_is_checkpointing :
113
+ logger .info ("Enabling ModelCheckpoint since the user defined enable_checkpointing=True." )
114
+
115
+ config ["ModelCheckpoint" ] = StateDictAwareModelCheckpoint
116
+ config ["ModelCheckpoint.filename" ] = "{epoch}"
117
+ config ["ModelCheckpoint.monitor" ] = "val/loss"
118
+ config ["StateDictModelCheckpoint" ] = StateDictAwareModelCheckpoint
119
+ config ["StateDictModelCheckpoint.filename" ] = "{epoch}_state_dict"
120
+ config ["StateDictModelCheckpoint.save_weights_only" ] = True
121
+ config ["StateDictModelCheckpoint.monitor" ] = "val/loss"
122
+ else :
123
+ logger .info ("No extra checkpoint config will be added, since the user already defined it in the callbacks." )
124
124
125
125
return config
126
126
0 commit comments