Skip to content

Commit

Permalink
add opencl to MaskZero.lua
Browse files Browse the repository at this point in the history
  • Loading branch information
hughperkins committed Aug 19, 2016
1 parent 88cbe99 commit 11e3183
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
build/
6 changes: 5 additions & 1 deletion MaskZero.lua
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ function MaskZero:updateOutput(input)
local vectorDim = rmi:dim()
self._zeroMask = self._zeroMask or rmi.new()
self._zeroMask:norm(rmi, 2, vectorDim)
self.zeroMask = self.zeroMask or ((torch.type(rmi) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor())
self.zeroMask = self.zeroMask or (
(torch.type(rmi) == 'torch.CudaTensor') and torch.CudaByteTensor()
or (torch.type(rmi) == 'torch.ClTensor') and torch.ClTensor()
or torch.ByteTensor()
)
self._zeroMask.eq(self.zeroMask, self._zeroMask, 0)

-- forward through decorated module
Expand Down

0 comments on commit 11e3183

Please sign in to comment.