diff --git a/CMakeLists.txt b/CMakeLists.txt index 23d1fe4..d4b819e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,15 +4,27 @@ CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR) CMAKE_POLICY(VERSION 2.6) FIND_PACKAGE(Torch REQUIRED) + +IF (NOT ("$ENV{CUDA}" STREQUAL "NO")) + FIND_PACKAGE(CUDA 6.5) +ENDIF() + SET(BUILD_STATIC YES) # makes sure static targets are enabled in ADD_TORCH_PACKAGE SET(CMAKE_C_FLAGS "--std=c99 -pedantic -Werror -Wall -Wextra -Wno-unused-function -D_GNU_SOURCE ${CMAKE_C_FLAGS}") -SET(src - src/rnn.c - ) +SET(src src/rnn.c) +IF (CUDA_FOUND) + LIST(APPEND src src/rnn.cu) + IF(NOT COMMAND CUDA_SELECT_NVCC_ARCH_FLAGS OR MSVC) + INCLUDE(${CMAKE_CURRENT_SOURCE_DIR}/cmake/select_compute_arch.cmake) + ENDIF() + CUDA_SELECT_NVCC_ARCH_FLAGS(NVCC_FLAGS_EXTRA $ENV{TORCH_CUDA_ARCH_LIST}) + LIST(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA}) +ENDIF() INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}/src) + SET(luasrc init.lua AbstractRecurrent.lua @@ -79,12 +91,16 @@ SET(luasrc deprecated/LSTM.lua ) -ADD_TORCH_PACKAGE(rnn "${src}" "${luasrc}" "An RNN library for Torch") +ADD_TORCH_PACKAGE(rnn "${src}" "${luasrc}" "An RNN module for Torch") -TARGET_LINK_LIBRARIES(rnn luaT TH) +IF (CUDA_FOUND) + SET(CMAKE_C_FLAGS "-DUSE_CUDA ${CMAKE_C_FLAGS}") + TARGET_LINK_LIBRARIES(rnn luaT TH THC ${CUDA_LIBRARIES}) +ELSE() + TARGET_LINK_LIBRARIES(rnn luaT TH) +ENDIF() IF (BUILD_STATIC OR "$ENV{STATIC_TH}" STREQUAL "YES") SET_TARGET_PROPERTIES(rnn_static PROPERTIES COMPILE_FLAGS "-fPIC -DSTATIC_TH") ENDIF() - INSTALL(FILES ${luasrc} DESTINATION "${Torch_INSTALL_LUA_PATH_SUBDIR}/rnn") diff --git a/cmake/select_compute_arch.cmake b/cmake/select_compute_arch.cmake new file mode 100644 index 0000000..bff85de --- /dev/null +++ b/cmake/select_compute_arch.cmake @@ -0,0 +1,200 @@ +# Synopsis: +# CUDA_SELECT_NVCC_ARCH_FLAGS(out_variable [target_CUDA_architectures]) +# -- Selects GPU arch flags for nvcc based on target_CUDA_architectures +# target_CUDA_architectures : Auto | Common | All | LIST(ARCH_AND_PTX ...) +# - "Auto" detects local machine GPU compute arch at runtime. +# - "Common" and "All" cover common and entire subsets of architectures +# ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX +# NAME: Fermi Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal +# NUM: Any number. Only those pairs are currently accepted by NVCC though: +# 2.0 2.1 3.0 3.2 3.5 3.7 5.0 5.2 5.3 6.0 6.2 +# Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable} +# Additionally, sets ${out_variable}_readable to the resulting numeric list +# Example: +# CUDA_SELECT_NVCC_ARCH_FLAGS(ARCH_FLAGS 3.0 3.5+PTX 5.2(5.0) Maxwell) +# LIST(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS}) +# +# More info on CUDA architectures: https://en.wikipedia.org/wiki/CUDA +# + +# This list will be used for CUDA_ARCH_NAME = All option +set(CUDA_KNOWN_GPU_ARCHITECTURES "Fermi" "Kepler" "Maxwell") + +# This list will be used for CUDA_ARCH_NAME = Common option (enabled by default) +set(CUDA_COMMON_GPU_ARCHITECTURES "3.0" "3.5" "5.0") + +if (CUDA_VERSION VERSION_GREATER "6.5") + list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Kepler+Tegra" "Kepler+Tesla" "Maxwell+Tegra") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2") +endif () + +if (CUDA_VERSION VERSION_GREATER "7.5") + list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Pascal") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "6.0" "6.1" "6.1+PTX") +else() + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2+PTX") +endif () + + + +################################################################################################ +# A function for automatic detection of GPUs installed (if autodetection is enabled) +# Usage: +# CUDA_DETECT_INSTALLED_GPUS(OUT_VARIABLE) +# +function(CUDA_DETECT_INSTALLED_GPUS OUT_VARIABLE) + if(NOT CUDA_GPU_DETECT_OUTPUT) + set(cufile ${PROJECT_BINARY_DIR}/detect_cuda_archs.cu) + + file(WRITE ${cufile} "" + "#include \n" + "int main()\n" + "{\n" + " int count = 0;\n" + " if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n" + " if (count == 0) return -1;\n" + " for (int device = 0; device < count; ++device)\n" + " {\n" + " cudaDeviceProp prop;\n" + " if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n" + " std::printf(\"%d.%d \", prop.major, prop.minor);\n" + " }\n" + " return 0;\n" + "}\n") + + execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" "--run" "${cufile}" + "-ccbin" ${CMAKE_CXX_COMPILER} + WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/" + RESULT_VARIABLE nvcc_res OUTPUT_VARIABLE nvcc_out + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(nvcc_res EQUAL 0) + # only keep the last line of nvcc_out + string(REGEX REPLACE ";" "\\\\;" nvcc_out "${nvcc_out}") + string(REGEX REPLACE "\n" ";" nvcc_out "${nvcc_out}") + list(GET nvcc_out -1 nvcc_out) + string(REPLACE "2.1" "2.1(2.0)" nvcc_out "${nvcc_out}") + set(CUDA_GPU_DETECT_OUTPUT ${nvcc_out} CACHE INTERNAL "Returned GPU architetures from detect_gpus tool" FORCE) + endif() + endif() + + if(NOT CUDA_GPU_DETECT_OUTPUT) + message(STATUS "Automatic GPU detection failed. Building for common architectures.") + set(${OUT_VARIABLE} ${CUDA_COMMON_GPU_ARCHITECTURES} PARENT_SCOPE) + else() + set(${OUT_VARIABLE} ${CUDA_GPU_DETECT_OUTPUT} PARENT_SCOPE) + endif() +endfunction() + + +################################################################################################ +# Function for selecting GPU arch flags for nvcc based on CUDA architectures from parameter list +# Usage: +# SELECT_NVCC_ARCH_FLAGS(out_variable [list of CUDA compute archs]) +function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable) + set(CUDA_ARCH_LIST "${ARGN}") + + if("X${CUDA_ARCH_LIST}" STREQUAL "X" ) + set(CUDA_ARCH_LIST "Auto") + endif() + + set(cuda_arch_bin) + set(cuda_arch_ptx) + + if("${CUDA_ARCH_LIST}" STREQUAL "All") + set(CUDA_ARCH_LIST ${CUDA_KNOWN_GPU_ARCHITECTURES}) + elseif("${CUDA_ARCH_LIST}" STREQUAL "Common") + set(CUDA_ARCH_LIST ${CUDA_COMMON_GPU_ARCHITECTURES}) + elseif("${CUDA_ARCH_LIST}" STREQUAL "Auto") + CUDA_DETECT_INSTALLED_GPUS(CUDA_ARCH_LIST) + message(STATUS "Autodetected CUDA architecture(s): ${CUDA_ARCH_LIST}") + endif() + + # Now process the list and look for names + string(REGEX REPLACE "[ \t]+" ";" CUDA_ARCH_LIST "${CUDA_ARCH_LIST}") + list(REMOVE_DUPLICATES CUDA_ARCH_LIST) + foreach(arch_name ${CUDA_ARCH_LIST}) + set(arch_bin) + set(add_ptx FALSE) + # Check to see if we are compiling PTX + if(arch_name MATCHES "(.*)\\+PTX$") + set(add_ptx TRUE) + set(arch_name ${CMAKE_MATCH_1}) + endif() + if(arch_name MATCHES "(^[0-9]\\.[0-9](\\([0-9]\\.[0-9]\\))?)$") + set(arch_bin ${CMAKE_MATCH_1}) + set(arch_ptx ${arch_bin}) + else() + # Look for it in our list of known architectures + if(${arch_name} STREQUAL "Fermi") + set(arch_bin "2.0 2.1(2.0)") + elseif(${arch_name} STREQUAL "Kepler+Tegra") + set(arch_bin 3.2) + elseif(${arch_name} STREQUAL "Kepler+Tesla") + set(arch_bin 3.7) + elseif(${arch_name} STREQUAL "Kepler") + set(arch_bin 3.0 3.5) + set(arch_ptx 3.5) + elseif(${arch_name} STREQUAL "Maxwell+Tegra") + set(arch_bin 5.3) + elseif(${arch_name} STREQUAL "Maxwell") + set(arch_bin 5.0 5.2) + set(arch_ptx 5.2) + elseif(${arch_name} STREQUAL "Pascal") + set(arch_bin 6.0 6.1) + set(arch_ptx 6.1) + else() + message(SEND_ERROR "Unknown CUDA Architecture Name ${arch_name} in CUDA_SELECT_NVCC_ARCH_FLAGS") + endif() + endif() + if(NOT arch_bin) + message(SEND_ERROR "arch_bin wasn't set for some reason") + endif() + list(APPEND cuda_arch_bin ${arch_bin}) + if(add_ptx) + if (NOT arch_ptx) + set(arch_ptx ${arch_bin}) + endif() + list(APPEND cuda_arch_ptx ${arch_ptx}) + endif() + endforeach() + + # remove dots and convert to lists + string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}") + string(REGEX REPLACE "\\." "" cuda_arch_ptx "${cuda_arch_ptx}") + string(REGEX MATCHALL "[0-9()]+" cuda_arch_bin "${cuda_arch_bin}") + string(REGEX MATCHALL "[0-9]+" cuda_arch_ptx "${cuda_arch_ptx}") + + if(cuda_arch_bin) + list(REMOVE_DUPLICATES cuda_arch_bin) + endif() + if(cuda_arch_ptx) + list(REMOVE_DUPLICATES cuda_arch_ptx) + endif() + + set(nvcc_flags "") + set(nvcc_archs_readable "") + + # Tell NVCC to add binaries for the specified GPUs + foreach(arch ${cuda_arch_bin}) + if(arch MATCHES "([0-9]+)\\(([0-9]+)\\)") + # User explicitly specified ARCH for the concrete CODE + list(APPEND nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1}) + list(APPEND nvcc_archs_readable sm_${CMAKE_MATCH_1}) + else() + # User didn't explicitly specify ARCH for the concrete CODE, we assume ARCH=CODE + list(APPEND nvcc_flags -gencode arch=compute_${arch},code=sm_${arch}) + list(APPEND nvcc_archs_readable sm_${arch}) + endif() + endforeach() + + # Tell NVCC to add PTX intermediate code for the specified architectures + foreach(arch ${cuda_arch_ptx}) + list(APPEND nvcc_flags -gencode arch=compute_${arch},code=compute_${arch}) + list(APPEND nvcc_archs_readable compute_${arch}) + endforeach() + + string(REPLACE ";" " " nvcc_archs_readable "${nvcc_archs_readable}") + set(${out_variable} ${nvcc_flags} PARENT_SCOPE) + set(${out_variable}_readable ${nvcc_archs_readable} PARENT_SCOPE) +endfunction() diff --git a/init.lua b/init.lua index dd222d1..6cdb1e0 100644 --- a/init.lua +++ b/init.lua @@ -1,6 +1,7 @@ require 'torchx' local _ = require 'moses' require 'nn' +pcall(require, 'cunn') -- create global rnn table: rnn = {} @@ -113,4 +114,4 @@ require('rnn.BiSequencerLM') -- prevent likely name conflicts nn.rnn = rnn -return rnn \ No newline at end of file +return rnn diff --git a/src/generic/cuda/StepGRU.cu b/src/generic/cuda/StepGRU.cu new file mode 100644 index 0000000..8fee288 --- /dev/null +++ b/src/generic/cuda/StepGRU.cu @@ -0,0 +1,207 @@ +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/cuda/StepGRU.cu" +#else + +#if defined(THC_REAL_IS_HALF) +#define _REAL(val) THC_float2half(val) +#else +#define _REAL(val) (val) +#endif + +static int nn_(StepGRU_updateOutput)(lua_State *L) { + THCState *state = getCudaState(L); + THCTensor *weight = (THCTensor *)luaT_checkudata(L, 1, torch_Tensor); + THCTensor *bias = (THCTensor *)luaT_checkudata(L, 2, torch_Tensor); + THCTensor *gates = (THCTensor *)luaT_checkudata(L, 3, torch_Tensor); + THCTensor *cur_x = (THCTensor *)luaT_checkudata(L, 4, torch_Tensor); + THCTensor *prev_h = (THCTensor *)luaT_checkudata(L, 5, torch_Tensor); + int inputsize = luaL_checkinteger(L, 6); + int outputsize = luaL_checkinteger(L, 7); + THCTensor *next_h = (THCTensor *)luaT_checkudata(L, 8, torch_Tensor); + + int batchsize = THCTensor_(size)(state, cur_x, 0); + if (THCTensor_(size)(state, cur_x, 1) != inputsize) + return LUA_HANDLE_ERROR_STR(L, "expected input[1]:size(2) == inputsize"); + + THLongStorage* size = THLongStorage_newWithSize2(1, 3 * outputsize); + THCTensor *buffer = THCTensor_(newView)(state, bias, size); + buffer->stride[0] = 0; + buffer->size[0] = batchsize; + + THCTensor_(resize2d)(state, next_h, batchsize, outputsize); + long nElement = THCTensor_(nElement)(state, gates); + THCTensor_(resize2d)(state, gates, batchsize, 3 * outputsize); + if (nElement != batchsize * 3 * outputsize) + THCTensor_(fill)(state, gates, _REAL(0)); + + THCTensor *Wx = THCTensor_(newNarrow)(state, weight, 0, 0, inputsize); + THCTensor *Wh = THCTensor_(newNarrow)(state, weight, 0, inputsize, outputsize); + THCTensor *sub_gates = THCTensor_(newNarrow)(state, gates, 1, 0, 2 * outputsize); + THCTensor *sub_Wh = THCTensor_(newNarrow)(state, Wh, 1, 0, 2 * outputsize); + // r = sig(Wx * x + Wh * prev_h + b) + THCTensor *reset_gate = THCTensor_(newNarrow)(state, gates, 1, 0, outputsize); + // u = sig(Wx * x + Wh * prev_h + b) + THCTensor *update_gate = THCTensor_(newNarrow)(state, gates, 1, outputsize, outputsize); + // hc = tanh(Wx * x + Wh * r . prev_h + b) + THCTensor *hidden_candidate = THCTensor_(newNarrow)(state, gates, 1, 2*outputsize, outputsize); + + // forward + THCTensor_(addmm)(state, gates, _REAL(1), buffer, _REAL(1), cur_x, Wx); + THCTensor_(addmm)(state, sub_gates, _REAL(1), sub_gates, _REAL(1), prev_h, sub_Wh); + THCTensor_(sigmoid)(state, sub_gates, sub_gates); + + // temporary buffer : r . prev_h + THCTensor_(cmul)(state, next_h, reset_gate, prev_h); + THCTensor_(narrow)(state, sub_Wh, Wh, 1, 2 * outputsize, outputsize); + // hc += Wh * r . prev_h + THCTensor_(addmm)(state, hidden_candidate, _REAL(1), hidden_candidate, _REAL(1), next_h, sub_Wh); + // hc = tanh(Wx * x + Wh * r . prev_h + b) + THCTensor_(tanh)(state, hidden_candidate, hidden_candidate); + // (1-u) . hc = hc - (u . hc) + THCTensor_(addcmul)(state, next_h, hidden_candidate, _REAL(-1), update_gate, hidden_candidate); + // next_h = (1-u) . hc + u . prev_h + THCTensor_(addcmul)(state, next_h, next_h, _REAL(1), update_gate, prev_h); + + THCTensor_(free)(state, Wx); + THCTensor_(free)(state, Wh); + THCTensor_(free)(state, buffer); + THCTensor_(free)(state, reset_gate); + THCTensor_(free)(state, update_gate); + THCTensor_(free)(state, hidden_candidate); + THCTensor_(free)(state, sub_gates); + THCTensor_(free)(state, sub_Wh); + THLongStorage_free(size); + + return 1; +} + +static int nn_(StepGRU_backward)(lua_State *L) { + THCState *state = getCudaState(L); + THCTensor *weight = (THCTensor *)luaT_checkudata(L, 1, torch_Tensor); + THCTensor *gates = (THCTensor *)luaT_checkudata(L, 2, torch_Tensor); + THCTensor *gradWeight = (THCTensor *)luaT_checkudata(L, 3, torch_Tensor); + THCTensor *grad_b = (THCTensor *)luaT_checkudata(L, 4, torch_Tensor); + THCTensor *grad_gates = (THCTensor *)luaT_checkudata(L, 5, torch_Tensor); + THCTensor *buffer = (THCTensor *)luaT_checkudata(L, 6, torch_Tensor); + THCTensor *cur_x = (THCTensor *)luaT_checkudata(L, 7, torch_Tensor); + THCTensor *prev_h = (THCTensor *)luaT_checkudata(L, 8, torch_Tensor); + THCTensor *grad_next_h = (THCTensor *)luaT_checkudata(L, 9, torch_Tensor); + lua_Number scale = luaL_checknumber(L, 10); + int inputsize = luaL_checkinteger(L, 11); + int outputsize = luaL_checkinteger(L, 12); + THCTensor *grad_cur_x = (THCTensor *)luaT_checkudata(L, 13, torch_Tensor); + THCTensor *grad_prev_h = (THCTensor *)luaT_checkudata(L, 14, torch_Tensor); + + int batchsize = THCTensor_(size)(state, cur_x, 0); + if (THCTensor_(size)(state, cur_x, 1) != inputsize) + return LUA_HANDLE_ERROR_STR(L, "expected input[1]:size(2) == inputsize"); + if (THCTensor_(size)(state, grad_next_h, 1) != outputsize) + return LUA_HANDLE_ERROR_STR(L, "expected gradOutput[1]:size(2) == outputsize"); + + THCTensor_(resize2d)(state, grad_cur_x, batchsize, inputsize); + THCTensor_(resize2d)(state, grad_prev_h, batchsize, outputsize); + THCTensor_(resize2d)(state, grad_gates, batchsize, 3 * outputsize); + + THCTensor *Wx = THCTensor_(newNarrow)(state, weight, 0, 0, inputsize); + THCTensor *Wh = THCTensor_(newNarrow)(state, weight, 0, inputsize, outputsize); + THCTensor *reset_gate = THCTensor_(newNarrow)(state, gates, 1, 0, outputsize); + THCTensor *update_gate = THCTensor_(newNarrow)(state, gates, 1, outputsize, outputsize); + THCTensor *hidden_candidate = THCTensor_(newNarrow)(state, gates, 1, 2*outputsize, outputsize); + THCTensor *grad_Wx = THCTensor_(newNarrow)(state, gradWeight, 0, 0, inputsize); + THCTensor *grad_Wh = THCTensor_(newNarrow)(state, gradWeight, 0, inputsize, outputsize); + THCTensor *grad_reset_gate = THCTensor_(newNarrow)(state, grad_gates, 1, 0, outputsize); + THCTensor *grad_update_gate = THCTensor_(newNarrow)(state, grad_gates, 1, outputsize, outputsize); + THCTensor *grad_hidden_candidate = THCTensor_(newNarrow)(state, grad_gates, 1, 2*outputsize, outputsize); + + THCTensor *sub_Wh = THCTensor_(newNarrow)(state, Wh, 1, 2 * outputsize, outputsize); + THCTensor *sub_Wh_t = THCTensor_(newTranspose)(state, sub_Wh, 0, 1); + THCTensor *Wx_t = THCTensor_(newTranspose)(state, Wx, 0, 1); + THCTensor *cur_x_t = THCTensor_(newTranspose)(state, cur_x, 0, 1); + THCTensor *sub_grad_gates = THCTensor_(newNarrow)(state, grad_gates, 1, 0, 2 * outputsize); + THCTensor *sub_grad_Wh = THCTensor_(newNarrow)(state, grad_Wh, 1, 0, 2 * outputsize); + THCTensor *prev_h_t = THCTensor_(newTranspose)(state, prev_h, 0, 1); + + // use grad_update_gate as temporary buffer to compute grad_hidden_candidate and grad_reset_gate + THCTensor_(fill)(state, grad_update_gate, _REAL(0)); + THCTensor_(addcmul)(state, grad_update_gate, grad_next_h, _REAL(-1), update_gate, grad_next_h); + THCTensor_(fill)(state, grad_hidden_candidate, _REAL(1)); + THCTensor_(addcmul)(state, grad_hidden_candidate, grad_hidden_candidate, _REAL(-1), + hidden_candidate, hidden_candidate); + THCTensor_(cmul)(state, grad_hidden_candidate, grad_hidden_candidate, grad_update_gate); + + THCTensor_(fill)(state, grad_update_gate, _REAL(0)); + THCTensor_(addmm)(state, grad_update_gate, _REAL(1), grad_update_gate, _REAL(1), + grad_hidden_candidate, sub_Wh_t); + THCTensor_(cmul)(state, grad_update_gate, grad_update_gate, prev_h); + THCTensor_(fill)(state, grad_reset_gate, _REAL(1)); + THCTensor_(cadd)(state, grad_reset_gate, grad_reset_gate, _REAL(-1), reset_gate); + THCTensor_(cmul)(state, grad_reset_gate, grad_reset_gate, reset_gate); + THCTensor_(cmul)(state, grad_reset_gate, grad_reset_gate, grad_update_gate); + + THCTensor_(cadd)(state, buffer, prev_h, _REAL(-1), hidden_candidate); + THCTensor_(fill)(state, grad_update_gate, _REAL(1)); + THCTensor_(cadd)(state, grad_update_gate, grad_update_gate, _REAL(-1), update_gate); + THCTensor_(cmul)(state, grad_update_gate, grad_update_gate, update_gate); + THCTensor_(cmul)(state, grad_update_gate, grad_update_gate, buffer); + THCTensor_(cmul)(state, grad_update_gate, grad_update_gate, grad_next_h); + THCTensor_(addmm)(state, grad_cur_x, _REAL(0), grad_cur_x, _REAL(1), grad_gates, Wx_t); + THCTensor_(addmm)(state, grad_Wx, _REAL(scale), grad_Wx, _REAL(1), cur_x_t, grad_gates); + THCTensor_(addmm)(state, sub_grad_Wh, _REAL(scale), sub_grad_Wh, + _REAL(1), prev_h_t, sub_grad_gates); + + THCTensor_(resize1d)(state, buffer, outputsize); + THCTensor_(sum)(state, buffer, grad_gates, 0, 0); + THCTensor_(cadd)(state, grad_b, grad_b, _REAL(scale), buffer); + THCTensor_(cmul)(state, buffer, prev_h, reset_gate); + + THCTensor_(narrow)(state, sub_grad_Wh, grad_Wh, 1, 2 * outputsize, outputsize); + THCTensor_(transpose)(state, cur_x_t, buffer, 0, 1); // reuse cur_x_t as buffer_t + THCTensor_(addmm)(state, sub_grad_Wh, _REAL(scale), + sub_grad_Wh, _REAL(1), cur_x_t, grad_hidden_candidate); + THCTensor_(cmul)(state, grad_prev_h, grad_next_h, update_gate); + + THCTensor_(narrow)(state, sub_Wh, Wh, 1, 0, 2 * outputsize); + THCTensor_(transpose)(state, cur_x_t, sub_Wh, 0, 1); // reuse cur_x_t as sub_Wh_t + THCTensor_(addmm)(state, grad_prev_h, _REAL(1), grad_prev_h, _REAL(1), sub_grad_gates, cur_x_t); + + THCTensor_(addmm)(state, buffer, _REAL(0), buffer, _REAL(1), grad_hidden_candidate, sub_Wh_t); + THCTensor_(cmul)(state, buffer, buffer, reset_gate); + THCTensor_(cadd)(state, grad_prev_h, grad_prev_h, _REAL(1), buffer); + + THCTensor_(free)(state, Wx); + THCTensor_(free)(state, Wh); + THCTensor_(free)(state, reset_gate); + THCTensor_(free)(state, update_gate); + THCTensor_(free)(state, hidden_candidate); + + THCTensor_(free)(state, grad_Wx); + THCTensor_(free)(state, grad_Wh); + THCTensor_(free)(state, grad_reset_gate); + THCTensor_(free)(state, grad_update_gate); + THCTensor_(free)(state, grad_hidden_candidate); + + THCTensor_(free)(state, sub_Wh); + THCTensor_(free)(state, sub_Wh_t); + THCTensor_(free)(state, Wx_t); + THCTensor_(free)(state, cur_x_t); + THCTensor_(free)(state, sub_grad_gates); + THCTensor_(free)(state, sub_grad_Wh); + THCTensor_(free)(state, prev_h_t); + + return 2; +} + +static const struct luaL_Reg nn_(StepGRU__) [] = { + {"StepGRU_updateOutput", nn_(StepGRU_updateOutput)}, + {"StepGRU_backward", nn_(StepGRU_backward)}, + {NULL, NULL} +}; + +static void nn_(StepGRU_init)(lua_State *L) { + luaT_pushmetatable(L, torch_Tensor); + luaT_registeratname(L, nn_(StepGRU__), "nn"); + lua_pop(L,1); +} + +#undef _REAL +#endif diff --git a/src/generic/cuda/StepLSTM.cu b/src/generic/cuda/StepLSTM.cu new file mode 100644 index 0000000..5e60806 --- /dev/null +++ b/src/generic/cuda/StepLSTM.cu @@ -0,0 +1,255 @@ +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/cuda/StepLSTM.cu" +#else + +#if defined(THC_REAL_IS_HALF) +#define _REAL(val) THC_float2half(val) +#else +#define _REAL(val) (val) +#endif + +static int nn_(StepLSTM_updateOutput)(lua_State *L) { + THCState *state = getCudaState(L); + THCTensor *weight = (THCTensor *)luaT_checkudata(L, 1, torch_Tensor); + THCTensor *bias = (THCTensor *)luaT_checkudata(L, 2, torch_Tensor); + THCTensor *gates = (THCTensor *)luaT_checkudata(L, 3, torch_Tensor); + THCTensor *cur_x = (THCTensor *)luaT_checkudata(L, 4, torch_Tensor); + THCTensor *prev_h = (THCTensor *)luaT_checkudata(L, 5, torch_Tensor); + THCTensor *prev_c = (THCTensor *)luaT_checkudata(L, 6, torch_Tensor); + int inputsize = luaL_checkinteger(L, 7); + int hiddensize = luaL_checkinteger(L, 8); + int outputsize = luaL_checkinteger(L, 9); + THCTensor *next_h = (THCTensor *)luaT_checkudata(L, 10, torch_Tensor); // when LSTMP pass hidden[t] + THCTensor *next_c = (THCTensor *)luaT_checkudata(L, 11, torch_Tensor); + + int batchsize = THCTensor_(size)(state, cur_x, 0); + if (THCTensor_(size)(state, cur_x, 1) != inputsize) + return LUA_HANDLE_ERROR_STR(L, "expected input[1]:size(2) == inputsize"); + + THLongStorage* size = THLongStorage_newWithSize2(1, 4 * hiddensize); + THCTensor *buffer = THCTensor_(newView)(state, bias, size); + buffer->stride[0] = 0; + buffer->size[0] = batchsize; + + THCTensor *Wx = THCTensor_(newNarrow)(state, weight, 0, 0, inputsize); + THCTensor *Wh = THCTensor_(newNarrow)(state, weight, 0, inputsize, outputsize); + + THCTensor_(resize2d)(state, next_h, batchsize, hiddensize); + THCTensor_(resize2d)(state, next_c, batchsize, hiddensize); + long nElement = THCTensor_(nElement)(state, gates); + THCTensor_(resize2d)(state, gates, batchsize, 4 * hiddensize); + if (nElement != batchsize * 4 * hiddensize) + THCTensor_(fill)(state, gates, _REAL(0)); + + // forward + THCTensor_(addmm)(state, gates, _REAL(1), buffer, _REAL(1), cur_x, Wx); + THCTensor_(addmm)(state, gates, _REAL(1), gates, _REAL(1), prev_h, Wh); + + THCTensor_(narrow)(state, buffer, gates, 1, 0, 3 * hiddensize); + THCTensor_(sigmoid)(state, buffer, buffer); + + THCTensor_(narrow)(state, buffer, gates, 1, 3 * hiddensize, hiddensize); + THCTensor_(tanh)(state, buffer, buffer); + + THCTensor *input_gate = THCTensor_(newNarrow)(state, gates, 1, 0, hiddensize); + THCTensor *forget_gate = THCTensor_(newNarrow)(state, gates, 1, hiddensize, hiddensize); + THCTensor *output_gate = THCTensor_(newNarrow)(state, gates, 1, 2*hiddensize, hiddensize); + THCTensor *input_transform = THCTensor_(newNarrow)(state, gates, 1, 3*hiddensize, hiddensize); + + THCTensor_(cmul)(state, next_h, input_gate, input_transform); + THCTensor_(cmul)(state, next_c, forget_gate, prev_c); + THCTensor_(cadd)(state, next_c, next_c, _REAL(1), next_h); + THCTensor_(tanh)(state, next_h, next_c); + THCTensor_(cmul)(state, next_h, next_h, output_gate); + + THCTensor_(free)(state, Wx); + THCTensor_(free)(state, Wh); + THCTensor_(free)(state, buffer); + THCTensor_(free)(state, input_gate); + THCTensor_(free)(state, forget_gate); + THCTensor_(free)(state, output_gate); + THCTensor_(free)(state, input_transform); + THLongStorage_free(size); + + if (lua_gettop(L) > 11) // implements LSTMP (P stands for projection layer) + { + THCTensor *hidden = next_h; + THCTensor *weightO = (THCTensor *)luaT_checkudata(L, 12, torch_Tensor); + next_h = (THCTensor *)luaT_checkudata(L, 13, torch_Tensor); + THCTensor_(resize2d)(state, next_h, batchsize, outputsize); + THCTensor_(addmm)(state, next_h, _REAL(0), next_h, _REAL(1), hidden, weightO); + // push results onto stack + luaT_pushudata(L, next_c, torch_Tensor); + } + + return 2; +} + +static int nn_(StepLSTM_backward)(lua_State *L) { + THCState *state = getCudaState(L); + THCTensor *weight = (THCTensor *)luaT_checkudata(L, 1, torch_Tensor); + THCTensor *gates = (THCTensor *)luaT_checkudata(L, 2, torch_Tensor); + THCTensor *gradWeight = (THCTensor *)luaT_checkudata(L, 3, torch_Tensor); + THCTensor *grad_b = (THCTensor *)luaT_checkudata(L, 4, torch_Tensor); + THCTensor *grad_gates = (THCTensor *)luaT_checkudata(L, 5, torch_Tensor); + THCTensor *grad_gates_sum = (THCTensor *)luaT_checkudata(L, 6, torch_Tensor); + THCTensor *cur_x = (THCTensor *)luaT_checkudata(L, 7, torch_Tensor); + THCTensor *prev_h = (THCTensor *)luaT_checkudata(L, 8, torch_Tensor); + THCTensor *prev_c = (THCTensor *)luaT_checkudata(L, 9, torch_Tensor); + THCTensor *next_c = (THCTensor *)luaT_checkudata(L, 10, torch_Tensor); + THCTensor *grad_next_h = (THCTensor *)luaT_checkudata(L, 11, torch_Tensor); + THCTensor *grad_next_c = (THCTensor *)luaT_checkudata(L, 12, torch_Tensor); + lua_Number scale = luaL_checknumber(L, 13); + int inputsize = luaL_checkinteger(L, 14); + int hiddensize = luaL_checkinteger(L, 15); + int outputsize = luaL_checkinteger(L, 16); + THCTensor *grad_cur_x = (THCTensor *)luaT_checkudata(L, 17, torch_Tensor); + THCTensor *grad_prev_h = (THCTensor *)luaT_checkudata(L, 18, torch_Tensor); + THCTensor *grad_prev_c = (THCTensor *)luaT_checkudata(L, 19, torch_Tensor); + + int batchsize = THCTensor_(size)(state, cur_x, 0); + if (THCTensor_(size)(state, cur_x, 1) != inputsize) + return LUA_HANDLE_ERROR_STR(L, "expected input[1]:size(2) == inputsize"); + if (THCTensor_(size)(state, grad_next_h, 1) != outputsize) + return LUA_HANDLE_ERROR_STR(L, "expected gradOutput[1]:size(2) == outputsize"); + + if (lua_gettop(L) > 19) // LSTMP + { + THCTensor *weightO = (THCTensor *)luaT_checkudata(L, 20, torch_Tensor); + THCTensor *hidden = (THCTensor *)luaT_checkudata(L, 21, torch_Tensor); + THCTensor *gradWeightO = (THCTensor *)luaT_checkudata(L, 22, torch_Tensor); + THCTensor *grad_hidden = (THCTensor *)luaT_checkudata(L, 23, torch_Tensor); + + THCTensor *hidden_t = THCTensor_(newTranspose)(state, hidden, 0, 1); + THCTensor *weightO_t = THCTensor_(newTranspose)(state, weightO, 0, 1); + + THCTensor_(addmm)(state, gradWeightO, _REAL(scale), gradWeightO, _REAL(1), hidden_t, grad_next_h); + THCTensor_(resize2d)(state, grad_hidden, batchsize, hiddensize); + THCTensor_(addmm)(state, grad_hidden, _REAL(0), grad_hidden, _REAL(1), grad_next_h, weightO_t); + + grad_next_h = grad_hidden; + + THCTensor_(free)(state, hidden_t); + THCTensor_(free)(state, weightO_t); + + // push results to top of stack + luaT_pushudata(L, grad_cur_x, torch_Tensor); + luaT_pushudata(L, grad_prev_h, torch_Tensor); + luaT_pushudata(L, grad_prev_c, torch_Tensor); + } + + THCTensor_(resize2d)(state, grad_cur_x, batchsize, inputsize); + THCTensor_(resize2d)(state, grad_prev_h, batchsize, outputsize); + THCTensor_(resize2d)(state, grad_prev_c, batchsize, hiddensize); + + // these tensors were set-up in updateOutput + THCTensor *Wx = THCTensor_(newNarrow)(state, weight, 0, 0, inputsize); + THCTensor *Wh = THCTensor_(newNarrow)(state, weight, 0, inputsize, outputsize); + + THCTensor *input_gate = THCTensor_(newNarrow)(state, gates, 1, 0, hiddensize); + THCTensor *forget_gate = THCTensor_(newNarrow)(state, gates, 1, hiddensize, hiddensize); + THCTensor *output_gate = THCTensor_(newNarrow)(state, gates, 1, 2*hiddensize, hiddensize); + THCTensor *input_transform = THCTensor_(newNarrow)(state, gates, 1, 3*hiddensize, hiddensize); + + // set-up grad tensors + THCTensor *grad_Wx = THCTensor_(newNarrow)(state, gradWeight, 0, 0, inputsize); + THCTensor *grad_Wh = THCTensor_(newNarrow)(state, gradWeight, 0, inputsize, outputsize); + + THCTensor_(resize2d)(state, grad_gates, batchsize, 4 * hiddensize); + + THCTensor *grad_input_gate = THCTensor_(newNarrow)(state, grad_gates, 1, 0, hiddensize); + THCTensor *grad_forget_gate = THCTensor_(newNarrow)(state, grad_gates, 1, hiddensize, hiddensize); + THCTensor *grad_output_gate = THCTensor_(newNarrow)(state, grad_gates, 1, 2*hiddensize, hiddensize); + THCTensor *grad_input_transform = THCTensor_(newNarrow)(state, grad_gates, 1, 3*hiddensize, hiddensize); + + // backward + + // we use grad_[input,forget,output]_gate as temporary buffers to compute grad_prev_c. + THCTensor_(tanh)(state, grad_input_gate, next_c); + THCTensor_(cmul)(state, grad_forget_gate, grad_input_gate, grad_input_gate); + + THCTensor_(fill)(state, grad_output_gate, _REAL(1)); + THCTensor_(cadd)(state, grad_output_gate, grad_output_gate, _REAL(-1), grad_forget_gate); + THCTensor_(cmul)(state, grad_output_gate, grad_output_gate, output_gate); + THCTensor_(cmul)(state, grad_output_gate, grad_output_gate, grad_next_h); + THCTensor_(cadd)(state, grad_prev_c, grad_next_c, _REAL(1), grad_output_gate); + + // we use above grad_input_gate to compute grad_output_gate + THCTensor_(fill)(state, grad_output_gate, _REAL(1)); + THCTensor_(cadd)(state, grad_output_gate, grad_output_gate, _REAL(-1), output_gate); + THCTensor_(cmul)(state, grad_output_gate, grad_output_gate, output_gate); + THCTensor_(cmul)(state, grad_output_gate, grad_output_gate, grad_input_gate); + THCTensor_(cmul)(state, grad_output_gate, grad_output_gate, grad_next_h); + + // Use grad_input_gate as a temporary buffer for computing grad_input_transform + THCTensor_(cmul)(state, grad_input_gate, input_transform, input_transform); + THCTensor_(fill)(state, grad_input_transform, _REAL(1)); + THCTensor_(cadd)(state, grad_input_transform, grad_input_transform, _REAL(-1), grad_input_gate); + THCTensor_(cmul)(state, grad_input_transform, grad_input_transform, input_gate); + THCTensor_(cmul)(state, grad_input_transform, grad_input_transform, grad_prev_c); + + // We don't need any temporary storage for these so do them last + THCTensor_(fill)(state, grad_input_gate, _REAL(1)); + THCTensor_(cadd)(state, grad_input_gate, grad_input_gate, _REAL(-1), input_gate); + THCTensor_(cmul)(state, grad_input_gate, grad_input_gate, input_gate); + THCTensor_(cmul)(state, grad_input_gate, grad_input_gate, input_transform); + THCTensor_(cmul)(state, grad_input_gate, grad_input_gate, grad_prev_c); + + THCTensor_(fill)(state, grad_forget_gate, _REAL(1)); + THCTensor_(cadd)(state, grad_forget_gate, grad_forget_gate, _REAL(-1), forget_gate); + THCTensor_(cmul)(state, grad_forget_gate, grad_forget_gate, forget_gate); + THCTensor_(cmul)(state, grad_forget_gate, grad_forget_gate, prev_c); + THCTensor_(cmul)(state, grad_forget_gate, grad_forget_gate, grad_prev_c); + + // now for the main dish + THCTensor *Wx_t = THCTensor_(newTranspose)(state, Wx, 0, 1); + THCTensor *Wh_t = THCTensor_(newTranspose)(state, Wh, 0, 1); + THCTensor *cur_x_t = THCTensor_(newTranspose)(state, cur_x, 0, 1); + THCTensor *prev_h_t = THCTensor_(newTranspose)(state, prev_h, 0, 1); + + THCTensor_(addmm)(state, grad_cur_x, _REAL(0), grad_cur_x, _REAL(1), grad_gates, Wx_t); + THCTensor_(addmm)(state, grad_Wx, _REAL(1), grad_Wx, _REAL(scale), cur_x_t, grad_gates); + THCTensor_(addmm)(state, grad_Wh, _REAL(1), grad_Wh, _REAL(scale), prev_h_t, grad_gates); + THCTensor_(resize2d)(state, grad_gates_sum, 1, 4 * hiddensize); + THCTensor_(sum)(state, grad_gates_sum, grad_gates, 0, 0); + THCTensor_(cadd)(state, grad_b, grad_b, _REAL(scale), grad_gates_sum); + + THCTensor_(addmm)(state, grad_prev_h, _REAL(0), grad_prev_h, _REAL(1), grad_gates, Wh_t); + THCTensor_(cmul)(state, grad_prev_c, grad_prev_c, forget_gate); + + THCTensor_(free)(state, Wx); + THCTensor_(free)(state, Wh); + THCTensor_(free)(state, input_gate); + THCTensor_(free)(state, forget_gate); + THCTensor_(free)(state, output_gate); + THCTensor_(free)(state, input_transform); + + THCTensor_(free)(state, grad_Wx); + THCTensor_(free)(state, grad_Wh); + THCTensor_(free)(state, grad_input_gate); + THCTensor_(free)(state, grad_forget_gate); + THCTensor_(free)(state, grad_output_gate); + THCTensor_(free)(state, grad_input_transform); + + THCTensor_(free)(state, Wx_t); + THCTensor_(free)(state, Wh_t); + THCTensor_(free)(state, cur_x_t); + THCTensor_(free)(state, prev_h_t); + + return 3; +} + +static const struct luaL_Reg nn_(StepLSTM__) [] = { + {"StepLSTM_updateOutput", nn_(StepLSTM_updateOutput)}, + {"StepLSTM_backward", nn_(StepLSTM_backward)}, + {NULL, NULL} +}; + +static void nn_(StepLSTM_init)(lua_State *L) { + luaT_pushmetatable(L, torch_Tensor); + luaT_registeratname(L, nn_(StepLSTM__), "nn"); + lua_pop(L,1); +} + +#undef _REAL +#endif diff --git a/src/generic/cuda/VariableLength.cu b/src/generic/cuda/VariableLength.cu new file mode 100644 index 0000000..5ade34c --- /dev/null +++ b/src/generic/cuda/VariableLength.cu @@ -0,0 +1,386 @@ +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/cuda/VariableLength.cu" +#else + +#if defined(THC_REAL_IS_HALF) +#define _REAL(val) THC_float2half(val) +#else +#define _REAL(val) (val) +#endif + +static int nn_(from_samples_to_structured)(lua_State *L) { + THCState *state = getCudaState(L); + // processes inputs + if (lua_gettop(L) != 3) + return LUA_HANDLE_ERROR_STR(L, "expected 3 arguments: samples, output, mask"); + const int samples_index = 1; + const int output_index = 2; + const int mask_index = 3; + if (!lua_istable(L, samples_index)) + return LUA_HANDLE_ERROR_STR(L, "expected table for first argument"); + THCTensor *output = (THCTensor *)luaT_checkudata(L, output_index, torch_Tensor); + if (!THCTensor_(isContiguous)(state, output)) + return LUA_HANDLE_ERROR_STR(L, "tensor should be contiguous"); + THCudaByteTensor *mask = (THCudaByteTensor *)luaT_checkudata(L, mask_index, "torch.CudaByteTensor"); + if (!THCudaByteTensor_isContiguous(state, mask)) + return LUA_HANDLE_ERROR_STR(L, "tensor should be contiguous"); + + // loads all samples from the table + long n_samples = lua_objlen(L, samples_index); + THCTensor *tensors[n_samples]; + lua_pushnil(L); + while (lua_next(L, samples_index) != 0) { + long index = lua_tointeger(L, -2); + THCTensor *tensor = (THCTensor *)luaT_checkudata(L, -1, torch_Tensor); + tensors[index-1] = tensor; + lua_pop(L, 1); + } + + // processes the samples to get some meta-info that will be used to determine the positioning in + // the dense tensor created in the output + Sample samples_info[n_samples]; + THCTensor* step = THCTensor_(new)(state); // a tensor that contains first step of first tensor + THCTensor* _step = THCTensor_(new)(state); // contains first step of other tensors (sizes much match) + for (long i = 0; i < n_samples; i++) { + THCTensor_(narrow)(state, _step, tensors[i], 0, 0, 1); // 1 [x ...] + if (i == 0) + THCTensor_(narrow)(state, step, tensors[i], 0, 0, 1); + else if (!THCTensor_(isSameSizeAs)(state, step, _step)) + return LUA_HANDLE_ERROR_STR(L, "got tensors of different sizes"); + samples_info[i].length = THCTensor_(size)(state, tensors[i], 0); + samples_info[i].index = i; + samples_info[i].assigned_row = -1; + } + + // sorts samples in order of length + qsort(samples_info, n_samples, sizeof(Sample), sample_compare); + + long max_length = samples_info[n_samples-1].length; + + // creates the two tables with meta-info that will be output + lua_newtable(L); + const int indexes_index = lua_gettop(L); + int local_indexes_index = 0; + lua_newtable(L); + const int mapped_lengths_index = lua_gettop(L); + int local_mapped_lengths_index = 0; + + long row_index = 0; + long length_available = max_length; + long count = 0, row_count = 0; + long start_index = 0; + + // while there are unprocessed samples... + while (count < n_samples) { + // flag of whether a sample was added in this iteration + int added_sample = 0; + // for each sample provided + for (long i = n_samples-1; i >= 0; i--) { + // checks if the current sample hasn't been assigned yet and fits the space left in the line + if (samples_info[i].assigned_row == -1 && samples_info[i].length <= length_available) { + long sample_index = samples_info[i].index; + + // if first sample in the row, creates sub-tables with meta-info for each row + if (row_count == 0) { + lua_newtable(L); + local_indexes_index = lua_gettop(L); + lua_newtable(L); + local_mapped_lengths_index = lua_gettop(L); + } + + // places the meta-info about the sample (index and length) into the tables + row_count++; + lua_pushinteger(L, sample_index+1); + lua_rawseti(L, local_indexes_index, row_count); + lua_pushinteger(L, samples_info[i].length); + lua_rawseti(L, local_mapped_lengths_index, row_count); + + // assigns the sample to this row and updates the row and sample info + samples_info[i].assigned_row = row_index; + length_available -= samples_info[i].length + 1; + start_index += samples_info[i].length + 1; + count++; + added_sample = 1; + } + } + + // if no sample was added, it means no sample available can fit in the space left, so we have to + // add another table + if (!added_sample) { + // saves the current row-based meta-info + lua_rawseti(L, mapped_lengths_index, row_index+1); + lua_rawseti(L, indexes_index, row_index+1); + // and advances rows + row_index++; + length_available = max_length; + start_index = 0; + row_count = 0; + } + } + // saves the last row's meta-info + lua_rawseti(L, mapped_lengths_index, row_index+1); + lua_rawseti(L, indexes_index, row_index+1); + + // with the info available, resizes the output and mask + long n_rows = lua_objlen(L, indexes_index); + // output will have size: maxlen x nrows [x ...] + long output_dim = THCTensor_(nDimension)(state, step) + 1; + THLongStorage* output_size = THLongStorage_newWithSize(output_dim); + output_size->data[0] = max_length; + output_size->data[1] = n_rows; + for (long i=2; i < output_dim; i++) { + output_size->data[i] = THCTensor_(size)(state, step, i-1); + } + THCTensor_(resize)(state, output, output_size, NULL); + THCudaByteTensor_resize2d(state, mask, max_length, n_rows); + // mask starts filled with ones indicating it's empty + THCudaByteTensor_fill(state, mask, 1); + + THCTensor *row = THCTensor_(new)(state), *section = THCTensor_(new)(state); + THCudaByteTensor *mrow = THCudaByteTensor_new(state), *msection = THCudaByteTensor_new(state); + // for each row in the output + for (long i = 0; i < n_rows; i++) { + THCTensor_(select)(state, row, output, 1, i); + THCudaByteTensor_select(state, mrow, mask, 1, i); + lua_rawgeti(L, indexes_index, i+1); + const int local_indexes_index = lua_gettop(L); + lua_rawgeti(L, mapped_lengths_index, i+1); + const int local_mapped_lengths_index = lua_gettop(L); + + long n_entries_in_row = lua_objlen(L, -1); + long start = 0; + // for each sample placed in that row + for (long j = 0; j < n_entries_in_row; j++) { + lua_rawgeti(L, local_indexes_index, j+1); + lua_rawgeti(L, local_mapped_lengths_index, j+1); + long index = lua_tointeger(L, -2); + long length = lua_tointeger(L, -1); + lua_pop(L, 2); + + // copies the data from the input and fills the mask + THCTensor_(narrow)(state, section, row, 0, start, length); + THCudaByteTensor_narrow(state, msection, mrow, 0, start, length); + THCTensor_(copy)(state, section, tensors[index-1]); + THCudaByteTensor_fill(state, msection, 0); + start += length + 1; + } + lua_pop(L, 2); + } + THCTensor_(free)(state, row); + THCTensor_(free)(state, section); + THCTensor_(free)(state, step); + THCTensor_(free)(state, step); + THLongStorage_free(output_size); + THCudaByteTensor_free(state, mrow); + THCudaByteTensor_free(state, msection); + + return 2; +} + +// converts the dense tensor `input` into a list of samples `output`, each with its correct length. +static int nn_(from_structured_to_samples)(lua_State *L) { + THCState *state = getCudaState(L); + // processes inputs + if (lua_gettop(L) != 3) + return LUA_HANDLE_ERROR_STR(L, "expected 3 arguments: indexing, lengths, input"); + const int indexes_index = 1; + const int mapped_lengths_index = 2; + const int input_index = 3; + if (!lua_istable(L, indexes_index)) + return LUA_HANDLE_ERROR_STR(L, "expected table for first argument"); + if (!lua_istable(L, mapped_lengths_index)) + return LUA_HANDLE_ERROR_STR(L, "expected table for second argument"); + + THCTensor *input = (THCTensor *)luaT_checkudata(L, input_index, torch_Tensor); + if (!THCTensor_(isContiguous)(state, input)) + return LUA_HANDLE_ERROR_STR(L, "tensor should be contiguous"); + + lua_newtable(L); + const int output_index = lua_gettop(L); + + long n_rows = lua_objlen(L, indexes_index); + THCTensor *row = THCTensor_(new)(state); + // for each row in the input + for (long i = 0; i < n_rows; i++) { + THCTensor_(select)(state, row, input, 1, i); + lua_rawgeti(L, indexes_index, i+1); + const int local_indexes_index = lua_gettop(L); + lua_rawgeti(L, mapped_lengths_index, i+1); + const int local_mapped_lengths_index = lua_gettop(L); + + long n_entries_in_row = lua_objlen(L, -1); + long start = 0; + // for each sample placed in that row + for (long j = 0; j < n_entries_in_row; j++) { + lua_rawgeti(L, local_indexes_index, j+1); + lua_rawgeti(L, local_mapped_lengths_index, j+1); + long index = lua_tointeger(L, -2); + long length = lua_tointeger(L, -1); + lua_pop(L, 2); + + // gets the sub-tensor of the row that corresponds to the sample and places in the table + THCTensor *dest = THCTensor_(new)(state); + THCTensor_(narrow)(state, dest, row, 0, start, length); + start += length + 1; + luaT_pushudata(L, dest, torch_Tensor); + lua_rawseti(L, output_index, index); + } + lua_pop(L, 2); + } + THCTensor_(free)(state, row); + + return 1; +} + +static int nn_(from_structured_to_final)(lua_State *L) { + THCState *state = getCudaState(L); + // processes inputs + if (lua_gettop(L) != 4) + return LUA_HANDLE_ERROR_STR(L, "expected 4 arguments: indexing, lengths, input, output"); + const int indexes_index = 1; + const int mapped_lengths_index = 2; + const int input_index = 3; + const int output_index = 4; + if (!lua_istable(L, indexes_index)) + return LUA_HANDLE_ERROR_STR(L, "expected table for first argument"); + if (!lua_istable(L, mapped_lengths_index)) + return LUA_HANDLE_ERROR_STR(L, "expected table for second argument"); + + THCTensor *input = (THCTensor *)luaT_checkudata(L, input_index, torch_Tensor); + if (!THCTensor_(isContiguous)(state, input)) + return LUA_HANDLE_ERROR_STR(L, "tensor should be contiguous"); + THCTensor *output = (THCTensor *)luaT_checkudata(L, output_index, torch_Tensor); + if (!THCTensor_(isContiguous)(state, output)) + return LUA_HANDLE_ERROR_STR(L, "tensor should be contiguous"); + + long n_samples = get_n_samples(L, mapped_lengths_index); + long output_dim = THCTensor_(nDimension)(state, input) - 1; + THLongStorage* output_size = THLongStorage_newWithSize(output_dim); // n_samples [x ...] + output_size->data[0] = n_samples; + for (long i=1;i < output_dim; i++){ + output_size->data[i] = THCTensor_(size)(state, input, i+1); + } + THCTensor_(resize)(state, output, output_size, NULL); + + long n_rows = lua_objlen(L, indexes_index); + THCTensor *row = THCTensor_(new)(state), *section = THCTensor_(new)(state); + THCTensor *output_section = THCTensor_(new)(state); + // for each row in the output + for (long i = 0; i < n_rows; i++) { + THCTensor_(select)(state, row, input, 1, i); + lua_rawgeti(L, indexes_index, i+1); + const int local_indexes_index = lua_gettop(L); + lua_rawgeti(L, mapped_lengths_index, i+1); + const int local_mapped_lengths_index = lua_gettop(L); + + long n_entries_in_row = lua_objlen(L, -1); + long start = 0; + // for each sample placed in that row + for (long j = 0; j < n_entries_in_row; j++) { + lua_rawgeti(L, local_indexes_index, j+1); + lua_rawgeti(L, local_mapped_lengths_index, j+1); + long index = lua_tointeger(L, -2); + long length = lua_tointeger(L, -1); + lua_pop(L, 2); + + // gets the sub-tensor of the row that corresponds to the sample and places in the table + THCTensor_(select)(state, section, row, 0, start + length-1); + THCTensor_(select)(state, output_section, output, 0, index-1); + THCTensor_(copy)(state, output_section, section); + start += length + 1; + } + lua_pop(L, 2); + } + THCTensor_(free)(state, row); + THCTensor_(free)(state, section); + THCTensor_(free)(state, output_section); + THLongStorage_free(output_size); + + return 0; +} + +static int nn_(from_final_to_structured)(lua_State *L) { + THCState *state = getCudaState(L); + if (lua_gettop(L) != 4) + return LUA_HANDLE_ERROR_STR(L, "expected 4 arguments: indexing, lengths, input, output"); + const int indexes_index = 1; + const int mapped_lengths_index = 2; + const int input_index = 3; + const int output_index = 4; + if (!lua_istable(L, indexes_index)) + return LUA_HANDLE_ERROR_STR(L, "expected table for first argument"); + if (!lua_istable(L, mapped_lengths_index)) + return LUA_HANDLE_ERROR_STR(L, "expected table for second argument"); + + THCTensor *input = (THCTensor *)luaT_checkudata(L, input_index, torch_Tensor); + if (!THCTensor_(isContiguous)(state, input)) + return LUA_HANDLE_ERROR_STR(L, "tensor should be contiguous"); + THCTensor *output = (THCTensor *)luaT_checkudata(L, output_index, torch_Tensor); + if (!THCTensor_(isContiguous)(state, output)) + return LUA_HANDLE_ERROR_STR(L, "tensor should be contiguous"); + + long max_length = get_max_length(L, mapped_lengths_index); + long n_rows = lua_objlen(L, mapped_lengths_index); + + long output_dim = THCTensor_(nDimension)(state, input) + 1; + THLongStorage* output_size = THLongStorage_newWithSize(output_dim); // max_length x n_rows [x ...] + output_size->data[0] = max_length; + output_size->data[1] = n_rows; + for (long i=2;i < output_dim; i++){ + output_size->data[i] = THCTensor_(size)(state, input, i-1); + } + THCTensor_(resize)(state, output, output_size, NULL); + THCTensor_(fill)(state, output, _REAL(0)); + + THCTensor *row = THCTensor_(new)(state), *section = THCTensor_(new)(state); + THCTensor *input_section = THCTensor_(new)(state); + // for each row in the input + for (long i = 0; i < n_rows; i++) { + THCTensor_(select)(state, row, output, 1, i); + lua_rawgeti(L, indexes_index, i+1); + const int local_indexes_index = lua_gettop(L); + lua_rawgeti(L, mapped_lengths_index, i+1); + const int local_mapped_lengths_index = lua_gettop(L); + + long n_entries_in_row = lua_objlen(L, -1); + long start = 0; + // for each sample placed in that row + for (long j = 0; j < n_entries_in_row; j++) { + lua_rawgeti(L, local_indexes_index, j+1); + lua_rawgeti(L, local_mapped_lengths_index, j+1); + long index = lua_tointeger(L, -2); + long length = lua_tointeger(L, -1); + lua_pop(L, 2); + + // copies the data from the input + THCTensor_(select)(state, section, row, 0, start + length-1); + THCTensor_(select)(state, input_section, input, 0, index-1); + THCTensor_(copy)(state, section, input_section); + start += length + 1; + } + lua_pop(L, 2); + } + THCTensor_(free)(state, row); + THCTensor_(free)(state, section); + THCTensor_(free)(state, input_section); + THLongStorage_free(output_size); + + return 0; +} + +static const struct luaL_Reg nn_(VariableLength__) [] = { + {"VariableLength_FromSamples", nn_(from_samples_to_structured)}, + {"VariableLength_ToSamples", nn_(from_structured_to_samples)}, + {"VariableLength_ToFinal", nn_(from_structured_to_final)}, + {"VariableLength_FromFinal", nn_(from_final_to_structured)}, + {NULL, NULL} +}; + +static void nn_(VariableLength_init)(lua_State *L) { + luaT_pushmetatable(L, torch_Tensor); + luaT_registeratname(L, nn_(VariableLength__), "nn"); + lua_pop(L,1); +} + +#undef _REAL +#endif diff --git a/src/rnn.c b/src/rnn.c index cc2f210..900b9b5 100644 --- a/src/rnn.c +++ b/src/rnn.c @@ -1,60 +1,9 @@ -#include "luaT.h" -#include "TH.h" +#include "rnn.h" #ifdef _OPENMP #include "omp.h" #endif -#include "error.h" -#include "utils.h" -#include - -typedef struct { - long length; - long index; - long assigned_row; -} Sample; - -static int sample_compare(const void *a_, const void *b_) { - Sample *a = (Sample*) a_; - Sample *b = (Sample*) b_; - return a->length < b->length ? -1 : a->length > b->length; -} - -static long get_max_length(lua_State *L, int lengths_index) { - const int current_top = lua_gettop(L); - long max_length = 0; - - lua_pushnil(L); - while (lua_next(L, lengths_index) != 0) { - const int inner_index = current_top + 2; - - lua_pushnil(L); - while (lua_next(L, inner_index) != 0) { - long length = lua_tointeger(L, -1); - if (length > max_length) - max_length = length; - lua_pop(L, 1); - } - - lua_pop(L, 1); - } - - return max_length; -} - -static long get_n_samples(lua_State *L, int lengths_index) { - long count = 0; - - lua_pushnil(L); - while (lua_next(L, lengths_index) != 0) { - count += lua_objlen(L, -1); - lua_pop(L, 1); - } - - return count; -} - #define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME) #define torch_Tensor TH_CONCAT_STRING_3(torch., Real, Tensor) #define nn_(NAME) TH_CONCAT_3(nn_, Real, NAME) @@ -78,5 +27,9 @@ DLL_EXPORT int luaopen_librnn(lua_State *L) nn_DoubleStepLSTM_init(L); nn_DoubleStepGRU_init(L); +#if defined(USE_CUDA) + return cuda_librnn_init(L); +#else return 1; +#endif } diff --git a/src/rnn.cu b/src/rnn.cu new file mode 100644 index 0000000..c3fe574 --- /dev/null +++ b/src/rnn.cu @@ -0,0 +1,52 @@ +#include "rnn.h" +#include "THC/THC.h" + +static THCState *getCudaState(lua_State *L) +{ + lua_getglobal(L, "cutorch"); + lua_getfield(L, -1, "getState"); + lua_call(L, 0, 1); + THCState * state = (THCState *)lua_touserdata(L, -1); + lua_pop(L, 2); + return state; +} + +#define torch_(NAME) TH_CONCAT_3(torch_, CReal, NAME) +#define torch_Tensor TH_CONCAT_STRING_3(torch., CReal, Tensor) +#define nn_(NAME) TH_CONCAT_3(nn_, CReal, NAME) + +#include "generic/cuda/VariableLength.cu" +#include "THC/THCGenerateFloatTypes.h" + +#include "generic/cuda/StepLSTM.cu" +#include "THC/THCGenerateFloatTypes.h" + +#include "generic/cuda/StepGRU.cu" +#include "THC/THCGenerateFloatTypes.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +int cuda_librnn_init(lua_State *L) +{ + nn_CudaVariableLength_init(L); + nn_CudaStepLSTM_init(L); + nn_CudaStepGRU_init(L); + + nn_CudaDoubleVariableLength_init(L); + nn_CudaDoubleStepLSTM_init(L); + nn_CudaDoubleStepGRU_init(L); + +#ifdef CUDA_HALF_TENSOR + nn_CudaHalfVariableLength_init(L); + nn_CudaHalfStepLSTM_init(L); + nn_CudaHalfStepGRU_init(L); +#endif + + return 1; +} + +#if defined(__cplusplus) +} +#endif diff --git a/src/rnn.h b/src/rnn.h new file mode 100644 index 0000000..49f8b1c --- /dev/null +++ b/src/rnn.h @@ -0,0 +1,62 @@ +#include "luaT.h" +#include "TH.h" + +#include "error.h" +#include "utils.h" +#include + +typedef struct { + long length; + long index; + long assigned_row; +} Sample; + +static int sample_compare(const void *a_, const void *b_) { + Sample *a = (Sample*) a_; + Sample *b = (Sample*) b_; + return a->length < b->length ? -1 : a->length > b->length; +} + +static long get_max_length(lua_State *L, int lengths_index) { + const int current_top = lua_gettop(L); + long max_length = 0; + + lua_pushnil(L); + while (lua_next(L, lengths_index) != 0) { + const int inner_index = current_top + 2; + + lua_pushnil(L); + while (lua_next(L, inner_index) != 0) { + long length = lua_tointeger(L, -1); + if (length > max_length) + max_length = length; + lua_pop(L, 1); + } + + lua_pop(L, 1); + } + + return max_length; +} + +static long get_n_samples(lua_State *L, int lengths_index) { + long count = 0; + + lua_pushnil(L); + while (lua_next(L, lengths_index) != 0) { + count += lua_objlen(L, -1); + lua_pop(L, 1); + } + + return count; +} + +#if defined(USE_CUDA) +#if defined(__cplusplus) +extern "C" { +#endif +int cuda_librnn_init(lua_State *L); +#if defined(__cplusplus) +} +#endif +#endif