@@ -33,9 +33,14 @@ def __init__(self, emb_matrix,
33
33
self .batch_size = batch_size
34
34
self .hidden_size = hidden_size
35
35
36
- self .encoder = nn .GRU (input_size = hidden_size , hidden_size = hidden_size , bidirectional = bidirectional ,
37
- num_layers = num_layers ,
38
- batch_first = True )
36
+ # self.encoder = nn.GRU(input_size=hidden_size, hidden_size=hidden_size, bidirectional=bidirectional,
37
+ # num_layers=num_layers,
38
+ # batch_first=True)
39
+
40
+ self .encoder = nn .LSTM (input_size = hidden_size , hidden_size = hidden_size , bidirectional = bidirectional ,
41
+ num_layers = num_layers ,
42
+ batch_first = True )
43
+
39
44
self .hidden = self .init_hidden ()
40
45
self .sentinel = nn .Parameter (torch .rand (hidden_size , ))
41
46
@@ -70,7 +75,10 @@ def forward(self, inputs, mask):
70
75
return output
71
76
72
77
def init_hidden (self ):
73
- return torch .zeros (self .num_directions * self .num_layers , self .batch_size , self .hidden_size )
78
+ # return torch.zeros(self.num_directions * self.num_layers, self.batch_size, self.hidden_size)
79
+
80
+ return (torch .zeros (self .num_directions * self .num_layers , self .batch_size , self .hidden_size ),
81
+ torch .zeros (self .num_directions * self .num_layers , self .batch_size , self .hidden_size ))
74
82
75
83
76
84
# TODO : Takes input and produces out of same dimension our reference implementation
@@ -86,9 +94,14 @@ def __init__(self, dropout_rate,
86
94
self .batch_size = batch_size
87
95
self .hidden_size = hidden_size
88
96
89
- self .fusion_bilstm = nn .GRU (num_layers = num_layers , input_size = hidden_size * 3 , hidden_size = hidden_size ,
90
- batch_first = True ,
91
- bidirectional = True )
97
+ # self.fusion_bilstm = nn.GRU(num_layers=num_layers, input_size=hidden_size * 3, hidden_size=hidden_size,
98
+ # batch_first=True,
99
+ # bidirectional=True)
100
+
101
+ self .fusion_bilstm = nn .LSTM (num_layers = num_layers , input_size = hidden_size * 3 , hidden_size = hidden_size ,
102
+ batch_first = True ,
103
+ bidirectional = True )
104
+
92
105
self .hidden = self .init_hidden ()
93
106
self .dropout = nn .Dropout (p = dropout_rate )
94
107
@@ -122,7 +135,10 @@ def forward(self, inputs, mask):
122
135
return output
123
136
124
137
def init_hidden (self ):
125
- return torch .zeros (self .num_directions * self .num_layers , self .batch_size , self .hidden_size )
138
+ # return torch.zeros(self.num_directions * self.num_layers, self.batch_size, self.hidden_size)
139
+
140
+ return (torch .zeros (self .num_directions * self .num_layers , self .batch_size , self .hidden_size ),
141
+ torch .zeros (self .num_directions * self .num_layers , self .batch_size , self .hidden_size ))
126
142
127
143
128
144
class DynamicDecoder (nn .Module ):
@@ -194,10 +210,10 @@ def forward(self, U, d_mask, target_span):
194
210
195
211
# Get hidden state
196
212
# TODO : There could be problem with the dimension
197
- h_i = self .gru (u_cat .unsqueeze (1 ), h_i )[ 1 ]
213
+ output , h_i = self .gru (u_cat .unsqueeze (1 ), h_i )
198
214
199
215
# Get new start estimate and start loss
200
- s_i , _ , start_loss_i = self .start_hmn (h_i , U , None , s_i , u_cat , None , s_target )
216
+ s_i , _ , start_loss_i = self .start_hmn (output , U , None , s_i , u_cat , None , s_target )
201
217
# s_i, start_loss_i = self.start_hmn(h_i, U, u_cat, s_target)
202
218
203
219
# Update embedding at start estimate
@@ -207,7 +223,7 @@ def forward(self, U, d_mask, target_span):
207
223
u_cat = torch .cat ((u_s_i , u_e_i ), 1 ) # batch_size x 4l
208
224
209
225
# Get new end estimate and end loss
210
- e_i , _ , end_loss_i = self .end_hmn (h_i , U , None , e_i , u_cat , None , e_target )
226
+ e_i , _ , end_loss_i = self .end_hmn (output , U , None , e_i , u_cat , None , e_target )
211
227
# e_i, end_loss_i = self.end_hmn(h_i, U, u_cat, e_target)
212
228
213
229
# Update cumulative loss if computing loss
@@ -220,6 +236,7 @@ def forward(self, U, d_mask, target_span):
220
236
loss = cumulative_loss / self .max_dec_steps
221
237
return loss , s_i , e_i
222
238
239
+
223
240
class CoattentionNetwork (nn .Module ):
224
241
def __init__ (self , device ,
225
242
hidden_size ,
0 commit comments