diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..567609b --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +build/ diff --git a/MaskZero.lua b/MaskZero.lua index d2ea82b..bb4d1e5 100644 --- a/MaskZero.lua +++ b/MaskZero.lua @@ -68,7 +68,11 @@ function MaskZero:updateOutput(input) local vectorDim = rmi:dim() self._zeroMask = self._zeroMask or rmi.new() self._zeroMask:norm(rmi, 2, vectorDim) - self.zeroMask = self.zeroMask or ((torch.type(rmi) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor()) + self.zeroMask = self.zeroMask or ( + (torch.type(rmi) == 'torch.CudaTensor') and torch.CudaByteTensor() + or (torch.type(rmi) == 'torch.ClTensor') and torch.ClTensor() + or torch.ByteTensor() + ) self._zeroMask.eq(self.zeroMask, self._zeroMask, 0) -- forward through decorated module