19
19
from .basic_utils import init_random_seed , set_random_seed , build_custom_hooks
20
20
21
21
Predefined_Control_Keys = ['max_iters' , 'log_interval' , 'val_interval' , 'save_interval' , 'max_save_num' ,
22
- 'seed' , 'cudnn_deterministic' , 'pretrained_model' , 'checkpoint' , 'test_mode' ]
22
+ 'seed' , 'cudnn_deterministic' , 'pretrained_model' , 'checkpoint' , 'test_mode' ,
23
+ 'save_best_model' ]
23
24
24
25
25
26
def train (args ):
@@ -37,6 +38,15 @@ def train(args):
37
38
control_cfg = cfg .control
38
39
for key in control_cfg .keys ():
39
40
assert key in Predefined_Control_Keys , '{} is not allowed appeared in control keys' .format (key )
41
+ # set default values for control keys
42
+ max_iters = control_cfg .get ('max_iters' ,100000 )
43
+ log_interval = control_cfg .get ('log_interval' , 100 )
44
+ val_interval = control_cfg .get ('val_interval' , 5000 )
45
+ save_interval = control_cfg .get ('save_interval' , 5000 )
46
+ max_save_num = control_cfg .get ('max_save_num' , 1 )
47
+ cudnn_deter_flag = control_cfg .get ('cudnn_deterministic' , False )
48
+ test_mode = control_cfg .get ('test_mode' , False )
49
+ save_best_model = control_cfg .get ('save_best_model' , True )
40
50
# create log dir
41
51
run_id = random .randint (1 , 100000 )
42
52
run_id_tensor = torch .ones ((1 ,), device = 'cuda:{}' .format (local_rank )) * run_id
@@ -72,8 +82,8 @@ def train(args):
72
82
seed = control_cfg .get ('seed' , None )
73
83
random_seed = init_random_seed (seed )
74
84
logger .info (f'Set random random_seed to { random_seed } , '
75
- f'deterministic: { control_cfg . cudnn_deterministic } ' )
76
- set_random_seed (random_seed , deterministic = control_cfg . cudnn_deterministic )
85
+ f'deterministic: { cudnn_deter_flag } ' )
86
+ set_random_seed (random_seed , deterministic = cudnn_deter_flag )
77
87
#
78
88
# build dataloader
79
89
train_loaders , test_loaders = parse_args_for_multiple_datasets (cfg ['datasets' ],
@@ -92,7 +102,7 @@ def train(args):
92
102
'scheduler_dict' : scheduler_dict ,
93
103
'train_loaders' : train_loaders ,
94
104
'logdir' : logdir ,
95
- 'log_interval' : control_cfg . log_interval
105
+ 'log_interval' : log_interval ,
96
106
}
97
107
training_args = cfg .train
98
108
training_hook_args = training_args .pop ('custom_hooks' , None )
@@ -143,28 +153,27 @@ def train(args):
143
153
# build custom validator hooks
144
154
build_custom_hooks (test_hook_args , validator )
145
155
# test mode: only conduct test process
146
- test_mode = control_cfg .get ('test_mode' , False )
147
156
if test_mode :
148
157
validator (trainer .iteration )
149
158
exit (0 )
150
159
########################################
151
160
# register training hooks
152
- log_interval = control_cfg . log_interval
161
+
153
162
updater_iter = control_cfg .get ('update_iter' , 1 )
154
163
train_time_recoder = TrainTimeLogger (log_interval )
155
164
trainer .register_hook (train_time_recoder )
156
165
scheduler_step = SchedulerStep (updater_iter )
157
166
lr_recoder = LrLogger (log_interval )
158
167
trainer .register_hook (lr_recoder , priority = 'HIGH' )
159
168
trainer .register_hook (scheduler_step , priority = 'VERY_LOW' )
160
- save_model_hook = SaveCheckpoint (control_cfg [ ' max_save_num' ] , save_interval = control_cfg [ ' save_interval' ] )
169
+ save_model_hook = SaveCheckpoint (max_save_num = max_save_num , save_interval = save_interval )
161
170
trainer .register_hook (save_model_hook ,
162
171
priority = 'LOWEST' ) # save model after scheduler step to get the right iteration number
163
172
########################################
164
173
# build custom training hooks
165
174
build_custom_hooks (training_hook_args , trainer )
166
175
# deal with val_interval
167
- val_point_list = deal_with_val_interval (control_cfg [ ' val_interval' ] , max_iters = control_cfg [ ' max_iters' ] ,
176
+ val_point_list = deal_with_val_interval (val_interval , max_iters = max_iters ,
168
177
trained_iteration = trained_iteration )
169
178
# start training and testing
170
179
last_val_point = trained_iteration
@@ -175,7 +184,7 @@ def train(args):
175
184
# test
176
185
save_flag , early_stop_flag = validator (trainer .iteration )
177
186
#
178
- if save_flag :
187
+ if save_flag and save_best_model :
179
188
save_path = os .path .join (trainer .logdir , "best_model.pth" .format (trainer .iteration ))
180
189
torch .save (trainer .state_dict (), save_path )
181
190
# early stop
0 commit comments