Skip to content

Commit

Permalink
DontCast supports castin castout args
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Jul 16, 2015
1 parent 401e338 commit 8d7e6c0
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 0 deletions.
76 changes: 76 additions & 0 deletions DontCast.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,82 @@
local DontCast, parent = torch.class("nn.DontCast", "nn.Decorator")

function DontCast:__init(module, castin, castout, moduleType)
parent.__init(self, module)
self.castin = castin
self.castout = (castout == nil) and castin or castout
self.moduleType = moduleType
if not self.moduleType then
assert(torch.isTensor(module.output), "cannot extrapolate module type")
self.moduleType = torch.typename(module.output)
end
end

function DontCast:updateOutput(input)
if self.castin and torch.type(input) ~= self.moduleType then
self._input = self._input or torch.getmetatable(self.moduleType).new()
self._input:resize(input:size()):copy(input)
input = self._input
end

local output = self.module:updateOutput(input)

if self.castout then
self.output:resize(output:size()):copy(output)
else
self.output = output
end
return self.output
end

function DontCast:updateGradInput(input, gradOutput)
if self.castin and torch.type(input) ~= self.moduleType then
input = self._input
end
if self.castout and torch.type(gradOutput) ~= self.moduleType then
self._gradOutput = self._gradOutput or torch.getmetatable(self.moduleType).new()
self._gradOutput:resize(gradOutput:size()):copy(gradOutput)
gradOutput = self._gradOutput
end

local gradInput = self.module:updateGradInput(input, gradOutput)

if self.castin then
self.gradInput:resize(gradInput:size()):copy(gradInput)
else
self.gradInput = gradInput
end
return self.gradInput
end

function DontCast:accGradParameters(input, gradOutput, scale)
if self.castin and torch.type(input) ~= self.moduleType then
input = self._input
end
if self.castout and torch.type(gradOutput) ~= self.moduleType then
gradOutput = self._gradOutput
end

self.module:accGradParameters(input, gradOutput, scale)
end

function DontCast:accUpdateGradParameters(input, gradOutput, lr)
if self.castin and torch.type(input) ~= self.moduleType then
input = self._input
end
if self.castout and torch.type(gradOutput) ~= self.moduleType then
gradOutput = self._gradOutput
end

self.module:accUpdateGradParameters(input, gradOutput, lr)
end

-- dont cast
function DontCast:type(type)
if self.castout then
self.output = self.output:type(type)
end
if self.castin then
self.gradInput = self.gradInput:type(type)
end
return self
end
40 changes: 40 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,46 @@ function dpnntest.SpatialUniformCrop()
end
end

function dpnntest.DontCast()
local input = torch.randn(3,4)
local gradOutput = torch.randn(3,2)
local linear = nn.Linear(4,2):float()
local mlp = nn.DontCast(linear, true)
linear:zeroGradParameters()
local linear = linear:clone()
local output = mlp:forward(input)
local gradInput = mlp:backward(input, gradOutput)
mytester:assert(torch.type(output) == 'torch.DoubleTensor')
mytester:assert(torch.type(gradInput) == 'torch.DoubleTensor')
local output2 = linear:forward(input:float())
local gradInput2 = linear:backward(input:float(), gradOutput:float())
mytester:assertTensorEq(output:float(), output2, 0.000001)
mytester:assertTensorEq(gradInput:float(), gradInput2, 0.000001)
local mlp3 = nn.DontCast(linear:clone())
mlp3:zeroGradParameters()
local output3 = mlp3:forward(input:float())
local gradInput3 = mlp3:backward(input:float(), gradOutput:float())
mytester:assert(torch.type(output3) == 'torch.FloatTensor')
mytester:assert(torch.type(gradInput3) == 'torch.FloatTensor')
mytester:assertTensorEq(output3, output2, 0.000001)
mytester:assertTensorEq(gradInput3, gradInput2, 0.000001)
mlp:float()
local output4 = mlp:forward(input:float())
local gradInput4 = mlp:backward(input:float(), gradOutput:float())
mytester:assert(torch.type(output4) == 'torch.FloatTensor')
mytester:assert(torch.type(gradInput4) == 'torch.FloatTensor')
mytester:assertTensorEq(output3, output4, 0.000001)
mytester:assertTensorEq(gradInput3, gradInput4, 0.000001)
mlp:double()
mytester:assert(torch.type(linear.output) == 'torch.FloatTensor')
local output = mlp:forward(input)
local gradInput = mlp:backward(input, gradOutput)
mytester:assert(torch.type(output4) == 'torch.FloatTensor')
mytester:assert(torch.type(gradInput4) == 'torch.FloatTensor')
mytester:assertTensorEq(output3, output:float(), 0.000001)
mytester:assertTensorEq(gradInput3, gradInput:float(), 0.000001)
end

function dpnntest.ModuleCriterion()
local input = torch.randn(8,4)
local target = torch.randn(8,4)
Expand Down

0 comments on commit 8d7e6c0

Please sign in to comment.