Skip to content

Commit

Permalink
MultiCudaTensor++
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Jun 3, 2016
1 parent 5f51119 commit 164493f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
23 changes: 23 additions & 0 deletions MultiCudaTensor.lua
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ function MCT:size(dim)
return torch.LongStorage(self._size)
end

function MCT:dim()
return self:size():size()
end

function MCT:zero()
for i,tensor in ipairs(self.tensors) do
cutorch.withDevice(tensor:getDevice(), function()
Expand Down Expand Up @@ -277,3 +281,22 @@ function MCT:clone()
f:close()
return clone
end

function MCT:uniform(lower, upper)
for i,tensor in ipairs(self.tensors) do
cutorch.withDevice(tensor:getDevice(), function() tensor:uniform(lower, upper) end)
end
return self
end

-- math.pow(gradParam:norm(),2)
function MCT:norm(...)
assert(#{...} == 0)
local norm = 0
for i,tensor in ipairs(self.tensors) do
norm = norm + cutorch.withDevice(tensor:getDevice(), function() return math.pow(tensor:norm(),2) end)
end
return math.sqrt(norm)
end

assert(not MCT.storage, "If you ever define storage, you will need to modify Module.sharedClone in dpnn.Module")
14 changes: 14 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ function torchxtest.MultiCudaTensor()
mytester:assert(mweight:size(1) == inputsize)
mytester:assert(mweight:size(2) == outputsize)

-- test dim
mytester:assert(mweight:dim() == 2)

-- test transpose
local mwt = mweight:t()
mytester:assert(mwt.catdim == 1)
Expand Down Expand Up @@ -210,6 +213,11 @@ function torchxtest.MultiCudaTensor()

mytester:assertTensorEq(output, output2, 0.0001)

-- test norm
local norm = mweight:norm()
local norm2 = weight:norm()
mytester:assert(math.abs(norm - norm2) < 0.0001)

-- test zero
mweight:zero()
for i=1,2 do
Expand Down Expand Up @@ -244,7 +252,13 @@ function torchxtest.MultiCudaTensor()
cutorch.withDevice(1, function() mytester:assertTensorEq(mw2.tensors[1], mweight.tensors[1], 0.000001) end)
cutorch.withDevice(2, function() mytester:assertTensorEq(mw2.tensors[2], mweight.tensors[2], 0.000001) end)

-- test uniform
mw2:uniform(-2, -1)
cutorch.withDevice(1, function() mytester:assert(mw2.tensors[1]:min() >= -2 and mw2.tensors[1]:max() <= -1) end)
cutorch.withDevice(2, function() mytester:assert(mw2.tensors[2]:min() >= -2 and mw2.tensors[2]:max() <= -1) end)

mytester:assert(cutorch.getDevice() == origdevice)

end

function torchx.test(tests)
Expand Down

0 comments on commit 164493f

Please sign in to comment.