Skip to content

Commit

Permalink
allow user to specify a clone function when using sharedClone
Browse files Browse the repository at this point in the history
  • Loading branch information
albanD committed Sep 9, 2016
1 parent ead91a2 commit 5b2eb7f
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions Module.lua
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
moduleTree = obj
obj = nil
isTable = false
elseif obj.dpnn_sharedClone then
-- allow to use a custom sharedClone method on one module
moduleTree = obj
obj = nil
isTable = false
elseif scdone[torch.pointer(obj)] then
moduleTree = scdone[torch.pointer(obj)]
else
Expand Down Expand Up @@ -142,8 +147,13 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
if scdone[torch.pointer(original)] then
for k,param in pairs(moduleTree) do
if torch.isTypeOf(param,'nn.Module') then
-- AbstractRecurrent instances branch here with stepClone = true
clone[k] = param
if param.dpnn_sharedClone then
-- Call the custom sharedClone
clone[k] = param:dpnn_sharedClone()
else
-- AbstractRecurrent instances branch here with stepClone = true
clone[k] = param
end
original[k] = param
elseif torch.isTensor(param) then
if param.storage then
Expand Down

0 comments on commit 5b2eb7f

Please sign in to comment.