Skip to content

Commit

Permalink
MultiCudaTensor ++
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Jun 3, 2016
1 parent d100ef3 commit 5f51119
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
15 changes: 11 additions & 4 deletions MultiCudaTensor.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@ local MCT = torch.class("torch.MultiCudaTensor")
local buffers1, buffers2 = {}, {}

function MCT:__init(catdim, tensors)
assert(torch.type(catdim) == 'number')
self.catdim = catdim
self.catdim = catdim or -1
self.tensors = tensors or {}
self:size()
end

function MCT:size(dim)
Expand Down Expand Up @@ -43,7 +41,7 @@ function MCT:zero()
end

function MCT:t()
assert(#self._size == 2)
assert(self:size():size() == 2)
return self:transpose(1,2)
end

Expand Down Expand Up @@ -166,6 +164,14 @@ function MCT:add(value, src)
return self
end

-- momGradParams[i]:mul(momFactor)
function MCT:mul(value)
for i,tensor in ipairs(self.tensors) do
cutorch.withDevice(tensor:getDevice(), function() tensor:mul(value) end)
end
return self
end

-- self.weight.addmm(self.linout, 0, self.linout, 1, input, self.weight:t())
-- res = (v1 * M) + (v2 * mat1 * mat2)
function MCT.addmm(res, v1, M, v2, mat1, mat2)
Expand Down Expand Up @@ -224,6 +230,7 @@ function MCT:copy(src)
for i,tensor in ipairs(src.tensors) do
self.tensors[i]:copy(tensor)
end
return self
end

function MCT:write(file)
Expand Down
7 changes: 7 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,13 @@ function torchxtest.MultiCudaTensor()
mytester:assertTensorEq(weight[{{},{1, outputsize/2}}]:float(), mweight.tensors[1]:float(), 0.00001)
mytester:assertTensorEq(weight[{{},{(outputsize/2)+1, outputsize}}]:float(), mweight.tensors[2]:float(), 0.00001)

-- test mul (updateGradParameters)
mweight:mul(2)
weight:mul(2)

mytester:assertTensorEq(weight[{{},{1, outputsize/2}}]:float(), mweight.tensors[1]:float(), 0.00001)
mytester:assertTensorEq(weight[{{},{(outputsize/2)+1, outputsize}}]:float(), mweight.tensors[2]:float(), 0.00001)

-- test addmm
local input = torch.CudaTensor(5, outputsize):uniform(0,1)
local output = torch.CudaTensor(5, inputsize):zero()
Expand Down

0 comments on commit 5f51119

Please sign in to comment.