Skip to content

Commit

Permalink
reinitialize code
Browse files Browse the repository at this point in the history
  • Loading branch information
yikangshen committed Nov 12, 2018
1 parent 7c64238 commit 4b4e515
Show file tree
Hide file tree
Showing 28 changed files with 2,156 additions and 1,882 deletions.
29 changes: 29 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
BSD 3-Clause License

Copyright (c) 2017,
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
171 changes: 145 additions & 26 deletions LSTMCell.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.rnn import RNNCellBase
import torch.nn as nn
import torch

from locked_dropout import LockedDropout


class LayerNorm(nn.Module):
Expand All @@ -18,48 +19,166 @@ def forward(self, x):
return self.gamma * (x - mean) / (std + self.eps) + self.beta


class LSTMCell(RNNCellBase):
class LinearDropConnect(nn.Linear):
def __init__(self, in_features, out_features, bias=True, dropout=0.):
super(LinearDropConnect, self).__init__(
in_features=in_features,
out_features=out_features,
bias=bias
)
self.dropout = dropout

def sample_mask(self):
if self.dropout == 0.:
self._weight = self.weight
else:
mask = self.weight.new_empty(
self.weight.size(),
dtype=torch.uint8
)
mask.bernoulli_(self.dropout)
self._weight = self.weight.masked_fill(mask, 0.)

def forward(self, input, sample_mask=False):
if self.training:
if sample_mask:
self.sample_mask()
return F.linear(input, self._weight, self.bias)
else:
return F.linear(input, self.weight * (1 - self.dropout),
self.bias)


def __init__(self, input_size, hidden_size, dropout=0):
def cumsoftmax(x, dim=-1):
return torch.cumsum(F.softmax(x, dim=dim), dim=dim)


class LSTMCell(nn.Module):

def __init__(self, input_size, hidden_size, chunk_size, dropconnect=0.):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.chunk_size = chunk_size
self.n_chunk = int(hidden_size / chunk_size)

self.ih = nn.Sequential(nn.Linear(input_size, 3 * hidden_size, bias=True), LayerNorm(3 * hidden_size))
self.hh = nn.Sequential(nn.Linear(hidden_size, 3 * hidden_size, bias=True), LayerNorm(3 * hidden_size))
self.ih = nn.Sequential(
nn.Linear(input_size, 4 * hidden_size + self.n_chunk * 2, bias=True),
# LayerNorm(3 * hidden_size)
)
self.hh = LinearDropConnect(hidden_size, hidden_size*4+self.n_chunk*2, bias=True, dropout=dropconnect)

self.c_norm = LayerNorm(hidden_size)
self.drop = nn.Dropout(dropout)
# self.c_norm = LayerNorm(hidden_size)

self.dst = nn.Sequential(nn.Linear(hidden_size + input_size, hidden_size),
# LayerNorm(1),
nn.Softmax(dim=-1))
self.drop_weight_modules = [self.hh]

def forward(self, input, hidden, rmask):
def forward(self, input, hidden,
transformed_input=None):
hx, cx = hidden

input = self.drop(input)
hx = hx * rmask
gates = self.ih(input) + self.hh(hx) #+ self.bias
if transformed_input is None:
transformed_input = self.ih(input)

gates = transformed_input + self.hh(hx)
cingate, cforgetgate = gates[:, :self.n_chunk*2].chunk(2, 1)
outgate, cell, ingate, forgetgate = gates[:,self.n_chunk*2:].view(-1, self.n_chunk*4, self.chunk_size).chunk(4,1)

cell, ingate, outgate = gates.chunk(3, 1)
cingate = 1. - cumsoftmax(cingate)
cforgetgate = cumsoftmax(cforgetgate)

dst = self.dst(torch.cat([input, hx], dim=-1))
fgate = torch.cumsum(dst, dim=-1)
distance_cforget = 1. - cforgetgate.sum(dim=-1) / self.n_chunk
distance_cin = cingate.sum(dim=-1) / self.n_chunk

distance = fgate.sum(dim=-1) / self.hidden_size
cingate = cingate[:, :, None]
cforgetgate = cforgetgate[:, :, None]

ingate = F.sigmoid(ingate) * fgate
forgetgate = (1 - ingate)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cell = F.tanh(cell)
outgate = F.sigmoid(outgate)

# cy = cforgetgate * forgetgate * cx + cingate * ingate * cell

overlap = cforgetgate * cingate
forgetgate = forgetgate * overlap + (cforgetgate - overlap)
ingate = ingate * overlap + (cingate - overlap)
cy = forgetgate * cx + ingate * cell
hy = outgate * F.tanh(self.c_norm(cy))

return hy, cy, distance
# hy = outgate * F.tanh(self.c_norm(cy))
hy = outgate * F.tanh(cy)
return hy.view(-1, self.hidden_size), cy, (distance_cforget, distance_cin)

def init_hidden(self, bsz):
weight = next(self.parameters()).data
return weight.new(bsz, self.hidden_size).zero_(), \
weight.new(bsz, self.hidden_size).zero_()
return (weight.new(bsz, self.hidden_size).zero_(),
weight.new(bsz, self.n_chunk, self.chunk_size).zero_())

def sample_masks(self):
for m in self.drop_weight_modules:
m.sample_mask()


class LSTMStack(nn.Module):
def __init__(self, layer_sizes, chunk_size, dropout=0., dropconnect=0.):
super(LSTMStack, self).__init__()
self.cells = nn.ModuleList([LSTMCell(layer_sizes[i],
layer_sizes[i+1],
chunk_size,
dropconnect=dropconnect)
for i in range(len(layer_sizes) - 1)])
self.lockdrop = LockedDropout()
self.dropout = dropout
self.sizes = layer_sizes

def init_hidden(self, bsz):
return [c.init_hidden(bsz) for c in self.cells]

def forward(self, input, hidden):
length, batch_size, _ = input.size()

if self.training:
for c in self.cells:
c.sample_masks()

prev_state = list(hidden)
prev_layer = input

raw_outputs = []
outputs = []
distances_forget = []
distances_in = []
for l in range(len(self.cells)):
curr_layer = [None] * length
dist = [None] * length
t_input = self.cells[l].ih(prev_layer)

for t in range(length):
hidden, cell, d = self.cells[l](
None, prev_state[l],
transformed_input=t_input[t]
)
prev_state[l] = hidden, cell # overwritten every timestep
curr_layer[t] = hidden
dist[t] = d

prev_layer = torch.stack(curr_layer)
dist_cforget, dist_cin = zip(*dist)
dist_layer_cforget = torch.stack(dist_cforget)
dist_layer_cin = torch.stack(dist_cin)
raw_outputs.append(prev_layer)
if l < len(self.cells) - 1:
prev_layer = self.lockdrop(prev_layer, self.dropout)
outputs.append(prev_layer)
distances_forget.append(dist_layer_cforget)
distances_in.append(dist_layer_cin)
output = prev_layer

return output, prev_state, raw_outputs, outputs, (torch.stack(distances_forget), torch.stack(distances_in))


if __name__ == "__main__":
x = torch.Tensor(10, 10, 10)
x.data.normal_()
lstm = LSTMCellStack([10, 10, 10])
print(lstm(x, lstm.init_hidden(10))[1])

63 changes: 0 additions & 63 deletions LSTMCell_new.py

This file was deleted.

58 changes: 0 additions & 58 deletions LSTMCell_normal.py

This file was deleted.

Loading

0 comments on commit 4b4e515

Please sign in to comment.