From d100ef327a02c4a44ed08045376930ed75fd1559 Mon Sep 17 00:00:00 2001 From: nicholas-leonard Date: Fri, 3 Jun 2016 13:44:28 -0400 Subject: [PATCH] MultiCudaTensor --- MultiCudaTensor.lua | 272 ++++++++++++++++++++++++++++++++++++++++++++ init.lua | 3 +- test/test.lua | 109 ++++++++++++++++++ 3 files changed, 383 insertions(+), 1 deletion(-) create mode 100644 MultiCudaTensor.lua diff --git a/MultiCudaTensor.lua b/MultiCudaTensor.lua new file mode 100644 index 0000000..2339d07 --- /dev/null +++ b/MultiCudaTensor.lua @@ -0,0 +1,272 @@ +------------------------------------------------------------------------ +--[[ MultiCudaTensor ]]-- +-- This experimental tensor is used by the NCEModule in dpnn to +-- distribute weight/gradWeight over 2 gpus. +-- The MCT only implements the small fraction of use-cases that the +-- NCEModule requires. +------------------------------------------------------------------------ +local MCT = torch.class("torch.MultiCudaTensor") + +-- each buffer is indexed by device +local buffers1, buffers2 = {}, {} + +function MCT:__init(catdim, tensors) + assert(torch.type(catdim) == 'number') + self.catdim = catdim + self.tensors = tensors or {} + self:size() +end + +function MCT:size(dim) + if not self._size then + if #self.tensors == 0 then + self._size = {} + end + self._size = self.tensors[1]:size():totable() + for i=2,#self.tensors do + self._size[self.catdim] = self._size[self.catdim] + self.tensors[i]:size(self.catdim) + end + end + if dim then + return self._size[dim] + end + return torch.LongStorage(self._size) +end + +function MCT:zero() + for i,tensor in ipairs(self.tensors) do + cutorch.withDevice(tensor:getDevice(), function() + tensor:zero() + end) + end + return self +end + +function MCT:t() + assert(#self._size == 2) + return self:transpose(1,2) +end + +function MCT:transpose(dim1, dim2) + local dim = self.catdim + if dim1 == self.catdim then + dim = dim2 + elseif dim2 == self.catdim then + dim = dim1 + end + local tensors = {} + for i,tensor in ipairs(self.tensors) do + cutorch.withDevice(tensor:getDevice(), function() + tensors[i] = tensor:transpose(dim1, dim2) + end) + end + local result = self.new(dim, tensors) + return result +end + +-- self.weight.index(self._weight, self.weight, 1, self.sampleidx:view(-1)) +function MCT.index(res, src, dim, indices) + -- we only support a specific use-case + assert(torch.type(res) == 'torch.CudaTensor') + assert(torch.type(src) == 'torch.MultiCudaTensor') + assert(torch.type(dim) == 'number') + assert(dim == 1) + assert(torch.type(indices) == 'torch.CudaTensor' or torch.type(indices) == 'torch.LongTensor') + assert(indices:dim() == 1) + assert(src.catdim ~= dim) + + local size = src:size() + size[dim] = indices:size(1) + res:resize(size) + + local start = 1 + for i,srctensor in ipairs(src.tensors) do + local device = srctensor:getDevice() + local res_ = res:narrow(src.catdim, start, srctensor:size(src.catdim)) + local res__ = res_ + + cutorch.withDevice(device, function() + if device ~= res_:getDevice() then + buffers2[device] = buffers2[device] or res_.new() + buffers2[device]:resizeAs(res_):copy(res_) + res__ = buffers2[device] + end + + if torch.type(indices) == 'torch.CudaTensor' and indices:getDevice() ~= device then + buffers1[device] = buffers1[device] or indices.new() + buffers1[device]:resizeAs(indices):copy(indices) + res__:index(srctensor, dim, buffers1[device]) + else + res__:index(srctensor, dim, indices) + end + + end) + + if device ~= res:getDevice() then + res_:copy(res__) + end + + start = start + res_:size(src.catdim) + end + return res +end + +-- self.gradWeight:indexAdd(1, sampleidx, _gradWeight) +function MCT:indexAdd(dim, indices, src) + assert(torch.type(src) == 'torch.CudaTensor') + assert(torch.type(dim) == 'number') + assert(dim == 1) + assert(self.catdim ~= dim) + assert(torch.type(indices) == 'torch.CudaTensor' or torch.type(indices) == 'torch.LongTensor') + + local start = 1 + for i,tensor in ipairs(self.tensors) do + local device = tensor:getDevice() + local src_ = src:narrow(self.catdim, start, tensor:size(self.catdim)) + local src__ = src_ + + cutorch.withDevice(device, function() + if device ~= src:getDevice() then + buffers2[device] = buffers2[device] or src.new() + buffers2[device]:resizeAs(src_):copy(src_) + src__ = buffers2[device] + end + + if torch.type(indices) == 'torch.CudaTensor' and indices:getDevice() ~= device then + buffers1[device] = buffers1[device] or indices.new() + buffers1[device]:resizeAs(indices):copy(indices) + tensor:indexAdd(dim, buffers1[device], src__) + else + tensor:indexAdd(dim, indices, src__) + end + end) + + start = start + src_:size(self.catdim) + end + + return self +end + +function MCT:add(value, src) + if not src then + src = value + value = 1 + end + assert(torch.type(src) == 'torch.MultiCudaTensor') + assert(torch.type(value) == 'number') + + for i,srctensor in ipairs(src.tensors) do + local dstdevice = self.tensors[i]:getDevice() + local srcdevice = srctensor:getDevice() + assert(dstdevice == srcdevice) + cutorch.withDevice(srcdevice, function() + self.tensors[i]:add(value, srctensor) + end) + end + return self +end + +-- self.weight.addmm(self.linout, 0, self.linout, 1, input, self.weight:t()) +-- res = (v1 * M) + (v2 * mat1 * mat2) +function MCT.addmm(res, v1, M, v2, mat1, mat2) + -- we only support a specific use-case + assert(mat2.catdim == 1) + assert(torch.type(mat2) == 'torch.MultiCudaTensor') + assert(torch.type(mat1) == 'torch.CudaTensor') + assert(torch.type(M) == 'torch.CudaTensor' and torch.pointer(M) == torch.pointer(res)) + assert(torch.type(res) == 'torch.CudaTensor') + res:mul(v1) + + local start = 1 + local lastres = res + for i,mat2_ in ipairs(mat2.tensors) do + local mat1_ = mat1:narrow(2, start, mat2_:size(1)) + local device = mat2_:getDevice() + + cutorch.withDevice(device, function() + if device ~= mat1_:getDevice() then + buffers2[device] = buffers2[device] or mat1_.new() + buffers2[device]:resizeAs(mat1_):copy(mat1_) + mat1_ = buffers2[device] + end + + buffers1[device] = buffers1[device] or lastres.new() + buffers1[device]:resizeAs(res) + buffers1[device]:mm(mat1_, mat2_) + end) + + local resdevice = res:getDevice() + if device == resdevice then + res:add(v2, buffers1[device]) + else + buffers1[resdevice]:resizeAs(res):copy(buffers1[device]) + res:add(v2, buffers1[resdevice]) + end + + start = start + mat2_:size(1) + end + + assert(start-1 == mat2:size(1)) + return res +end + +-- gradParam.new():resizeAs(gradParam):copy(gradParam) +function MCT:resizeAs(src) + self.catdim = src.catdim + for i,tensor in ipairs(src.tensors) do + self.tensors[i] = self.tensors[i] or tensor.new() + cutorch.withDevice(tensor:getDevice(), function() self.tensors[i]:resizeAs(tensor) end) + end + return self +end + +function MCT:copy(src) + for i,tensor in ipairs(src.tensors) do + self.tensors[i]:copy(tensor) + end +end + +function MCT:write(file) + -- Write all values in the object as a table. + local object = {} + local tensors = self.tensors + self.tensors = nil + for k, v in pairs(self) do + object[k] = v + end + + file:writeObject(object) + file:writeObject(#tensors) + + for i,tensor in ipairs(tensors) do + file:writeObject(tensor:getDevice()) + file:writeObject(tensor) + end + + self.tensors = tensors +end + +function MCT:read(file) + local object = file:readObject() + for k, v in pairs(object) do + self[k] = v + end + + self.tensors = {} + + local N = file:readObject() + + for i=1,N do + local device = file:readObject() + self.tensors[i] = cutorch.withDevice(device, function() return file:readObject() end) + end +end + +function MCT:clone() + local f = torch.MemoryFile("rw"):binary() + f:writeObject(self) + f:seek(1) + local clone = f:readObject() + f:close() + return clone +end diff --git a/init.lua b/init.lua index 3c95a54..6df056c 100644 --- a/init.lua +++ b/init.lua @@ -3,7 +3,7 @@ require 'paths' require 'sys' ffi = require 'ffi' -torchx = {Tensor={}} +torchx = {Tensor={}, version=1} torch.include('torchx', 'extend.lua') @@ -18,6 +18,7 @@ torch.include('torchx', 'dkjson.lua') torch.include('torchx', 'recursivetensor.lua') torch.include('torchx', 'Queue.lua') torch.include('torchx', 'AliasMultinomial.lua') +torch.include('torchx', 'MultiCudaTensor.lua') torch.include('torchx', 'test.lua') diff --git a/test/test.lua b/test/test.lua index 77cf54e..d9b7dcd 100644 --- a/test/test.lua +++ b/test/test.lua @@ -131,6 +131,115 @@ function torchxtest.AliasMultinomial() mytester:assertTensorEq(probs, counts, 0.001) end +function torchxtest.MultiCudaTensor() + if not pcall(function() require 'cutorch' end) then + return + end + + if cutorch.getDeviceCount() < 2 then + return + end + + local origdevice = cutorch.getDevice() + + local inputsize, outputsize = 200, 100 + local weight = torch.CudaTensor(inputsize, outputsize):uniform(0,1) + local tensors = { + cutorch.withDevice(1, function() return weight[{{},{1, outputsize/2}}]:clone() end), + cutorch.withDevice(2, function() return weight[{{},{(outputsize/2)+1, outputsize}}]:clone() end) + } + local mweight = torch.MultiCudaTensor(2, tensors) + mytester:assert(mweight.catdim == 2) + + -- test size + mytester:assertTableEq(mweight:size():totable(), {inputsize, outputsize}, 0.000001) + mytester:assert(mweight:size(1) == inputsize) + mytester:assert(mweight:size(2) == outputsize) + + -- test transpose + local mwt = mweight:t() + mytester:assert(mwt.catdim == 1) + mytester:assertTableEq(mwt:size():totable(), {outputsize, inputsize}, 0.000001) + + -- test index + local nindex = 3 + local res = torch.CudaTensor() + local indices = torch.LongTensor(nindex):random(1,inputsize):cuda() + mweight.index(res, mweight, 1, indices) + + local res2 = torch.CudaTensor() + weight.index(res2, weight, 1, indices) + + mytester:assert(res:getDevice() == res2:getDevice()) + mytester:assertTensorEq(res, res2, 0.00001) + + mytester:assertTensorEq(weight[{{},{1, outputsize/2}}]:float(), mweight.tensors[1]:float(), 0.00001) + mytester:assertTensorEq(weight[{{},{(outputsize/2)+1, outputsize}}]:float(), mweight.tensors[2]:float(), 0.00001) + + -- test indexAdd + + local src = torch.CudaTensor(nindex, outputsize):fill(1) + + weight:indexAdd(1, indices, src) + mweight:indexAdd(1, indices, src) + + mytester:assertTensorEq(weight[{{},{1, outputsize/2}}]:float(), mweight.tensors[1]:float(), 0.00001) + mytester:assertTensorEq(weight[{{},{(outputsize/2)+1, outputsize}}]:float(), mweight.tensors[2]:float(), 0.00001) + + -- test add (updateParameters) + mweight:add(1, mweight) + weight:add(1, weight) + + mytester:assertTensorEq(weight[{{},{1, outputsize/2}}]:float(), mweight.tensors[1]:float(), 0.00001) + mytester:assertTensorEq(weight[{{},{(outputsize/2)+1, outputsize}}]:float(), mweight.tensors[2]:float(), 0.00001) + + -- test addmm + local input = torch.CudaTensor(5, outputsize):uniform(0,1) + local output = torch.CudaTensor(5, inputsize):zero() + mweight.addmm(output, 0, output, 1, input, mweight:t()) + + local output2 = output:clone():zero() + weight.addmm(output2, 0, output2, 1, input, weight:t()) + + mytester:assertTensorEq(output, output2, 0.0001) + + -- test zero + mweight:zero() + for i=1,2 do + cutorch.withDevice(i, function() + mytester:assert(mweight.tensors[i]:sum() == 0) + end) + end + + -- test clone + local mw2 = mweight:clone() + mytester:assert(mw2.tensors[1]:getDevice() == mweight.tensors[1]:getDevice()) + mytester:assert(mw2.tensors[2]:getDevice() == mweight.tensors[2]:getDevice()) + cutorch.withDevice(1, function() mytester:assertTensorEq(mw2.tensors[1], mweight.tensors[1], 0.000001) end) + cutorch.withDevice(2, function() mytester:assertTensorEq(mw2.tensors[2], mweight.tensors[2], 0.000001) end) + + -- test resizeAs + mw2.tensors[1]:resize(mw2.tensors[1]:size(1)/2, mw2.tensors[1]:size(2)) + mw2.tensors[2]:resize(mw2.tensors[2]:size(1)/2, mw2.tensors[2]:size(2)) + + mw2:resizeAs(mweight) + + mytester:assertTableEq(mw2.tensors[1]:size():totable(), mweight.tensors[1]:size():totable(), 0.000001) + mytester:assertTableEq(mw2.tensors[2]:size():totable(), mweight.tensors[2]:size():totable(), 0.000001) + cutorch.withDevice(1, function() mytester:assertTensorEq(mw2.tensors[1], mweight.tensors[1], 0.000001) end) + cutorch.withDevice(2, function() mytester:assertTensorEq(mw2.tensors[2], mweight.tensors[2], 0.000001) end) + + -- test copy + cutorch.withDevice(1, function() mweight.tensors[1]:uniform(0,1) end) + cutorch.withDevice(2, function() mweight.tensors[2]:uniform(0,1) end) + + mw2:copy(mweight) + cutorch.withDevice(1, function() mytester:assertTensorEq(mw2.tensors[1], mweight.tensors[1], 0.000001) end) + cutorch.withDevice(2, function() mytester:assertTensorEq(mw2.tensors[2], mweight.tensors[2], 0.000001) end) + + mytester:assert(cutorch.getDevice() == origdevice) +end + function torchx.test(tests) local oldtype = torch.getdefaulttensortype() torch.setdefaulttensortype('torch.FloatTensor')