Skip to content

Commit

Permalink
MultiCudaTensor
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Jun 3, 2016
1 parent 7eeb6ae commit d100ef3
Show file tree
Hide file tree
Showing 3 changed files with 383 additions and 1 deletion.
272 changes: 272 additions & 0 deletions MultiCudaTensor.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
------------------------------------------------------------------------
--[[ MultiCudaTensor ]]--
-- This experimental tensor is used by the NCEModule in dpnn to
-- distribute weight/gradWeight over 2 gpus.
-- The MCT only implements the small fraction of use-cases that the
-- NCEModule requires.
------------------------------------------------------------------------
local MCT = torch.class("torch.MultiCudaTensor")

-- each buffer is indexed by device
local buffers1, buffers2 = {}, {}

function MCT:__init(catdim, tensors)
assert(torch.type(catdim) == 'number')
self.catdim = catdim
self.tensors = tensors or {}
self:size()
end

function MCT:size(dim)
if not self._size then
if #self.tensors == 0 then
self._size = {}
end
self._size = self.tensors[1]:size():totable()
for i=2,#self.tensors do
self._size[self.catdim] = self._size[self.catdim] + self.tensors[i]:size(self.catdim)
end
end
if dim then
return self._size[dim]
end
return torch.LongStorage(self._size)
end

function MCT:zero()
for i,tensor in ipairs(self.tensors) do
cutorch.withDevice(tensor:getDevice(), function()
tensor:zero()
end)
end
return self
end

