Skip to content

Commit

Permalink
Merge remote-tracking branch 'dpnn/master' into merge-dpnn
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Leonard committed Apr 20, 2017
2 parents 2bb5129 + ca0e99f commit a155e43
Show file tree
Hide file tree
Showing 68 changed files with 9,730 additions and 59 deletions.
57 changes: 57 additions & 0 deletions ArgMax.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
------------------------------------------------------------------------
--[[ ArgMax ]]--
-- Returns the index of the maxima for dimension dim.
-- Cannot backpropagate through this module.
-- Created for use with ReinforceCategorical.
------------------------------------------------------------------------
local ArgMax, parent = torch.class("nn.ArgMax", "nn.Module")

function ArgMax:__init(dim, nInputDim, asLong)
parent.__init(self)
self.dim = dim or 1
self.nInputDim = nInputDim or 9999
self.asLong = (asLong == nil) and true or asLong
if self.asLong then
self.output = torch.LongTensor()
end
end

function ArgMax:updateOutput(input)
self._value = self._value or input.new()
self._indices = self._indices or
(torch.type(input) == 'torch.CudaTensor' and (torch.CudaLongTensor and torch.CudaLongTensor() or torch.CudaTensor()) or torch.LongTensor())
local dim = (input:dim() > self.nInputDim) and (self.dim + 1) or self.dim

torch.max(self._value, self._indices, input, dim)
if input:dim() > 1 then
local idx = self._indices:select(dim, 1)
self.output:resize(idx:size()):copy(idx)
else
self.output:resize(self._indices:size()):copy(self._indices)
end
return self.output
end

function ArgMax:updateGradInput(input, gradOutput)
-- cannot backprop from an index so just return a dummy zero tensor
self.gradInput:resizeAs(input):zero()
return self.gradInput
end

function ArgMax:type(type)
-- torch.max expects a LongTensor as indices, whereas cutorch.max expects a CudaTensor.
if type == 'torch.CudaTensor' then
parent.type(self, type)
else
-- self._indices must be a LongTensor. Setting it to nil temporarily avoids
-- unnecessary memory allocations.
local indices
indices, self._indices = self._indices, nil
parent.type(self, type)
self._indices = indices and indices:long() or nil
end
if self.asLong then
self.output = torch.LongTensor()
end
return self
end
18 changes: 18 additions & 0 deletions BatchNormalization.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
local _ = require 'moses'
local BN, parent = nn.BatchNormalization, nn.Module

local empty = _.clone(parent.dpnn_mediumEmpty)
table.insert(empty, 'buffer')
table.insert(empty, 'buffer2')
table.insert(empty, 'centered')
table.insert(empty, 'std')
table.insert(empty, 'normalized')
table.insert(empty, 'output')
table.insert(empty, 'gradInput')
BN.dpnn_mediumEmpty = empty

-- for sharedClone
local params = _.clone(parent.dpnn_parameters)
table.insert(params, 'running_mean')
table.insert(params, 'running_var')
BN.dpnn_parameters = params
82 changes: 82 additions & 0 deletions BinaryClassReward.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
------------------------------------------------------------------------
--[[ BinaryClassReward ]]--
-- Variance reduced binary classification reinforcement criterion.
-- The binary class version of VRClassReward.
-- input : {class prediction, baseline reward}
-- Reward is 1 for success, Reward is 0 otherwise.
-- reward = scale*(Reward - baseline) where baseline is 2nd input element
-- Note : for RNNs with R = 1 for last step in sequence, encapsulate it
-- in nn.ModuleCriterion(BinaryClassReward, nn.SelectTable(-1))
------------------------------------------------------------------------
local BinaryClassReward, parent = torch.class("nn.BinaryClassReward", "nn.Criterion")

function BinaryClassReward:__init(module, scale, criterion)
parent.__init(self)
self.module = module -- so it can call module:reinforce(reward)
self.scale = scale or 1 -- scale of reward
self.criterion = criterion or nn.MSECriterion() -- baseline criterion
self.sizeAverage = true
self.gradInput = {torch.Tensor()}
end

