Skip to content

Commit

Permalink
small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Leonard committed Jul 20, 2017
1 parent f1945f5 commit 3371428
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions ArgMax.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ local ArgMax, parent = torch.class("nn.ArgMax", "nn.Module")
function ArgMax:__init(dim, nInputDim, asLong)
parent.__init(self)
self.dim = dim or 1
self.nInputDim = nInputDim or 9999
self.nInputDim = nInputDim or 1
self.asLong = (asLong == nil) and true or asLong
if self.asLong then
self.output = torch.LongTensor()
Expand All @@ -21,7 +21,7 @@ function ArgMax:updateOutput(input)
self._indices = self._indices or
(torch.type(input) == 'torch.CudaTensor' and (torch.CudaLongTensor and torch.CudaLongTensor() or torch.CudaTensor()) or torch.LongTensor())
local dim = (input:dim() > self.nInputDim) and (self.dim + 1) or self.dim

torch.max(self._value, self._indices, input, dim)
if input:dim() > 1 then
local idx = self._indices:select(dim, 1)
Expand Down
2 changes: 1 addition & 1 deletion Criterion.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ function Criterion:setZeroMask(zeroMask)
end

function Criterion:clearState()
return nn.utils.clear(self, 'gradInput')
return nn.utils.clear(self, 'gradInput')
end

0 comments on commit 3371428

Please sign in to comment.