-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
752 additions
and
218 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,356 @@ | ||
# -*- coding: utf-8 -*- | ||
""" Implement a pyTorch LSTM with hard sigmoid reccurent activation functions. | ||
Adapted from the non-cuda variant of pyTorch LSTM at | ||
https://github.com/pytorch/pytorch/blob/master/torch/nn/_functions/rnn.py | ||
""" | ||
|
||
from __future__ import print_function, division | ||
import math | ||
import torch | ||
|
||
from torch.nn import Module | ||
from torch.nn.parameter import Parameter | ||
from torch.nn.utils.rnn import PackedSequence | ||
import torch.nn.functional as F | ||
|
||
class LSTMHardSigmoid(Module): | ||
|
||
def __init__(self, input_size, hidden_size, | ||
num_layers=1, bias=True, batch_first=False, | ||
dropout=0, bidirectional=False): | ||
super(LSTMHardSigmoid, self).__init__() | ||
self.input_size = input_size | ||
self.hidden_size = hidden_size | ||
self.num_layers = num_layers | ||
self.bias = bias | ||
self.batch_first = batch_first | ||
self.dropout = dropout | ||
self.dropout_state = {} | ||
self.bidirectional = bidirectional | ||
num_directions = 2 if bidirectional else 1 | ||
|
||
gate_size = 4 * hidden_size | ||
|
||
self._all_weights = [] | ||
for layer in range(num_layers): | ||
for direction in range(num_directions): | ||
layer_input_size = input_size if layer == 0 else hidden_size * num_directions | ||
|
||
w_ih = Parameter(torch.Tensor(gate_size, layer_input_size)) | ||
w_hh = Parameter(torch.Tensor(gate_size, hidden_size)) | ||
b_ih = Parameter(torch.Tensor(gate_size)) | ||
b_hh = Parameter(torch.Tensor(gate_size)) | ||
layer_params = (w_ih, w_hh, b_ih, b_hh) | ||
|
||
suffix = '_reverse' if direction == 1 else '' | ||
param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}'] | ||
if bias: | ||
param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}'] | ||
param_names = [x.format(layer, suffix) for x in param_names] | ||
|
||
for name, param in zip(param_names, layer_params): | ||
setattr(self, name, param) | ||
self._all_weights.append(param_names) | ||
|
||
self.flatten_parameters() | ||
self.reset_parameters() | ||
|
||
def flatten_parameters(self): | ||
"""Resets parameter data pointer so that they can use faster code paths. | ||
Right now, this is a no-op wince we don't use CUDA acceleration. | ||
""" | ||
self._data_ptrs = [] | ||
|
||
def _apply(self, fn): | ||
ret = super(LSTMHardSigmoid, self)._apply(fn) | ||
self.flatten_parameters() | ||
return ret | ||
|
||
def reset_parameters(self): | ||
stdv = 1.0 / math.sqrt(self.hidden_size) | ||
for weight in self.parameters(): | ||
weight.data.uniform_(-stdv, stdv) | ||
|
||
def forward(self, input, hx=None): | ||
is_packed = isinstance(input, PackedSequence) | ||
if is_packed: | ||
input, batch_sizes = input | ||
max_batch_size = batch_sizes[0] | ||
else: | ||
batch_sizes = None | ||
max_batch_size = input.size(0) if self.batch_first else input.size(1) | ||
|
||
if hx is None: | ||
num_directions = 2 if self.bidirectional else 1 | ||
hx = torch.autograd.Variable(input.data.new(self.num_layers * | ||
num_directions, | ||
max_batch_size, | ||
self.hidden_size).zero_(), requires_grad=False) | ||
hx = (hx, hx) | ||
|
||
has_flat_weights = list(p.data.data_ptr() for p in self.parameters()) == self._data_ptrs | ||
if has_flat_weights: | ||
first_data = next(self.parameters()).data | ||
assert first_data.storage().size() == self._param_buf_size | ||
flat_weight = first_data.new().set_(first_data.storage(), 0, torch.Size([self._param_buf_size])) | ||
else: | ||
flat_weight = None | ||
func = AutogradRNN( | ||
self.input_size, | ||
self.hidden_size, | ||
num_layers=self.num_layers, | ||
batch_first=self.batch_first, | ||
dropout=self.dropout, | ||
train=self.training, | ||
bidirectional=self.bidirectional, | ||
batch_sizes=batch_sizes, | ||
dropout_state=self.dropout_state, | ||
flat_weight=flat_weight | ||
) | ||
output, hidden = func(input, self.all_weights, hx) | ||
if is_packed: | ||
output = PackedSequence(output, batch_sizes) | ||
return output, hidden | ||
|
||
def __repr__(self): | ||
s = '{name}({input_size}, {hidden_size}' | ||
if self.num_layers != 1: | ||
s += ', num_layers={num_layers}' | ||
if self.bias is not True: | ||
s += ', bias={bias}' | ||
if self.batch_first is not False: | ||
s += ', batch_first={batch_first}' | ||
if self.dropout != 0: | ||
s += ', dropout={dropout}' | ||
if self.bidirectional is not False: | ||
s += ', bidirectional={bidirectional}' | ||
s += ')' | ||
return s.format(name=self.__class__.__name__, **self.__dict__) | ||
|
||
def __setstate__(self, d): | ||
super(LSTMHardSigmoid, self).__setstate__(d) | ||
self.__dict__.setdefault('_data_ptrs', []) | ||
if 'all_weights' in d: | ||
self._all_weights = d['all_weights'] | ||
if isinstance(self._all_weights[0][0], str): | ||
return | ||
num_layers = self.num_layers | ||
num_directions = 2 if self.bidirectional else 1 | ||
self._all_weights = [] | ||
for layer in range(num_layers): | ||
for direction in range(num_directions): | ||
suffix = '_reverse' if direction == 1 else '' | ||
weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}'] | ||
weights = [x.format(layer, suffix) for x in weights] | ||
if self.bias: | ||
self._all_weights += [weights] | ||
else: | ||
self._all_weights += [weights[:2]] | ||
|
||
@property | ||
def all_weights(self): | ||
return [[getattr(self, weight) for weight in weights] for weights in self._all_weights] | ||
|
||
def AutogradRNN(input_size, hidden_size, num_layers=1, batch_first=False, | ||
dropout=0, train=True, bidirectional=False, batch_sizes=None, | ||
dropout_state=None, flat_weight=None): | ||
|
||
cell = LSTMCell | ||
|
||
if batch_sizes is None: | ||
rec_factory = Recurrent | ||
else: | ||
rec_factory = variable_recurrent_factory(batch_sizes) | ||
|
||
if bidirectional: | ||
layer = (rec_factory(cell), rec_factory(cell, reverse=True)) | ||
else: | ||
layer = (rec_factory(cell),) | ||
|
||
func = StackedRNN(layer, | ||
num_layers, | ||
True, | ||
dropout=dropout, | ||
train=train) | ||
|
||
def forward(input, weight, hidden): | ||
if batch_first and batch_sizes is None: | ||
input = input.transpose(0, 1) | ||
|
||
nexth, output = func(input, hidden, weight) | ||
|
||
if batch_first and batch_sizes is None: | ||
output = output.transpose(0, 1) | ||
|
||
return output, nexth | ||
|
||
return forward | ||
|
||
def Recurrent(inner, reverse=False): | ||
def forward(input, hidden, weight): | ||
output = [] | ||
steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0)) | ||
for i in steps: | ||
hidden = inner(input[i], hidden, *weight) | ||
# hack to handle LSTM | ||
output.append(hidden[0] if isinstance(hidden, tuple) else hidden) | ||
|
||
if reverse: | ||
output.reverse() | ||
output = torch.cat(output, 0).view(input.size(0), *output[0].size()) | ||
|
||
return hidden, output | ||
|
||
return forward | ||
|
||
|
||
def variable_recurrent_factory(batch_sizes): | ||
def fac(inner, reverse=False): | ||
if reverse: | ||
return VariableRecurrentReverse(batch_sizes, inner) | ||
else: | ||
return VariableRecurrent(batch_sizes, inner) | ||
return fac | ||
|
||
def VariableRecurrent(batch_sizes, inner): | ||
def forward(input, hidden, weight): | ||
output = [] | ||
input_offset = 0 | ||
last_batch_size = batch_sizes[0] | ||
hiddens = [] | ||
flat_hidden = not isinstance(hidden, tuple) | ||
if flat_hidden: | ||
hidden = (hidden,) | ||
for batch_size in batch_sizes: | ||
step_input = input[input_offset:input_offset + batch_size] | ||
input_offset += batch_size | ||
|
||
dec = last_batch_size - batch_size | ||
if dec > 0: | ||
hiddens.append(tuple(h[-dec:] for h in hidden)) | ||
hidden = tuple(h[:-dec] for h in hidden) | ||
last_batch_size = batch_size | ||
|
||
if flat_hidden: | ||
hidden = (inner(step_input, hidden[0], *weight),) | ||
else: | ||
hidden = inner(step_input, hidden, *weight) | ||
|
||
output.append(hidden[0]) | ||
hiddens.append(hidden) | ||
hiddens.reverse() | ||
|
||
hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens)) | ||
assert hidden[0].size(0) == batch_sizes[0] | ||
if flat_hidden: | ||
hidden = hidden[0] | ||
output = torch.cat(output, 0) | ||
|
||
return hidden, output | ||
|
||
return forward | ||
|
||
|
||
def VariableRecurrentReverse(batch_sizes, inner): | ||
def forward(input, hidden, weight): | ||
output = [] | ||
input_offset = input.size(0) | ||
last_batch_size = batch_sizes[-1] | ||
initial_hidden = hidden | ||
flat_hidden = not isinstance(hidden, tuple) | ||
if flat_hidden: | ||
hidden = (hidden,) | ||
initial_hidden = (initial_hidden,) | ||
hidden = tuple(h[:batch_sizes[-1]] for h in hidden) | ||
for batch_size in reversed(batch_sizes): | ||
inc = batch_size - last_batch_size | ||
if inc > 0: | ||
hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0) | ||
for h, ih in zip(hidden, initial_hidden)) | ||
last_batch_size = batch_size | ||
step_input = input[input_offset - batch_size:input_offset] | ||
input_offset -= batch_size | ||
|
||
if flat_hidden: | ||
hidden = (inner(step_input, hidden[0], *weight),) | ||
else: | ||
hidden = inner(step_input, hidden, *weight) | ||
output.append(hidden[0]) | ||
|
||
output.reverse() | ||
output = torch.cat(output, 0) | ||
if flat_hidden: | ||
hidden = hidden[0] | ||
return hidden, output | ||
|
||
return forward | ||
|
||
def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True): | ||
|
||
num_directions = len(inners) | ||
total_layers = num_layers * num_directions | ||
|
||
def forward(input, hidden, weight): | ||
assert(len(weight) == total_layers) | ||
next_hidden = [] | ||
|
||
if lstm: | ||
hidden = list(zip(*hidden)) | ||
|
||
for i in range(num_layers): | ||
all_output = [] | ||
for j, inner in enumerate(inners): | ||
l = i * num_directions + j | ||
|
||
hy, output = inner(input, hidden[l], weight[l]) | ||
next_hidden.append(hy) | ||
all_output.append(output) | ||
|
||
input = torch.cat(all_output, input.dim() - 1) | ||
|
||
if dropout != 0 and i < num_layers - 1: | ||
input = F.dropout(input, p=dropout, training=train, inplace=False) | ||
|
||
if lstm: | ||
next_h, next_c = zip(*next_hidden) | ||
next_hidden = ( | ||
torch.cat(next_h, 0).view(total_layers, *next_h[0].size()), | ||
torch.cat(next_c, 0).view(total_layers, *next_c[0].size()) | ||
) | ||
else: | ||
next_hidden = torch.cat(next_hidden, 0).view( | ||
total_layers, *next_hidden[0].size()) | ||
|
||
return next_hidden, input | ||
|
||
return forward | ||
|
||
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): | ||
""" | ||
A modified LSTM cell with hard sigmoid activation on the input, forget and output gates. | ||
""" | ||
hx, cx = hidden | ||
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) | ||
|
||
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) | ||
|
||
ingate = hard_sigmoid(ingate) | ||
forgetgate = hard_sigmoid(forgetgate) | ||
cellgate = F.tanh(cellgate) | ||
outgate = hard_sigmoid(outgate) | ||
|
||
cy = (forgetgate * cx) + (ingate * cellgate) | ||
hy = outgate * F.tanh(cy) | ||
|
||
return hy, cy | ||
|
||
def hard_sigmoid(x): | ||
""" | ||
Computes element-wise hard sigmoid of x. | ||
See e.g. https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/sigm.py#L279 | ||
""" | ||
x = (0.2 * x) + 0.5 | ||
x = F.threshold(-x, -1, -1) | ||
x = F.threshold(-x, 0, 0) | ||
return x |
Oops, something went wrong.