function BinaryClassReward:updateOutput(input, target)
assert(torch.type(input) == 'table')
local input = input[1]
assert(input:dim() == 1)
assert(target:dim() == 1)
self._binary = self._binary or input.new()
self._binary:gt(input, 0.5)

-- max class value is class prediction
if torch.type(self._binary) ~= torch.type(target) then
self._target = self._target or self._binary.new()
self._target:resize(target:size()):copy(target)
target = self._target
end

-- reward = scale when correctly classified
self._reward = self._reward or input.new()
self._reward:eq(self._binary, target)
self.reward = self.reward or input.new()
self.reward:resize(self._reward:size(1)):copy(self._reward)
self.reward:mul(self.scale)

-- loss = -sum(reward)
self.output = -self.reward:sum()
if self.sizeAverage then
self.output = self.output/input:size(1)
end
return self.output
end

function BinaryClassReward:updateGradInput(inputTable, target)
local input, baseline = unpack(inputTable)

-- reduce variance of reward using baseline
self.vrReward = self.vrReward or self.reward.new()
self.vrReward:resizeAs(self.reward):copy(self.reward)
self.vrReward:add(-1, baseline)
if self.sizeAverage then
self.vrReward:div(input:size(1))
end
-- broadcast reward to modules
self.module:reinforce(self.vrReward)

-- zero gradInput (this criterion has no gradInput for class pred)
self.gradInput[1]:resizeAs(input):zero()

-- learn the baseline reward
self.gradInput[2] = self.criterion:backward(baseline, self.reward)

return self.gradInput
end

function BinaryClassReward:type(type)
self._binary = nil
self._target = nil
local module = self.module
self.module = nil
local ret = parent.type(self, type)
self.module = module
return ret
end
91 changes: 91 additions & 0 deletions BinaryLogisticRegression.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
------------------------------------------------------------------------
--[[ BinaryLogisticRegression ]]--
-- Takes an image of size batchSize x 1 or just batchSize as input.
-- Computes Binary Logistic Regression Cost.
-- Useful for 2 class classification.
------------------------------------------------------------------------

local BinaryLogisticRegression, parent = torch.class('nn.BinaryLogisticRegression', 'nn.Criterion')

function BinaryLogisticRegression:__init(sizeAverage)
parent.__init(self)
if sizeAverage ~= nil then
self.sizeAverage = sizeAverage
else
self.sizeAverage = true
end
end

function BinaryLogisticRegression:updateOutput(input, target)
local inputDim = input:nDimension()
local targetDim = target:nDimension()

-- Check dimensions of input and target
assert(inputDim == 1 or inputDim == 2,
"Input:Expecting batchSize or batchSize x 1")
assert(targetDim == 1 or targetDim == 2,
"Target:Expecting batchSize or batchSize x 1")
if inputDim == 2 then
assert(input:size(1)==1 or input:size(2)==1,
"Input: Expecting batchSize x 1.")
end
if targetDim == 2 then
assert(target:size(1)==1 or target:size(2)==1,
"Target: Expecting batchSize x 1.")
end

local inputElements = input:nElement()
local targetElements = target:nElement()

assert(inputElements == targetElements,
"No of input and target elements should be same.")

self._k = inputElements
local input = input:view(-1)
local target = target:view(-1)

self._baseExponents = self._baseExponents or input.new()
self._coeff = self._coeff or input.new()
self._logCoeff = self._logCoeff or input.new()

--Compute exponent = -target*input
self._baseExponents:resize(input:size()):copy(input)
self._baseExponents:cmul(target)
self._baseExponents:mul(-1)
-- Compute exp(exponent)
self._baseExponents:exp()

self._coeff:resize(input:size()):copy(self._baseExponents)
self._coeff:add(1)

self._logCoeff:resize(input:size()):copy(self._coeff)
self._logCoeff:log()

if self.sizeAverage then
return self._logCoeff:sum()/(self._k)
else
return self._logCoeff:sum()
end
end

function BinaryLogisticRegression:updateGradInput(input, target)
self.gradInput = self.gradInput or input.new()
local gradInput = self.gradInput
gradInput:resize(input:size()):copy(target)
gradInput:mul(-1)
gradInput:cmul(self._baseExponents)
gradInput:cdiv(self._coeff)
if self.sizeAverage then
gradInput:div(self._k)
end
return gradInput
end

