@@ -57,14 +57,14 @@ def forward(self, inputs, mask):
57
57
58
58
# Convert the numbers into embeddings
59
59
inputs = self .embeddings (inputs .to ('cpu' ))
60
- packed = inputs
60
+ # packed = inputs
61
61
62
62
# Get the sorted version of inputs as required for pack_padded_sequence
63
- # inputs_sorted = torch.index_select(inputs, 0, lens_argsort)
63
+ inputs_sorted = torch .index_select (inputs , 0 , lens_argsort )
64
64
65
- # packed = pack_padded_sequence(inputs_sorted, lens , batch_first=True)
65
+ packed = pack_padded_sequence (inputs_sorted , lens_sorted , batch_first = True )
66
66
output , self .hidden = self .encoder (packed , self .hidden )
67
- # output, _ = pad_packed_sequence(output, batch_first=True)
67
+ output , _ = pad_packed_sequence (output , batch_first = True )
68
68
69
69
# Restore batch elements to original order
70
70
# output = torch.index_select(output, 0, lens_argsort_argsort)
@@ -119,8 +119,7 @@ def forward(self, inputs, mask):
119
119
output = torch .index_select (output , 0 , lens_argsort_argsort .to ('cpu' ))
120
120
121
121
# Make output contiguous for speed of future operations
122
- # TODO: Try without and time to see if this actually speeds up
123
- output = output .contiguous ()
122
+ # output = output.contiguous()
124
123
125
124
output = self .dropout (output )
126
125
return output
@@ -161,8 +160,6 @@ def __init__(self, device,
161
160
num_layers = num_layers ,
162
161
bidirectional = bidirectional )
163
162
164
- # self.hidden = self.initHidden() # for GRU
165
-
166
163
def forward (self , U , d_mask , target_span ):
167
164
168
165
batch_indices = torch .arange (self .batch_size , out = torch .LongTensor (self .batch_size ))
@@ -226,10 +223,6 @@ def forward(self, U, d_mask, target_span):
226
223
loss = cumulative_loss / self .max_dec_steps
227
224
return loss , s_i , e_i
228
225
229
- # def initHidden(self):
230
- # return torch.zeros(self.num_directions * self.num_layers, self.batch_size, self.hidden_size)
231
-
232
-
233
226
class CoattentionNetwork (nn .Module ):
234
227
def __init__ (self , device ,
235
228
hidden_size ,
0 commit comments