Skip to content

Commit fd2778d

Browse files
committed
V0.0.18
1. add default value for some control keys: max_iters, log_interval, val_interval, save_interval, max_save_num, cudnn_deter_flag 2. add "save_best_model" into control keys 3. raise runtimeerror when val_interval is larger than max_iters 4. allowing max_save_num=0 now
1 parent 200d459 commit fd2778d

File tree

5 files changed

+25
-12
lines changed

5 files changed

+25
-12
lines changed

fastda/hooks/training_hooks.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -168,4 +168,5 @@ def after_train_iter(self, runner):
168168
if len(saved_files) >= self.max_save_num:
169169
sorted_files_by_ctime = sorted(saved_files, key=lambda x: os.path.getctime(x))
170170
os.remove(sorted_files_by_ctime[0])
171-
torch.save(runner.state_dict(), save_path)
171+
if self.max_save_num > 0:
172+
torch.save(runner.state_dict(), save_path)

fastda/utils/basic_utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def reduce_trained_iteration(val_checkpoint):
6565

6666
if isinstance(val_interval, (int, float)):
6767
val_times = int(max_iters / val_interval)
68+
if val_times == 0:
69+
raise RuntimeError(
70+
'max_iters number {} should be larger than val_interval {}'.format(max_iters, val_interval))
6871
for i in range(1, val_times + 1):
6972
fine_grained_val_checkpoint.append(i * int(val_interval))
7073
if fine_grained_val_checkpoint[-1] != max_iters:

fastda/utils/train_api.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from .basic_utils import init_random_seed, set_random_seed, build_custom_hooks
2020

2121
Predefined_Control_Keys = ['max_iters', 'log_interval', 'val_interval', 'save_interval', 'max_save_num',
22-
'seed', 'cudnn_deterministic', 'pretrained_model', 'checkpoint', 'test_mode']
22+
'seed', 'cudnn_deterministic', 'pretrained_model', 'checkpoint', 'test_mode',
23+
'save_best_model']
2324

2425

2526
def train(args):
@@ -37,6 +38,15 @@ def train(args):
3738
control_cfg = cfg.control
3839
for key in control_cfg.keys():
3940
assert key in Predefined_Control_Keys, '{} is not allowed appeared in control keys'.format(key)
41+
# set default values for control keys
42+
max_iters = control_cfg.get('max_iters',100000)
43+
log_interval = control_cfg.get('log_interval', 100)
44+
val_interval = control_cfg.get('val_interval', 5000)
45+
save_interval = control_cfg.get('save_interval', 5000)
46+
max_save_num = control_cfg.get('max_save_num', 1)
47+
cudnn_deter_flag = control_cfg.get('cudnn_deterministic', False)
48+
test_mode = control_cfg.get('test_mode', False)
49+
save_best_model = control_cfg.get('save_best_model', True)
4050
# create log dir
4151
run_id = random.randint(1, 100000)
4252
run_id_tensor = torch.ones((1,), device='cuda:{}'.format(local_rank)) * run_id
@@ -72,8 +82,8 @@ def train(args):
7282
seed = control_cfg.get('seed', None)
7383
random_seed = init_random_seed(seed)
7484
logger.info(f'Set random random_seed to {random_seed}, '
75-
f'deterministic: {control_cfg.cudnn_deterministic}')
76-
set_random_seed(random_seed, deterministic=control_cfg.cudnn_deterministic)
85+
f'deterministic: {cudnn_deter_flag}')
86+
set_random_seed(random_seed, deterministic=cudnn_deter_flag)
7787
#
7888
# build dataloader
7989
train_loaders, test_loaders = parse_args_for_multiple_datasets(cfg['datasets'],
@@ -92,7 +102,7 @@ def train(args):
92102
'scheduler_dict': scheduler_dict,
93103
'train_loaders': train_loaders,
94104
'logdir': logdir,
95-
'log_interval': control_cfg.log_interval
105+
'log_interval': log_interval,
96106
}
97107
training_args = cfg.train
98108
training_hook_args = training_args.pop('custom_hooks', None)
@@ -143,28 +153,27 @@ def train(args):
143153
# build custom validator hooks
144154
build_custom_hooks(test_hook_args, validator)
145155
# test mode: only conduct test process
146-
test_mode = control_cfg.get('test_mode', False)
147156
if test_mode:
148157
validator(trainer.iteration)
149158
exit(0)
150159
########################################
151160
# register training hooks
152-
log_interval = control_cfg.log_interval
161+
153162
updater_iter = control_cfg.get('update_iter', 1)
154163
train_time_recoder = TrainTimeLogger(log_interval)
155164
trainer.register_hook(train_time_recoder)
156165
scheduler_step = SchedulerStep(updater_iter)
157166
lr_recoder = LrLogger(log_interval)
158167
trainer.register_hook(lr_recoder, priority='HIGH')
159168
trainer.register_hook(scheduler_step, priority='VERY_LOW')
160-
save_model_hook = SaveCheckpoint(control_cfg['max_save_num'], save_interval=control_cfg['save_interval'])
169+
save_model_hook = SaveCheckpoint(max_save_num=max_save_num, save_interval=save_interval)
161170
trainer.register_hook(save_model_hook,
162171
priority='LOWEST') # save model after scheduler step to get the right iteration number
163172
########################################
164173
# build custom training hooks
165174
build_custom_hooks(training_hook_args, trainer)
166175
# deal with val_interval
167-
val_point_list = deal_with_val_interval(control_cfg['val_interval'], max_iters=control_cfg['max_iters'],
176+
val_point_list = deal_with_val_interval(val_interval, max_iters=max_iters,
168177
trained_iteration=trained_iteration)
169178
# start training and testing
170179
last_val_point = trained_iteration
@@ -175,7 +184,7 @@ def train(args):
175184
# test
176185
save_flag, early_stop_flag = validator(trainer.iteration)
177186
#
178-
if save_flag:
187+
if save_flag and save_best_model:
179188
save_path = os.path.join(trainer.logdir, "best_model.pth".format(trainer.iteration))
180189
torch.save(trainer.state_dict(), save_path)
181190
# early stop

fastda/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) VIMLab. All rights reserved.
22

3-
__version__ = '0.0.17'
3+
__version__ = '0.0.18'
44
short_version = __version__
55

66

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name="fastda",
8-
version="0.0.17",
8+
version="0.0.18",
99
author="Yixin Zhang",
1010
author_email="[email protected]",
1111
description="A simple framework for unsupervised domain adaptation",

0 commit comments

Comments
 (0)