@@ -69,7 +69,7 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
6969 if param then
7070 params [paramName ] = param
7171 obj [paramName ] = nil
72- if param :storage () then
72+ if torch . isTensor ( param ) and param . storage and param :storage () then
7373 pointers [torch .pointer (param :storage ():data ())] = true
7474 end
7575 end
@@ -82,7 +82,7 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
8282 if gradParam then
8383 params [paramName ] = gradParam
8484 obj [paramName ] = nil
85- if gradParam :storage () then
85+ if torch . isTensor ( gradParam ) and gradParam . storage and gradParam :storage () then
8686 pointers [torch .pointer (gradParam :storage ():data ())] = true
8787 end
8888 end
@@ -144,8 +144,13 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
144144 clone [k ] = param
145145 original [k ] = param
146146 elseif torch .isTensor (param ) then
147- clone [k ] = param .new ():set (param )
148- original [k ] = param
147+ if param .storage then
148+ clone [k ] = param .new ():set (param )
149+ original [k ] = param
150+ else -- for torch.MultiCudaTensor
151+ clone [k ] = param
152+ original [k ] = param
153+ end
149154 elseif type (param ) == ' table' then
150155 recursiveSet (clone [k ], original [k ], param )
151156 end
@@ -397,7 +402,7 @@ function Module:gradParamClip(cutoffNorm, moduleLocal)
397402 local norm = 0
398403 if moduleLocal and self .modules then
399404 for i ,module in ipairs (self .modules ) do
400- norm = norm + math.pow (module :gradParamClip (maxOutNorm , maxInNorm ), 2 )
405+ norm = norm + math.pow (module :gradParamClip (cutoffNorm , moduleLocal ), 2 )
401406 end
402407 norm = math.sqrt (norm )
403408 else
@@ -406,13 +411,25 @@ function Module:gradParamClip(cutoffNorm, moduleLocal)
406411 return norm
407412 end
408413 for k ,gradParam in pairs (gradParams ) do -- pairs for sparse params
409- norm = norm + math.pow (gradParam :norm (),2 )
414+ if torch .type (gradParam ) == ' torch.CudaTensor' then
415+ cutorch .withDevice (gradParam :getDevice (), function () -- support multi-device models
416+ norm = norm + math.pow (gradParam :norm (),2 )
417+ end )
418+ else
419+ norm = norm + math.pow (gradParam :norm (),2 )
420+ end
410421 end
411422 norm = math.sqrt (norm )
412423 if norm > cutoffNorm then
413424 -- rescale gradParams to obtain desired cutoffNorm
414425 for k ,gradParam in pairs (gradParams ) do
415- gradParam :mul (cutoffNorm / norm )
426+ if torch .type (gradParam ) == ' torch.CudaTensor' then
427+ cutorch .withDevice (gradParam :getDevice (), function () -- support multi-device models
428+ gradParam :mul (cutoffNorm / norm )
429+ end )
430+ else
431+ gradParam :mul (cutoffNorm / norm )
432+ end
416433 end
417434 end
418435 end
@@ -455,7 +472,13 @@ function Module:momentumGradParameters()
455472 end
456473 self .momGradParams = {}
457474 for i ,gradParam in pairs (gradParams ) do
458- self .momGradParams [i ] = gradParam .new ():resizeAs (gradParam ):copy (gradParam )
475+ if torch .type (gradParam ) == ' torch.CudaTensor' then
476+ cutorch .withDevice (gradParam :getDevice (), function () -- support multi-device models
477+ self .momGradParams [i ] = gradParam .new ():resizeAs (gradParam ):copy (gradParam )
478+ end )
479+ else
480+ self .momGradParams [i ] = gradParam .new ():resizeAs (gradParam ):copy (gradParam )
481+ end
459482 end
460483 end
461484 return self .momGradParams
0 commit comments