-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
131 lines (115 loc) · 5.45 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import time
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.visualizer import Visualizer
from util.evaluator import IC15Evaluator
from util import util
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import gc
import warnings
warnings.filterwarnings("ignore")
def eval(opt, dataset, model, evaluator = None):
n_threads = torch.get_num_threads()
torch.set_num_threads(1)
model.eval()
evaluator.reset()
eval_start_time = time.time()
for i, data in enumerate(dataset):
torch.cuda.synchronize()
model.set_input(data)
preds = model.test()
evaluator.update(preds)
eval_time = time.time() - eval_start_time
res = '==>Evaluation time: {:.0f}, \n'.format(eval_time)
metric, select_score = evaluator.summary(select_iou = 0.5)
res += metric
torch.set_num_threads(n_threads)
return res, select_score
if __name__ == '__main__':
opt = TrainOptions().parse() # get training options
util.init_distributed_mode(opt)
torch.manual_seed(10)
if opt.device == 'cuda':
torch.cuda.manual_seed_all(10)
# train dataset
train_dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
train_size = len(train_dataset) # get the number of images in the dataset.
print('The number of training images = %d. Trainset: %s' % (train_size, opt.dataroot))
# val dataset
# opt.phase = 'val'
# val_dataset = create_dataset(opt)
# # val_evaluator = IC15Evaluator(opt)
# val_size = len(val_dataset)
# print('The number of test images = %d. Valset: %s' % (val_size, opt.dataroot))
# opt.phase = 'train'
# test dataset
#opt.phase = 'test'
#test_dataset = create_dataset(opt)
#test_evaluator = IC15Evaluator(opt)
#test_size = len(test_dataset)
#print('The number of test images = %d. Testset: %s' % (test_size, opt.dataroot))
#opt.phase = 'train'
model = create_model(opt)
model.setup(opt)
visualizer = Visualizer(opt)
total_iters = 0
best_score = 0.0
# for i, data in enumerate(train_dataset):
# epoch_start_time = time.time()
# import pdb;pdb.set_trace()
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
iter_data_time = time.time()
epoch_iter = 0
running_loss_row, running_loss_col, running_loss_all = 0, 0, 0
if opt.distributed:
train_dataset.set_epoch(epoch)
for i, data in enumerate(train_dataset):
gc.collect()
torch.cuda.empty_cache()
iter_start_time = time.time()
if total_iters % opt.print_freq == 0:
t_data = iter_start_time - iter_data_time
visualizer.reset()
total_iters += opt.batch_size * opt.world_size
epoch_iter += opt.batch_size * opt.world_size
model.set_input(data)
model.optimize_parameters()
# if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk
# losses = model.get_current_losses()
# t_comp = (time.time() - iter_start_time) / opt.batch_size
# visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
losses = model.get_current_losses()
running_loss_row += losses['rel_cls_row']
running_loss_col += losses['rel_cls_col']
running_loss_all += losses['rel_all']
if epoch_iter % opt.print_freq == 0:
tmp_losses = {
'L_row': running_loss_row / (opt.batch_size * i),
'L_col': running_loss_col / (opt.batch_size * i),
'L_all': running_loss_all / (opt.batch_size * i)}
visualizer.print_current_losses(epoch, epoch_iter, tmp_losses)
# if total_iters % opt.save_latest_freq == 0: # cache our latest model every <save_latest_freq> iterations
# print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
# save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
# model.save_networks(save_suffix)
# val_res, metric_score = eval(opt, val_dataset, model, val_evaluator)
# visualizer.print_current_val(epoch, epoch_iter, val_res)
# if metric_score > best_score:
# best_score = metric_score
# print('saving the best model (epoch %d, total_iters %d)' % (epoch, total_iters))
# model.save_networks('best')
# #model.metric = best_acc
# best_res = 'current avg score: {:.4f}, best score: {:.6f}'.format(metric_score, best_score)
# visualizer.print_current_val(epoch, epoch_iter, best_res)
iter_data_time = time.time()
if epoch % opt.save_epoch_freq == 0: # cache our model every <save_epoch_freq> epochs
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
model.save_networks('latest')
model.save_networks(epoch)
print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
model.update_learning_rate()