diff --git a/ON_LSTM.py b/ON_LSTM.py index 8c259f8..4984015 100644 --- a/ON_LSTM.py +++ b/ON_LSTM.py @@ -179,6 +179,6 @@ def forward(self, input, hidden): if __name__ == "__main__": x = torch.Tensor(10, 10, 10) x.data.normal_() - lstm = LSTMCellStack([10, 10, 10]) + lstm = ONLSTMStack([10, 10, 10], chunk_size=10) print(lstm(x, lstm.init_hidden(10))[1])