function MCT:t()
assert(#self._size == 2)
return self:transpose(1,2)
end

function MCT:transpose(dim1, dim2)
local dim = self.catdim
if dim1 == self.catdim then
dim = dim2
elseif dim2 == self.catdim then
dim = dim1
end
local tensors = {}
for i,tensor in ipairs(self.tensors) do
cutorch.withDevice(tensor:getDevice(), function()
tensors[i] = tensor:transpose(dim1, dim2)
end)
end
local result = self.new(dim, tensors)
return result
end

-- self.weight.index(self._weight, self.weight, 1, self.sampleidx:view(-1))
function MCT.index(res, src, dim, indices)
-- we only support a specific use-case
assert(torch.type(res) == 'torch.CudaTensor')
assert(torch.type(src) == 'torch.MultiCudaTensor')
assert(torch.type(dim) == 'number')
assert(dim == 1)
assert(torch.type(indices) == 'torch.CudaTensor' or torch.type(indices) == 'torch.LongTensor')
assert(indices:dim() == 1)
assert(src.catdim ~= dim)

local size = src:size()
size[dim] = indices:size(1)
res:resize(size)

local start = 1
for i,srctensor in ipairs(src.tensors) do
local device = srctensor:getDevice()
local res_ = res:narrow(src.catdim, start, srctensor:size(src.catdim))
local res__ = res_

cutorch.withDevice(device, function()
if device ~= res_:getDevice() then
buffers2[device] = buffers2[device] or res_.new()
buffers2[device]:resizeAs(res_):copy(res_)
res__ = buffers2[device]
end

if torch.type(indices) == 'torch.CudaTensor' and indices:getDevice() ~= device then
buffers1[device] = buffers1[device] or indices.new()
buffers1[device]:resizeAs(indices):copy(indices)
res__:index(srctensor, dim, buffers1[device])
else
res__:index(srctensor, dim, indices)
end

end)

if device ~= res:getDevice() then
res_:copy(res__)
end

start = start + res_:size(src.catdim)
end
return res
end

-- self.gradWeight:indexAdd(1, sampleidx, _gradWeight)
function MCT:indexAdd(dim, indices, src)
assert(torch.type(src) == 'torch.CudaTensor')
assert(torch.type(dim) == 'number')
assert(dim == 1)
assert(self.catdim ~= dim)
assert(torch.type(indices) == 'torch.CudaTensor' or torch.type(indices) == 'torch.LongTensor')

local start = 1
for i,tensor in ipairs(self.tensors) do
local device = tensor:getDevice()
local src_ = src:narrow(self.catdim, start, tensor:size(self.catdim))
local src__ = src_

cutorch.withDevice(device, function()
if device ~= src:getDevice() then
buffers2[device] = buffers2[device] or src.new()
buffers2[device]:resizeAs(src_):copy(src_)
src__ = buffers2[device]
end

if torch.type(indices) == 'torch.CudaTensor' and indices:getDevice() ~= device then
buffers1[device] = buffers1[device] or indices.new()
buffers1[device]:resizeAs(indices):copy(indices)
tensor:indexAdd(dim, buffers1[device], src__)
else
tensor:indexAdd(dim, indices, src__)
end
end)

start = start + src_:size(self.catdim)
end

return self
end

function MCT:add(value, src)
if not src then
src = value
value = 1
end
assert(torch.type(src) == 'torch.MultiCudaTensor')
assert(torch.type(value) == 'number')

for i,srctensor in ipairs(src.tensors) do
local dstdevice = self.tensors[i]:getDevice()
local srcdevice = srctensor:getDevice()
assert(dstdevice == srcdevice)
cutorch.withDevice(srcdevice, function()
self.tensors[i]:add(value, srctensor)
end)
end
return self
end

-- self.weight.addmm(self.linout, 0, self.linout, 1, input, self.weight:t())
-- res = (v1 * M) + (v2 * mat1 * mat2)
function MCT.addmm(res, v1, M, v2, mat1, mat2)
-- we only support a specific use-case
assert(mat2.catdim == 1)
assert(torch.type(mat2) == 'torch.MultiCudaTensor')
assert(torch.type(mat1) == 'torch.CudaTensor')
assert(torch.type(M) == 'torch.CudaTensor' and torch.pointer(M) == torch.pointer(res))
assert(torch.type(res) == 'torch.CudaTensor')
res:mul(v1)

local start = 1
local lastres = res
for i,mat2_ in ipairs(mat2.tensors) do
local mat1_ = mat1:narrow(2, start, mat2_:size(1))
local device = mat2_:getDevice()

cutorch.withDevice(device, function()
if device ~= mat1_:getDevice() then
buffers2[device] = buffers2[device] or mat1_.new()
buffers2[device]:resizeAs(mat1_):copy(mat1_)
mat1_ = buffers2[device]
end

buffers1[device] = buffers1[device] or lastres.new()
buffers1[device]:resizeAs(res)
buffers1[device]:mm(mat1_, mat2_)
end)

local resdevice = res:getDevice()
if device == resdevice then
res:add(v2, buffers1[device])
else
buffers1[resdevice]:resizeAs(res):copy(buffers1[device])
res:add(v2, buffers1[resdevice])
end

start = start + mat2_:size(1)
end

assert(start-1 == mat2:size(1))
return res
end

-- gradParam.new():resizeAs(gradParam):copy(gradParam)
function MCT:resizeAs(src)
self.catdim = src.catdim
for i,tensor in ipairs(src.tensors) do
self.tensors[i] = self.tensors[i] or tensor.new()
cutorch.withDevice(tensor:getDevice(), function() self.tensors[i]:resizeAs(tensor) end)
end
return self
end

function MCT:copy(src)
for i,tensor in ipairs(src.tensors) do
self.tensors[i]:copy(tensor)
end
end

function MCT:write(file)
-- Write all values in the object as a table.
local object = {}
local tensors = self.tensors
self.tensors = nil
for k, v in pairs(self) do
object[k] = v
end

file:writeObject(object)
file:writeObject(#tensors)

for i,tensor in ipairs(tensors) do
file:writeObject(tensor:getDevice())
file:writeObject(tensor)
end

self.tensors = tensors
end

function MCT:read(file)
local object = file:readObject()
for k, v in pairs(object) do
self[k] = v
end

self.tensors = {}

local N = file:readObject()

for i=1,N do
local device = file:readObject()
self.tensors[i] = cutorch.withDevice(device, function() return file:readObject() end)
end
end

function MCT:clone()
local f = torch.MemoryFile("rw"):binary()
f:writeObject(self)
f:seek(1)
local clone = f:readObject()
f:close()
return clone
end
3 changes: 2 additions & 1 deletion init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ require 'paths'
require 'sys'
ffi = require 'ffi'

torchx = {Tensor={}}
torchx = {Tensor={}, version=1}


torch.include('torchx', 'extend.lua')
Expand All @@ -18,6 +18,7 @@ torch.include('torchx', 'dkjson.lua')
torch.include('torchx', 'recursivetensor.lua')
torch.include('torchx', 'Queue.lua')
torch.include('torchx', 'AliasMultinomial.lua')
torch.include('torchx', 'MultiCudaTensor.lua')

torch.include('torchx', 'test.lua')

Expand Down
Loading

0 comments on commit d100ef3

Please sign in to comment.