-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathmain.py
145 lines (127 loc) · 5.82 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# -*- coding:utf-8 -*-
from __future__ import absolute_import, division, print_function
from codecs import open as open
import torch
from preprocess import process_poems, start_token
import pdb
import tqdm
import numpy as np
import argparse
import sys
import os
import random
from preprocess import pos2PE
torch.manual_seed(0)
def sequence_collate(batch):
transposed = zip(*batch)
ret = [torch.nn.utils.rnn.pack_sequence(sorted(samples, key = len, reverse = True)) for samples in transposed]
# ret = [torch.nn.utils.rnn.pack_sequence(samples) for samples in transposed]
return ret
def prob_sample(w_list, topn = 10):
samples = []
for weights in w_list:
idx = np.argsort(weights)[::-1]
t = np.cumsum(weights[idx[:topn]])
s = np.sum(weights[idx[:topn]])
sample = int(np.searchsorted(t, np.random.rand(1) * s))
samples.append(idx[sample])
return np.array(samples)
def infer(model, final, words, word2int, emb, hidden_size=256, start=u'春', n=1, num=5):
dim_PE = 100
PE_const = 1000
device = torch.device('cpu') if isinstance(final.weight, torch.FloatTensor) else final.weight.get_device()
h = torch.zeros((1, n, hidden_size))
x = torch.nn.functional.embedding(torch.full((n,), word2int[start[0]], dtype=torch.long), emb).unsqueeze(0)
ret = [[start[0]] for i in range(n)]
for i in range(num * 4 - 1):
# add PE dims
pe = torch.tensor(pos2PE((i % num) + 1), dtype=torch.float).repeat(1, n, 1)
x, h, pe = x.to(device), h.to(device), pe.to(device)
x = torch.cat((x, pe), dim=2)
x, h = model(x, h)
if i % num == num - 1 and i // num + 1 < len(start):
w = np.array([word2int[start[i // num + 1]] for _ in range(n)])
else:
w = prob_sample(torch.nn.functional.softmax(final(x.view(-1, hidden_size)), dim=-1).data.cpu().numpy())
x = torch.nn.functional.embedding(torch.from_numpy(w), emb).unsqueeze(0)
for j in range(len(w)):
ret[j].append(words[w[j]])
if i % num == num - 2:
if sys.version_info.major == 2:
ret[j].append(u"," if i // num % 2 == 0 else u"。")
else:
ret[j].append("," if i // num % 2 == 0 else "。")
ret_list = []
for i in range(n):
if sys.version_info.major == 2:
ret_list.append(u"".join(ret[i]))
else:
ret_list.append("".join(ret[i]))
return ret_list
def main(epoch=10, batch_size=4, hidden_size=256, save_dir='./model', save_name='current.pth'):
dataset, words, word2int = process_poems('./data/poems.txt', './data/sgns.sikuquanshu.word')
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=sequence_collate)
model = torch.nn.GRU(input_size=dataset.emb_dim, hidden_size=hidden_size)
final = torch.nn.Linear(hidden_size, dataset.voc_size, bias=False)
opt = torch.optim.Adam(list(model.parameters()) + list(final.parameters()))
if torch.cuda.is_available():
model, final = model.cuda(), final.cuda()
for epoch in range(epoch):
data_iter = tqdm.tqdm(enumerate(loader),
desc="EP_%d" % (epoch),
total=len(loader),
bar_format="{l_bar}{r_bar}")
for i, data in data_iter:
data, label = data
# pdb.set_trace()
if torch.cuda.is_available():
data, label = data.cuda(), label.cuda()
pred, _ = model(data)
pred = final(pred.data)
loss = torch.nn.functional.cross_entropy(pred, label.data)
opt.zero_grad()
loss.backward()
opt.step()
# print(loss)
if i % 100 == 0:
post_fix = {
"epoch": epoch,
"iter": i,
"loss": loss.item(),
"example": infer(model, final, words, word2int, dataset.emb, hidden_size = hidden_size, num = 5 if random.random() < 0.5 else 7)
}
if sys.version_info.major == 2:
data_iter.write(unicode(post_fix))
else:
data_iter.write(str(post_fix))
# break
tmp_infer_rst = infer(model, final, words, word2int, dataset.emb, hidden_size = hidden_size, n=5, num=5)
if sys.version_info.major == 2:
tmp_infer_rst = u"\n".join(tmp_infer_rst).encode('utf-8')
else:
tmp_infer_rst = "\n".join(tmp_infer_rst)
print(tmp_infer_rst)
tmp_infer_rst = infer(model, final, words, word2int, dataset.emb, hidden_size = hidden_size, n=5, num=7)
if sys.version_info.major == 2:
tmp_infer_rst = u"\n".join(tmp_infer_rst).encode('utf-8')
else:
tmp_infer_rst = "\n".join(tmp_infer_rst)
print(tmp_infer_rst)
print('Saving...')
torch.save({
'model': model.cpu(),
'final': final.cpu(),
'words': words,
'word2int': word2int,
'emb': dataset.emb
}, os.path.join(save_dir, save_name))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--epoch", type=int, default=10, help="number of epochs")
parser.add_argument("-b", "--batch_size", type=int, default=4, help="number of batch_size")
parser.add_argument("-hs", "--hidden_size", type=int, default=256, help="hidden size of RNN")
parser.add_argument('-d', "--save_dir", type = str, default = './model', help='directory to save files in')
parser.add_argument('-n', "--name", type=str, default='current.pth', help='file name')
args = parser.parse_args()
print(args)
main(epoch=args.epoch, batch_size=args.batch_size, hidden_size=args.hidden_size, save_dir = args.save_dir, save_name = args.name)