function BinaryLogisticRegression:type(type, tensorCache)
if type then
self._baseExponents = nil
self._coeff = nil
self._logCoeff = nil
end
return parent.type(self, type, tensorCache)
end
43 changes: 43 additions & 0 deletions CAddTensorTable.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@

local CAddTensorTable, parent = torch.class('nn.CAddTensorTable', 'nn.Module')

function CAddTensorTable:__init()
parent.__init(self)
self.gradInput = {}
end

-- input is a table with 2 entries. input[1] is the vector to be added.
-- input[2] is the table to which we add the vector
function CAddTensorTable:updateOutput(input)
local currentOutput = {}
for i=1,#input[2] do
currentOutput[i] = currentOutput[i] or input[1].new()
currentOutput[i]:resizeAs(input[1])
currentOutput[i]:copy(input[2][i])
currentOutput[i]:add(input[1])
end
for i = #input[2]+1, #currentOutput do
currentOutput[i] = nil
end
self.output = currentOutput
return self.output
end

function CAddTensorTable:updateGradInput(input, gradOutput)
self.gradInput[1] = self.gradInput[1] or input[1].new()
self.gradInput[1]:resizeAs(input[1])
self.gradInput[1]:copy(gradOutput[1])
for i=2, #input[2] do
self.gradInput[1]:add(gradOutput[i])
end
self.gradInput[2] = self.gradInput[2] or {}
for i=1,#input[2] do
self.gradInput[2][i] = self.gradInput[2][i] or input[1].new()
self.gradInput[2][i]:resizeAs(input[1])
self.gradInput[2][i]:copy(gradOutput[i])
end
for i=#input[2]+1, #self.gradInput[2] do
self.gradInput[2][i] = nil
end
return self.gradInput
end
54 changes: 53 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,58 @@ SET(luasrc
utils.lua
LinearRNN.lua
LookupRNN.lua
ArgMax.lua
BatchNormalization.lua
BinaryClassReward.lua
BinaryLogisticRegression.lua
CAddTensorTable.lua
CategoricalEntropy.lua
Clip.lua
Collapse.lua
Constant.lua
Container.lua
Convert.lua
Criterion.lua
Decorator.lua
Dictionary.lua
DontCast.lua
FireModule.lua
Inception.lua
Kmeans.lua
LookupTable.lua
ModuleCriterion.lua
NCECriterion.lua
NCEModule.lua
NaN.lua
OneHot.lua
PCAColorTransform.lua
ParallelTable.lua
PrintSize.lua
Profile.lua
Reinforce.lua
ReinforceBernoulli.lua
ReinforceCategorical.lua
ReinforceGamma.lua
ReinforceNormal.lua
ReverseTable.lua
Sequential.lua
Serial.lua
SimpleColorTransform.lua
SpatialBatchNormalization.lua
SpatialBinaryConvolution.lua
SpatialBinaryLogisticRegression.lua
SpatialConvolution.lua
SpatialConvolutionMM.lua
SpatialFeatNormalization.lua
SpatialGlimpse.lua
SpatialMaxPooling.lua
SpatialRegionDropout.lua
SpatialUniformCrop.lua
TotalDropout.lua
VRClassReward.lua
WhiteNoise.lua
ZipTable.lua
ZipTableOneToMany.lua
)

ADD_TORCH_PACKAGE(rnn "${src}" "${luasrc}" "An RNN library for Torch")
Expand All @@ -64,4 +116,4 @@ TARGET_LINK_LIBRARIES(rnn luaT TH)

SET_TARGET_PROPERTIES(rnn_static PROPERTIES COMPILE_FLAGS "-fPIC -DSTATIC_TH")

INSTALL(FILES ${luasrc} DESTINATION "${Torch_INSTALL_LUA_PATH_SUBDIR}/rnn")
INSTALL(FILES ${luasrc} DESTINATION "${Torch_INSTALL_LUA_PATH_SUBDIR}/rnn")
Loading

0 comments on commit a155e43

Please sign in to comment.