Skip to content

Commit db4244e

Browse files
Jonathan Uesatosoumith
authored andcommitted
Make CMaxTable and CMinTable cunn-compatible (#954)
1 parent 02ebb69 commit db4244e

File tree

2 files changed

+38
-12
lines changed

2 files changed

+38
-12
lines changed

CMaxTable.lua

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,38 @@ function CMaxTable:__init()
44
parent.__init(self)
55
self.gradInput = {}
66
self.maxIdx = torch.Tensor()
7+
self.mask = torch.Tensor()
8+
self.maxVals = torch.Tensor()
9+
self.gradMaxVals = torch.Tensor()
710
end
811

912
function CMaxTable:updateOutput(input)
1013
self.output:resizeAs(input[1]):copy(input[1])
1114
self.maxIdx:resizeAs(input[1]):fill(1)
1215
for i=2,#input do
13-
local mask = torch.gt(input[i], self.output)
14-
self.maxIdx:maskedFill(mask, i)
15-
self.output:maskedCopy(mask, input[i][mask])
16+
self.maskByteTensor = self.maskByteTensor or
17+
(torch.type(self.output) == 'torch.CudaTensor' and
18+
torch.CudaByteTensor() or torch.ByteTensor())
19+
self.mask:gt(input[i], self.output)
20+
self.maskByteTensor:resize(self.mask:size()):copy(self.mask)
21+
self.maxIdx:maskedFill(self.maskByteTensor, i)
22+
self.maxVals:maskedSelect(input[i], self.maskByteTensor)
23+
self.output:maskedCopy(self.maskByteTensor, self.maxVals)
1624
end
1725
return self.output
1826
end
1927

2028
function CMaxTable:updateGradInput(input, gradOutput)
2129
for i=1,#input do
22-
self.gradInput[i] = input[i].new()
30+
self.gradInput[i] = self.gradInput[i] or input[i].new()
2331
self.gradInput[i]:resizeAs(input[i]):fill(0.0)
24-
local mask = torch.eq(self.maxIdx, i)
25-
self.gradInput[i]:maskedCopy(mask, gradOutput[mask])
32+
self.maskByteTensor = self.maskByteTensor or
33+
(torch.type(self.output) == 'torch.CudaTensor' and
34+
torch.CudaByteTensor() or torch.ByteTensor())
35+
self.mask:eq(self.maxIdx, i)
36+
self.maskByteTensor:resize(self.mask:size()):copy(self.mask)
37+
self.gradMaxVals:maskedSelect(gradOutput, self.maskByteTensor)
38+
self.gradInput[i]:maskedCopy(self.maskByteTensor, self.gradMaxVals)
2639
end
2740

2841
for i=#input+1, #self.gradInput do

CMinTable.lua

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,38 @@ function CMinTable:__init()
44
parent.__init(self)
55
self.gradInput = {}
66
self.minIdx = torch.Tensor()
7+
self.mask = torch.Tensor()
8+
self.minVals = torch.Tensor()
9+
self.gradMaxVals = torch.Tensor()
710
end
811

912
function CMinTable:updateOutput(input)
1013
self.output:resizeAs(input[1]):copy(input[1])
1114
self.minIdx:resizeAs(input[1]):fill(1)
1215
for i=2,#input do
13-
local mask = torch.lt(input[i], self.output)
14-
self.minIdx:maskedFill(mask, i)
15-
self.output:maskedCopy(mask, input[i][mask])
16+
self.maskByteTensor = self.maskByteTensor or
17+
(torch.type(self.output) == 'torch.CudaTensor' and
18+
torch.CudaByteTensor() or torch.ByteTensor())
19+
self.mask:lt(input[i], self.output)
20+
self.maskByteTensor:resize(self.mask:size()):copy(self.mask)
21+
self.minIdx:maskedFill(self.maskByteTensor, i)
22+
self.minVals:maskedSelect(input[i], self.maskByteTensor)
23+
self.output:maskedCopy(self.maskByteTensor, self.minVals)
1624
end
1725
return self.output
1826
end
1927

2028
function CMinTable:updateGradInput(input, gradOutput)
2129
for i=1,#input do
22-
self.gradInput[i] = torch.Tensor()
30+
self.gradInput[i] = self.gradInput[i] or input[i].new()
2331
self.gradInput[i]:resizeAs(input[i]):fill(0.0)
24-
local mask = torch.eq(self.minIdx, i)
25-
self.gradInput[i]:maskedCopy(mask, gradOutput[mask])
32+
self.maskByteTensor = self.maskByteTensor or
33+
(torch.type(self.output) == 'torch.CudaTensor' and
34+
torch.CudaByteTensor() or torch.ByteTensor())
35+
self.mask:eq(self.minIdx, i)
36+
self.maskByteTensor:resize(self.mask:size()):copy(self.mask)
37+
self.gradMaxVals:maskedSelect(gradOutput, self.maskByteTensor)
38+
self.gradInput[i]:maskedCopy(self.maskByteTensor, self.gradMaxVals)
2639
end
2740

2841
for i=#input+1, #self.gradInput do

0 commit comments

Comments
 (0)