-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
101 lines (82 loc) · 3.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import caffe
from model import Net
import torch
import torch.nn as nn
from torch.autograd import Variable
from skimage import color
import time
def to_gray(color_img):
color_img = color.rgb2lab(color_img)
grey_img = color_img.copy()
grey_img[:, :, 1:] = 0
return grey_img
parser = argparse.ArgumentParser()
# Generic
parser.add_argument(('-p', '--data_path'), type=str, default='',
help='input data path')
# Optimization
parser.add_argument(('-b', '--batch'), type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument(('-e', '--epoch'), type=int, default=200, metavar='N',
help='the number of training epoches')
parser.add_argument(('-c', '--cuda'), type=bool, default=True, metavar='N',
help='whether to use cuda or not')
parser.add_argument(('-lr', '--learning-rate'), type=float, default=0.01,
help='learning rate (default: 0.01)')
parser.add_argument(('-d', '--learning-rate-decay-factor'), type=float, default=0.5,
help='learning rate decay factor (default: 0.5)')
# Checkpoint
parser.add_argument(('-cp', '--checkpoint-every'), type=int, default=200,
help='checkpoint frequency')
parser.add_argument(('-rc', '--resume-from-checkpoint'), type=str, default='',
help='checkpoint file to resume from')
args = parser.parse_args()
path = args.data_path
batch_size = args.batch
num_epoches = args.epoch
use_cuda = args.cuda
lr = args.learning_rate
lr_decay = args.learning_rate_decay_factor
checkpoint_every = args.checkpoint_every
checkpoint_file = args.resume_from_checkpoint
if checkpoint_file != '':
print('Loading checkpoint from %s' % checkpoint_file)
model = torch.load(checkpoint_file)
else:
print('Initializing model from scratch')
model = Net(use_cuda)
criterion = nn.MSELoss()
optimizer = None
if __name__ == "__main__":
# Use numpy array here
color_img = caffe.io.load_image(path)
grey_img = to_gray(color_img)
num_imgs = grey_img.shape[0]
num_batches = num_imgs // batch_size
loss_list = []
# Convert to pytorch Variable while training
for i in range(num_epoches):
# Random shuffle data every epoch
perm = torch.randperm(num_imgs)
grey_img, color_img = grey_img[perm], color_img[perm]
loss_sum = 0.0
for j in range(num_batches):
grey_batch, color_batch = grey_img[j * batch_size:(
j + 1) * batch_size], color_img[j * batch_size:(j + 1) * batch_size]
grey_batch, color_batch = Variable(
grey_batch), Variable(color_batch)
output_batch = model(grey_batch)
optimizer.zero_grad()
err = criterion(output_batch, color_batch)
loss_sum += err.data[0]
err.backward()
optimizer.step()
t = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
print("[%s] Epoch %d, Batch %d, Loss = %lf" %
(t, i + 1, j + 1, err))
loss_sum /= num_batches
loss_list.append(loss_sum)
if i % checkpoint_every == 0:
torch.save(model, 'TrainedModel')
print(loss_list)