-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
56 lines (44 loc) · 1.61 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn as nn
def train(model, dataloader, criterion, optimizer, opt,
training_epoch=20, vis=None, name='None'):
for epoch in range(training_epoch):
loss_epoch = 0
step = 0
for step, (data, label) in enumerate(dataloader, 1):
if opt.cuda is not None:
data = data.cuda()
label = label.cuda()
out = model(data)
out = out.view(-1, out.size(-1))
label = label.view(-1)
loss = criterion(out, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_epoch += loss.item()
if vis is not None:
vis.update(loss_epoch/step)
print("{0} Loss {1}".format(name, loss_epoch/step))
def get_loss(model, dataloader, criterion, vis=None):
loss_epoch = 0
step = 0
for step, (data, label) in enumerate(dataloader, 1):
if model.use_cuda:
data = data.cuda()
label = label.cuda()
out = model(data)
out = out.view(-1, out.size(-1))
label = label.view(-1)
loss = criterion(out, label)
loss_epoch += loss.item()
if vis is not None:
vis.update(loss_epoch/step)
return loss_epoch
def sample_data(model, save_path='./real_data.txt',
sample_num=1000, batch_size=100, seq_len=20):
with open(save_path, 'w') as f:
for i in range(sample_num // batch_size):
out = model.sample(batch_size, seq_len).tolist()
for sample in out:
f.write("%s\n" % ' '.join(map(str, sample)))