-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Nicholas Leonard
committed
Feb 23, 2017
1 parent
ef98a97
commit f9500d3
Showing
10 changed files
with
791 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,61 @@ | ||
SET(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}) | ||
|
||
CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR) | ||
CMAKE_POLICY(VERSION 2.6) | ||
IF(LUAROCKS_PREFIX) | ||
MESSAGE(STATUS "Installing Torch through Luarocks") | ||
STRING(REGEX REPLACE "(.*)lib/luarocks/rocks.*" "\\1" CMAKE_INSTALL_PREFIX "${LUAROCKS_PREFIX}") | ||
MESSAGE(STATUS "Prefix inferred from Luarocks: ${CMAKE_INSTALL_PREFIX}") | ||
ENDIF() | ||
|
||
FIND_PACKAGE(Torch REQUIRED) | ||
|
||
SET(src) | ||
FILE(GLOB luasrc *.lua) | ||
SET(luasrc ${luasrc}) | ||
ADD_SUBDIRECTORY(test) | ||
ADD_TORCH_PACKAGE(rnn "${src}" "${luasrc}" "Recurrent Neural Networks") | ||
SET(BUILD_STATIC YES) # makes sure static targets are enabled in ADD_TORCH_PACKAGE | ||
|
||
SET(CMAKE_C_FLAGS "--std=c99 -pedantic -Werror -Wall -Wextra -Wno-unused-function -D_GNU_SOURCE ${CMAKE_C_FLAGS}") | ||
SET(src | ||
init.c | ||
) | ||
SET(luasrc | ||
init.lua | ||
AbstractRecurrent.lua | ||
AbstractSequencer.lua | ||
BiSequencer.lua | ||
BiSequencerLM.lua | ||
CopyGrad.lua | ||
Dropout.lua | ||
ExpandAs.lua | ||
FastLSTM.lua | ||
GRU.lua | ||
LinearNoBias.lua | ||
LookupTableMaskZero.lua | ||
LSTM.lua | ||
MaskZero.lua | ||
MaskZeroCriterion.lua | ||
Module.lua | ||
Mufuru.lua | ||
NormStabilizer.lua | ||
Padding.lua | ||
Recurrence.lua | ||
Recurrent.lua | ||
RecurrentAttention.lua | ||
recursiveUtils.lua | ||
Recursor.lua | ||
Repeater.lua | ||
RepeaterCriterion.lua | ||
SAdd.lua | ||
SeqBRNN.lua | ||
SeqGRU.lua | ||
SeqLSTM.lua | ||
SeqLSTMP.lua | ||
SeqReverseSequence.lua | ||
Sequencer.lua | ||
SequencerCriterion.lua | ||
TrimZero.lua | ||
ZeroGrad.lua | ||
test/bigtest.lua | ||
test/test.lua | ||
) | ||
|
||
ADD_TORCH_PACKAGE(rnn "${src}" "${luasrc}" "An RNN library for Torch") | ||
|
||
TARGET_LINK_LIBRARIES(rnn luaT TH) | ||
|
||
SET_TARGET_PROPERTIES(rnn_static PROPERTIES COMPILE_FLAGS "-fPIC -DSTATIC_TH") | ||
|
||
INSTALL(FILES ${luasrc} DESTINATION "${Torch_INSTALL_LUA_PATH_SUBDIR}/rnn") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
local VariableLength, parent = torch.class("nn.VariableLength", "nn.Decorator") | ||
|
||
-- make sure your module has been set-up for zero-masking (that is, module:maskZero()) | ||
function VariableLength:__init(module, lastOnly) | ||
parent.__init(self, module) | ||
-- only extract the last element of each sequence | ||
self.lastOnly = lastOnly -- defaults to false | ||
end | ||
|
||
-- recursively masks input (inplace) | ||
function VariableLength.recursiveMask(input, mask) | ||
if torch.type(input) == 'table' then | ||
for k,v in ipairs(input) do | ||
self.recursiveMask(v, mask) | ||
end | ||
else | ||
assert(torch.isTensor(input)) | ||
|
||
-- make sure mask has the same dimension as the input tensor | ||
assert(mask:dim() == 2, "Expecting batchsize x seqlen mask tensor") | ||
-- expand mask to input (if necessary) | ||
local zeroMask | ||
if input:dim() == 2 then | ||
zeroMask = mask | ||
elseif input:dim() > 2 then | ||
local inputSize = input:size():fill(1) | ||
inputSize[1] = input:size(1) | ||
inputSize[2] = input:size(2) | ||
mask:resize(inputSize) | ||
zeroMask = mask:expandAs(input) | ||
else | ||
error"Expecting batchsize x seqlen [ x ...] input tensor" | ||
end | ||
-- zero-mask input in between sequences | ||
input:maskedFill(zeroMask, 0) | ||
end | ||
end | ||
|
||
function VariableLength:updateOutput(input) | ||
-- input is a table of batchSize tensors | ||
assert(torch.type(input) == 'table') | ||
assert(torch.isTensor(input[1])) | ||
local batchSize = #input | ||
|
||
self._input = self._input or input[1].new() | ||
-- mask is a binary tensor with 1 where self._input is zero (between sequence zero-mask) | ||
self._mask = self._mask or torch.ByteTensor() | ||
|
||
-- now we process input into _input. | ||
-- indexes and mappedLengths are meta-information tables, explained below. | ||
self.indexes, self.mappedLengths = self._input.nn.VariableLength_FromSamples(input, self._input, self._mask) | ||
|
||
-- zero-mask the _input where mask is 1 | ||
self.recursiveMask(self._input, self._mask) | ||
|
||
-- feedforward the zero-mask format through the decorated module | ||
local output = self.modules[1]:updateOutput(self._input) | ||
|
||
if self.lastOnly then | ||
-- Extract the last time step of each sample. | ||
-- self.output tensor has shape: batchSize [x outputSize] | ||
self.output = torch.isTensor(self.output) and self.output or output.new() | ||
self.output.nn.VariableLength_ToFinal(selfindexes, self.mappedLengths, output, self.output) | ||
else | ||
-- This is the revese operation of everything before updateOutput | ||
self.output = input.nn.VariableLength_ToSamples(self.indexes, self.mappedLengths, output) | ||
end | ||
|
||
return self.output | ||
end | ||
|
||
function VariableLength:updateGradInput(input, gradInput) | ||
|
||
return self.gradInput | ||
end | ||
|
||
function VariableLength:accGradParameters(input, gradInput, scale) | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#ifndef _ERROR_H_ | ||
#define _ERROR_H_ | ||
|
||
#include "luaT.h" | ||
#include <string.h> | ||
|
||
static inline int _lua_error(lua_State *L, int ret, const char* file, int line) { | ||
int pos_ret = ret >= 0 ? ret : -ret; | ||
return luaL_error(L, "ERROR: (%s, %d): (%d, %s)\n", file, line, pos_ret, strerror(pos_ret)); | ||
} | ||
|
||
static inline int _lua_error_str(lua_State *L, const char *str, const char* file, int line) { | ||
return luaL_error(L, "ERROR: (%s, %d): (%s)\n", file, line, str); | ||
} | ||
|
||
static inline int _lua_error_str_str(lua_State *L, const char *str, const char* file, int line, const char *extra) { | ||
return luaL_error(L, "ERROR: (%s, %d): (%s: %s)\n", file, line, str, extra); | ||
} | ||
|
||
#define LUA_HANDLE_ERROR(L, ret) _lua_error(L, ret, __FILE__, __LINE__) | ||
#define LUA_HANDLE_ERROR_STR(L, str) _lua_error_str(L, str, __FILE__, __LINE__) | ||
#define LUA_HANDLE_ERROR_STR_STR(L, str, extra) _lua_error_str_str(L, str, __FILE__, __LINE__, extra) | ||
|
||
#endif |
Oops, something went wrong.