From 11e31832dedf53bc9256e6e3004a1d5d5a7bf6c4 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 19 Aug 2016 10:51:42 +0100 Subject: [PATCH] add opencl to MaskZero.lua --- .gitignore | 1 + MaskZero.lua | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..567609b --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +build/ diff --git a/MaskZero.lua b/MaskZero.lua index d2ea82b..bb4d1e5 100644 --- a/MaskZero.lua +++ b/MaskZero.lua @@ -68,7 +68,11 @@ function MaskZero:updateOutput(input) local vectorDim = rmi:dim() self._zeroMask = self._zeroMask or rmi.new() self._zeroMask:norm(rmi, 2, vectorDim) - self.zeroMask = self.zeroMask or ((torch.type(rmi) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor()) + self.zeroMask = self.zeroMask or ( + (torch.type(rmi) == 'torch.CudaTensor') and torch.CudaByteTensor() + or (torch.type(rmi) == 'torch.ClTensor') and torch.ClTensor() + or torch.ByteTensor() + ) self._zeroMask.eq(self.zeroMask, self._zeroMask, 0) -- forward through decorated module