Skip to content

Commit

Permalink
multigpu runs
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Jun 3, 2016
1 parent 62696d9 commit fc5c20c
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 10 deletions.
22 changes: 22 additions & 0 deletions Container.lua
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,25 @@ function Container:sparseParameters()
end
return params, gradParams, scales, size
end

function Container:parameters()
local function tinsert(to, from)
if torch.type(from) == 'table' then -- we change this line so that it works with torch.MultiCudaTensor
for i=1,#from do
tinsert(to,from[i])
end
else
table.insert(to,from)
end
end
local w = {}
local gw = {}
for i=1,#self.modules do
local mw,mgw = self.modules[i]:parameters()
if mw then
tinsert(w,mw)
tinsert(gw,mgw)
end
end
return w,gw
end
26 changes: 26 additions & 0 deletions GPU.lua
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,29 @@ function GPU:fromBatch(...)
return parent.fromBatch(self, unpack(args))
end
end

-- set the device of the decorated module
function GPU:setDevice(device)
self.device = device or self.device

local function recursiveModuleDevice(obj)
if type(obj) == 'table' and not (torch.isTypeOf(obj, 'nn.GPU') or torch.type(obj) == 'torch.MultiCudaTensor') then
for k,v in pairs(obj) do
obj[k] = recursiveModuleDevice(v)
end
elseif torch.type(obj):match('torch.Cuda.*Tensor') then
if obj:getDevice() ~= self.device then
obj = obj:clone() -- this will reallocate it to self.device
local newdevice = obj:getDevice()
-- when nElement() == 0 newdevice is 0
assert(newdevice == self.device or newdevice == 0)
end
end
assert(obj ~= nil)
return obj
end

assert(self.modules[1])
self.modules[1] = cutorch.withDevice(self.device, function() return recursiveModuleDevice(self.modules[1]) end)
return self
end
39 changes: 31 additions & 8 deletions Module.lua
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
if param then
params[paramName] = param
obj[paramName] = nil
if param:storage() then
if torch.isTensor(param) and param.storage and param:storage() then
pointers[torch.pointer(param:storage():data())] = true
end
end
Expand All @@ -82,7 +82,7 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
if gradParam then
params[paramName] = gradParam
obj[paramName] = nil
if gradParam:storage() then
if torch.isTensor(gradParam) and gradParam.storage and gradParam:storage() then
pointers[torch.pointer(gradParam:storage():data())] = true
end
end
Expand Down Expand Up @@ -144,8 +144,13 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
clone[k] = param
original[k] = param
elseif torch.isTensor(param) then
clone[k] = param.new():set(param)
original[k] = param
if param.storage then
clone[k] = param.new():set(param)
original[k] = param
else -- for torch.MultiCudaTensor
clone[k] = param
original[k] = param
end
elseif type(param) == 'table' then
recursiveSet(clone[k], original[k], param)
end
Expand Down Expand Up @@ -397,7 +402,7 @@ function Module:gradParamClip(cutoffNorm, moduleLocal)
local norm = 0
if moduleLocal and self.modules then
for i,module in ipairs(self.modules) do
norm = norm + math.pow(module:gradParamClip(maxOutNorm, maxInNorm), 2)
norm = norm + math.pow(module:gradParamClip(cutoffNorm, moduleLocal), 2)
end
norm = math.sqrt(norm)
else
Expand All @@ -406,13 +411,25 @@ function Module:gradParamClip(cutoffNorm, moduleLocal)
return norm
end
for k,gradParam in pairs(gradParams) do -- pairs for sparse params
norm = norm + math.pow(gradParam:norm(),2)
if torch.type(gradParam) == 'torch.CudaTensor' then
cutorch.withDevice(gradParam:getDevice(), function() -- support multi-device models
norm = norm + math.pow(gradParam:norm(),2)
end)
else
norm = norm + math.pow(gradParam:norm(),2)
end
end
norm = math.sqrt(norm)
if norm > cutoffNorm then
-- rescale gradParams to obtain desired cutoffNorm
for k,gradParam in pairs(gradParams) do
gradParam:mul(cutoffNorm/norm)
if torch.type(gradParam) == 'torch.CudaTensor' then
cutorch.withDevice(gradParam:getDevice(), function() -- support multi-device models
gradParam:mul(cutoffNorm/norm)
end)
else
gradParam:mul(cutoffNorm/norm)
end
end
end
end
Expand Down Expand Up @@ -455,7 +472,13 @@ function Module:momentumGradParameters()
end
self.momGradParams = {}
for i,gradParam in pairs(gradParams) do
self.momGradParams[i] = gradParam.new():resizeAs(gradParam):copy(gradParam)
if torch.type(gradParam) == 'torch.CudaTensor' then
cutorch.withDevice(gradParam:getDevice(), function() -- support multi-device models
self.momGradParams[i] = gradParam.new():resizeAs(gradParam):copy(gradParam)
end)
else
self.momGradParams[i] = gradParam.new():resizeAs(gradParam):copy(gradParam)
end
end
end
return self.momGradParams
Expand Down
5 changes: 3 additions & 2 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -773,8 +773,9 @@ function dpnntest.ReinforceNormal()
end

function dpnntest.ReinforceGamma()
require 'randomkit'
require 'cephes'
if not pcall(function() require 'randomkit'; require 'cephes' end) then
return
end
local input = torch.rand(500,1000):fill(250) -- shapes
local gradOutput = torch.Tensor() -- will be ignored
local reward = torch.randn(500)
Expand Down

0 comments on commit fc5c20c

Please sign in to comment.