Skip to content

Commit

Permalink
recursive tensor functions
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Feb 2, 2016
1 parent e9eae23 commit 37b5a89
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 0 deletions.
1 change: 1 addition & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ torch.include('torchx', 'group.lua')
torch.include('torchx', 'concat.lua')
torch.include('torchx', 'indexdir.lua')
torch.include('torchx', 'dkjson.lua')
torch.include('torchx', 'recursivetensor.lua')

torch.include('torchx', 'test.lua')

Expand Down
181 changes: 181 additions & 0 deletions recursivetensor.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@

function torchx.recursiveResizeAs(t1,t2)
if torch.type(t2) == 'table' then
t1 = (torch.type(t1) == 'table') and t1 or {t1}
for key,_ in pairs(t2) do
t1[key], t2[key] = torchx.recursiveResizeAs(t1[key], t2[key])
end
elseif torch.isTensor(t2) then
t1 = torch.isTensor(t1) and t1 or t2.new()
t1:resizeAs(t2)
else
error("expecting nested tensors or tables. Got "..
torch.type(t1).." and "..torch.type(t2).." instead")
end
return t1, t2
end

function torchx.recursiveSet(t1,t2)
if torch.type(t2) == 'table' then
t1 = (torch.type(t1) == 'table') and t1 or {t1}
for key,_ in pairs(t2) do
t1[key], t2[key] = torchx.recursiveSet(t1[key], t2[key])
end
elseif torch.isTensor(t2) then
t1 = torch.isTensor(t1) and t1 or t2.new()
t1:set(t2)
else
error("expecting nested tensors or tables. Got "..
torch.type(t1).." and "..torch.type(t2).." instead")
end
return t1, t2
end

function torchx.recursiveCopy(t1,t2)
if torch.type(t2) == 'table' then
t1 = (torch.type(t1) == 'table') and t1 or {t1}
for key,_ in pairs(t2) do
t1[key], t2[key] = torchx.recursiveCopy(t1[key], t2[key])
end
elseif torch.isTensor(t2) then
t1 = torch.isTensor(t1) and t1 or t2.new()
t1:resizeAs(t2):copy(t2)
else
error("expecting nested tensors or tables. Got "..
torch.type(t1).." and "..torch.type(t2).." instead")
end
return t1, t2
end

function torchx.recursiveAdd(t1, t2)
if torch.type(t2) == 'table' then
t1 = (torch.type(t1) == 'table') and t1 or {t1}
for key,_ in pairs(t2) do
t1[key], t2[key] = torchx.recursiveAdd(t1[key], t2[key])
end
elseif torch.isTensor(t1) and torch.isTensor(t2) then
t1:add(t2)
else
error("expecting nested tensors or tables. Got "..
torch.type(t1).." and "..torch.type(t2).." instead")
end
return t1, t2
end

function torchx.recursiveTensorEq(t1, t2)
if torch.type(t2) == 'table' then
local isEqual = true
if torch.type(t1) ~= 'table' then
return false
end
for key,_ in pairs(t2) do
isEqual = isEqual and torchx.recursiveTensorEq(t1[key], t2[key])
end
return isEqual
elseif torch.isTensor(t2) and torch.isTensor(t2) then
local diff = t1-t2
local err = diff:abs():max()
return err < 0.00001
else
error("expecting nested tensors or tables. Got "..
torch.type(t1).." and "..torch.type(t2).." instead")
end
end

function torchx.recursiveNormal(t2)
if torch.type(t2) == 'table' then
for key,_ in pairs(t2) do
t2[key] = torchx.recursiveNormal(t2[key])
end
elseif torch.isTensor(t2) then
t2:normal()
else
error("expecting tensor or table thereof. Got "
..torch.type(t2).." instead")
end
return t2
end

function torchx.recursiveFill(t2, val)
if torch.type(t2) == 'table' then
for key,_ in pairs(t2) do
t2[key] = torchx.recursiveFill(t2[key], val)
end
elseif torch.isTensor(t2) then
t2:fill(val)
else
error("expecting tensor or table thereof. Got "
..torch.type(t2).." instead")
end
return t2
end

function torchx.recursiveType(param, type_str)
if torch.type(param) == 'table' then
for i = 1, #param do
param[i] = torchx.recursiveType(param[i], type_str)
end
else
if torch.typename(param) and
torch.typename(param):find('torch%..+Tensor') then
param = param:type(type_str)
end
end
return param
end

function torchx.recursiveSum(t2)
local sum = 0
if torch.type(t2) == 'table' then
for key,_ in pairs(t2) do
sum = sum + torchx.recursiveSum(t2[key], val)
end
elseif torch.isTensor(t2) then
return t2:sum()
else
error("expecting tensor or table thereof. Got "
..torch.type(t2).." instead")
end
return sum
end

function torchx.recursiveNew(t2)
if torch.type(t2) == 'table' then
local t1 = {}
for key,_ in pairs(t2) do
t1[key] = torchx.recursiveNew(t2[key])
end
return t1
elseif torch.isTensor(t2) then
return t2.new()
else
error("expecting tensor or table thereof. Got "
..torch.type(t2).." instead")
end
end

function torchx.recursiveIndex(res, src, dim, indices)
if torch.type(src) == 'table' then
res = (torch.type(res) == 'table') and res or {res}
for key,_ in pairs(src) do
res[key], res[key] = torchx.recursiveIndex(res[key], src[key], dim, indices)
end
elseif torch.isTensor(src) then
res = torch.isTensor(res) or src.new()
res:index(src, dim, indices)
else
error("expecting nested tensors or tables. Got "..
torch.type(res).." and "..torch.type(src).." instead")
end
return res, src
end

-- get the batch size (i.e. size of first dim for a nested tensor)
function torchx.recursiveBatchSize(input)
if torch.type(input) == 'table' then
return torchx.recursiveBatchSize(input[1])
else
assert(torch.isTensor(input))
return input:size(1)
end
end

0 comments on commit 37b5a89

Please sign in to comment.