forked from SamLynnEvans/Transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathModels.py
More file actions
71 lines (62 loc) · 2.54 KB
/
Models.py
File metadata and controls
71 lines (62 loc) · 2.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
from Layers import EncoderLayer, DecoderLayer
from Embed import Embedder, PositionalEncoder
from Sublayers import Norm
import torch.nn.functional as F
import copy
def get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class Encoder(nn.Module):
def __init__(self, vocab_size, d_model, N, heads, dropout):
super().__init__()
self.N = N
self.embed = Embedder(vocab_size, d_model)
self.pe = PositionalEncoder(d_model, dropout=dropout)
self.layers = get_clones(EncoderLayer(d_model, heads, dropout), N)
self.norm = Norm(d_model)
def forward(self, src, mask):
x = self.embed(src)
x = self.pe(x)
for i in range(self.N):
x = self.layers[i](x, mask)
return self.norm(x)
class Decoder(nn.Module):
def __init__(self, vocab_size, d_model, N, heads, dropout):
super().__init__()
self.N = N
self.embed = Embedder(vocab_size, d_model)
self.pe = PositionalEncoder(d_model, dropout=dropout)
self.layers = get_clones(DecoderLayer(d_model, heads, dropout), N)
self.norm = Norm(d_model)
def forward(self, trg, e_outputs, src_mask, trg_mask):
x = self.embed(trg)
x = self.pe(x)
for i in range(self.N):
x = self.layers[i](x, e_outputs, src_mask, trg_mask)
return self.norm(x)
class Transformer(nn.Module):
def __init__(self, src_vocab_size, trg_vocab_size, d_model, N, heads, dropout):
super().__init__()
self.encoder = Encoder(src_vocab_size, d_model, N, heads, dropout)
self.decoder = Decoder(trg_vocab_size, d_model, N, heads, dropout)
self.out = nn.Linear(d_model, trg_vocab_size)
def forward(self, src, trg, src_mask, trg_mask):
e_outputs = self.encoder(src, src_mask)
#print("DECODER")
d_output = self.decoder(trg, e_outputs, src_mask, trg_mask)
output = self.out(d_output)
output = F.log_softmax(output, dim=-1)
return output
def init_model(opt, src_vocab_size, trg_vocab_size, checkpoint=None):
assert opt.d_model % opt.heads == 0
assert opt.dropout < 1
model = Transformer(src_vocab_size, trg_vocab_size, opt.d_model, opt.n_layers, opt.heads, opt.dropout).to(opt.device)
if checkpoint is not None:
print('load weight ...')
model.load_state_dict(checkpoint['model'])
else:
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
return model