diff --git a/CMaxTable.lua b/CMaxTable.lua index 62cede95f..845e38d23 100644 --- a/CMaxTable.lua +++ b/CMaxTable.lua @@ -4,25 +4,38 @@ function CMaxTable:__init() parent.__init(self) self.gradInput = {} self.maxIdx = torch.Tensor() + self.mask = torch.Tensor() + self.maxVals = torch.Tensor() + self.gradMaxVals = torch.Tensor() end function CMaxTable:updateOutput(input) self.output:resizeAs(input[1]):copy(input[1]) self.maxIdx:resizeAs(input[1]):fill(1) for i=2,#input do - local mask = torch.gt(input[i], self.output) - self.maxIdx:maskedFill(mask, i) - self.output:maskedCopy(mask, input[i][mask]) + self.maskByteTensor = self.maskByteTensor or + (torch.type(self.output) == 'torch.CudaTensor' and + torch.CudaByteTensor() or torch.ByteTensor()) + self.mask:gt(input[i], self.output) + self.maskByteTensor:resize(self.mask:size()):copy(self.mask) + self.maxIdx:maskedFill(self.maskByteTensor, i) + self.maxVals:maskedSelect(input[i], self.maskByteTensor) + self.output:maskedCopy(self.maskByteTensor, self.maxVals) end return self.output end function CMaxTable:updateGradInput(input, gradOutput) for i=1,#input do - self.gradInput[i] = input[i].new() + self.gradInput[i] = self.gradInput[i] or input[i].new() self.gradInput[i]:resizeAs(input[i]):fill(0.0) - local mask = torch.eq(self.maxIdx, i) - self.gradInput[i]:maskedCopy(mask, gradOutput[mask]) + self.maskByteTensor = self.maskByteTensor or + (torch.type(self.output) == 'torch.CudaTensor' and + torch.CudaByteTensor() or torch.ByteTensor()) + self.mask:eq(self.maxIdx, i) + self.maskByteTensor:resize(self.mask:size()):copy(self.mask) + self.gradMaxVals:maskedSelect(gradOutput, self.maskByteTensor) + self.gradInput[i]:maskedCopy(self.maskByteTensor, self.gradMaxVals) end for i=#input+1, #self.gradInput do diff --git a/CMinTable.lua b/CMinTable.lua index a8385e849..25b9a19a2 100644 --- a/CMinTable.lua +++ b/CMinTable.lua @@ -4,25 +4,38 @@ function CMinTable:__init() parent.__init(self) self.gradInput = {} self.minIdx = torch.Tensor() + self.mask = torch.Tensor() + self.minVals = torch.Tensor() + self.gradMaxVals = torch.Tensor() end function CMinTable:updateOutput(input) self.output:resizeAs(input[1]):copy(input[1]) self.minIdx:resizeAs(input[1]):fill(1) for i=2,#input do - local mask = torch.lt(input[i], self.output) - self.minIdx:maskedFill(mask, i) - self.output:maskedCopy(mask, input[i][mask]) + self.maskByteTensor = self.maskByteTensor or + (torch.type(self.output) == 'torch.CudaTensor' and + torch.CudaByteTensor() or torch.ByteTensor()) + self.mask:lt(input[i], self.output) + self.maskByteTensor:resize(self.mask:size()):copy(self.mask) + self.minIdx:maskedFill(self.maskByteTensor, i) + self.minVals:maskedSelect(input[i], self.maskByteTensor) + self.output:maskedCopy(self.maskByteTensor, self.minVals) end return self.output end function CMinTable:updateGradInput(input, gradOutput) for i=1,#input do - self.gradInput[i] = torch.Tensor() + self.gradInput[i] = self.gradInput[i] or input[i].new() self.gradInput[i]:resizeAs(input[i]):fill(0.0) - local mask = torch.eq(self.minIdx, i) - self.gradInput[i]:maskedCopy(mask, gradOutput[mask]) + self.maskByteTensor = self.maskByteTensor or + (torch.type(self.output) == 'torch.CudaTensor' and + torch.CudaByteTensor() or torch.ByteTensor()) + self.mask:eq(self.minIdx, i) + self.maskByteTensor:resize(self.mask:size()):copy(self.mask) + self.gradMaxVals:maskedSelect(gradOutput, self.maskByteTensor) + self.gradInput[i]:maskedCopy(self.maskByteTensor, self.gradMaxVals) end for i=#input+1, #self.gradInput do