From fc5c20c7c1b94dc41fbb2c26a7f7fe888c2312f7 Mon Sep 17 00:00:00 2001 From: nicholas-leonard Date: Fri, 3 Jun 2016 17:03:56 -0400 Subject: [PATCH] multigpu runs --- Container.lua | 22 ++++++++++++++++++++++ GPU.lua | 26 ++++++++++++++++++++++++++ Module.lua | 39 +++++++++++++++++++++++++++++++-------- test/test.lua | 5 +++-- 4 files changed, 82 insertions(+), 10 deletions(-) diff --git a/Container.lua b/Container.lua index 2c6d2b7..bbf9af0 100644 --- a/Container.lua +++ b/Container.lua @@ -27,3 +27,25 @@ function Container:sparseParameters() end return params, gradParams, scales, size end + +function Container:parameters() + local function tinsert(to, from) + if torch.type(from) == 'table' then -- we change this line so that it works with torch.MultiCudaTensor + for i=1,#from do + tinsert(to,from[i]) + end + else + table.insert(to,from) + end + end + local w = {} + local gw = {} + for i=1,#self.modules do + local mw,mgw = self.modules[i]:parameters() + if mw then + tinsert(w,mw) + tinsert(gw,mgw) + end + end + return w,gw +end diff --git a/GPU.lua b/GPU.lua index 6b1234c..560227f 100644 --- a/GPU.lua +++ b/GPU.lua @@ -80,3 +80,29 @@ function GPU:fromBatch(...) return parent.fromBatch(self, unpack(args)) end end + +-- set the device of the decorated module +function GPU:setDevice(device) + self.device = device or self.device + + local function recursiveModuleDevice(obj) + if type(obj) == 'table' and not (torch.isTypeOf(obj, 'nn.GPU') or torch.type(obj) == 'torch.MultiCudaTensor') then + for k,v in pairs(obj) do + obj[k] = recursiveModuleDevice(v) + end + elseif torch.type(obj):match('torch.Cuda.*Tensor') then + if obj:getDevice() ~= self.device then + obj = obj:clone() -- this will reallocate it to self.device + local newdevice = obj:getDevice() + -- when nElement() == 0 newdevice is 0 + assert(newdevice == self.device or newdevice == 0) + end + end + assert(obj ~= nil) + return obj + end + + assert(self.modules[1]) + self.modules[1] = cutorch.withDevice(self.device, function() return recursiveModuleDevice(self.modules[1]) end) + return self +end diff --git a/Module.lua b/Module.lua index 6e95d12..7fc5173 100644 --- a/Module.lua +++ b/Module.lua @@ -69,7 +69,7 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone) if param then params[paramName] = param obj[paramName] = nil - if param:storage() then + if torch.isTensor(param) and param.storage and param:storage() then pointers[torch.pointer(param:storage():data())] = true end end @@ -82,7 +82,7 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone) if gradParam then params[paramName] = gradParam obj[paramName] = nil - if gradParam:storage() then + if torch.isTensor(gradParam) and gradParam.storage and gradParam:storage() then pointers[torch.pointer(gradParam:storage():data())] = true end end @@ -144,8 +144,13 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone) clone[k] = param original[k] = param elseif torch.isTensor(param) then - clone[k] = param.new():set(param) - original[k] = param + if param.storage then + clone[k] = param.new():set(param) + original[k] = param + else -- for torch.MultiCudaTensor + clone[k] = param + original[k] = param + end elseif type(param) == 'table' then recursiveSet(clone[k], original[k], param) end @@ -397,7 +402,7 @@ function Module:gradParamClip(cutoffNorm, moduleLocal) local norm = 0 if moduleLocal and self.modules then for i,module in ipairs(self.modules) do - norm = norm + math.pow(module:gradParamClip(maxOutNorm, maxInNorm), 2) + norm = norm + math.pow(module:gradParamClip(cutoffNorm, moduleLocal), 2) end norm = math.sqrt(norm) else @@ -406,13 +411,25 @@ function Module:gradParamClip(cutoffNorm, moduleLocal) return norm end for k,gradParam in pairs(gradParams) do -- pairs for sparse params - norm = norm + math.pow(gradParam:norm(),2) + if torch.type(gradParam) == 'torch.CudaTensor' then + cutorch.withDevice(gradParam:getDevice(), function() -- support multi-device models + norm = norm + math.pow(gradParam:norm(),2) + end) + else + norm = norm + math.pow(gradParam:norm(),2) + end end norm = math.sqrt(norm) if norm > cutoffNorm then -- rescale gradParams to obtain desired cutoffNorm for k,gradParam in pairs(gradParams) do - gradParam:mul(cutoffNorm/norm) + if torch.type(gradParam) == 'torch.CudaTensor' then + cutorch.withDevice(gradParam:getDevice(), function() -- support multi-device models + gradParam:mul(cutoffNorm/norm) + end) + else + gradParam:mul(cutoffNorm/norm) + end end end end @@ -455,7 +472,13 @@ function Module:momentumGradParameters() end self.momGradParams = {} for i,gradParam in pairs(gradParams) do - self.momGradParams[i] = gradParam.new():resizeAs(gradParam):copy(gradParam) + if torch.type(gradParam) == 'torch.CudaTensor' then + cutorch.withDevice(gradParam:getDevice(), function() -- support multi-device models + self.momGradParams[i] = gradParam.new():resizeAs(gradParam):copy(gradParam) + end) + else + self.momGradParams[i] = gradParam.new():resizeAs(gradParam):copy(gradParam) + end end end return self.momGradParams diff --git a/test/test.lua b/test/test.lua index ec25887..c29e900 100644 --- a/test/test.lua +++ b/test/test.lua @@ -773,8 +773,9 @@ function dpnntest.ReinforceNormal() end function dpnntest.ReinforceGamma() - require 'randomkit' - require 'cephes' + if not pcall(function() require 'randomkit'; require 'cephes' end) then + return + end local input = torch.rand(500,1000):fill(250) -- shapes local gradOutput = torch.Tensor() -- will be ignored local reward = torch.randn(500)