Skip to content

Commit

Permalink
torch.remap
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Nov 30, 2014
1 parent aee8954 commit ee11de2
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 0 deletions.
1 change: 1 addition & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ torchx = {}

torch.include('torchx', 'treemax.lua')
torch.include('torchx', 'find.lua')
torch.include('torchx', 'remap.lua')

torch.include('torchx', 'test.lua')
30 changes: 30 additions & 0 deletions remap.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@

-- recursive map
function torch.remap(t1, t2, f, p1, p2)
if torch.type(f) ~= 'function' then
error"Expecting function at argument 3"
end
if torch.type(t1) == 'table' then
t2 = t2 or {}
for i=1,#t1 do
t1[i], t2[i] = torch.remap(t1[i], t2[i], f, p1, p2)
end
elseif torch.type(t2) == 'table' then
t1 = t1 or {}
for i=1,#t2 do
t1[i], t2[i] = torch.remap(t1[i], t2[i], f, p1, p2)
end
elseif torch.isTensor(t1) or torch.isTensor(t2) then
if not t1 then
t1 = p1 and p1.new() or t2.new()
elseif not t2 then
t2 = (p2 and p2.new()) or (p1 and p1.new()) or t1.new()
end
f(t1, t2)
else
error("expecting nested tensors or tables. Got "..
torch.type(t1).." and "..torch.type(t2).." instead")
end
return t1, t2
end

22 changes: 22 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,28 @@ function torchxtest.find()
mytester:assertTensorEq(indice, torch.LongTensor{2,9}, 0.00001, "find err")
end

function torchxtest.remap()
local a, b, c, d = torch.randn(3,4), torch.randn(3,4), torch.randn(2,4), torch.randn(1)
local e, f, g, h = torch.randn(3,4), torch.randn(3,4), torch.randn(2,4), torch.randn(1)
local t1 = {a:clone(), {b:clone(), c:clone(), {d:clone()}}}
local t2 = {e:clone(), {f:clone(), g:clone(), {h:clone()}}}
torch.remap(t1, t2, function(x, y) x:add(y) end)
mytester:assertTensorEq(a:add(e), t1[1], 0.000001, "error remap a add")
mytester:assertTensorEq(b:add(f), t1[2][1], 0.000001, "error remap b add")
mytester:assertTensorEq(c:add(g), t1[2][2], 0.000001, "error remap c add")
mytester:assertTensorEq(d:add(h), t1[2][3][1], 0.000001, "error remap d add")
local __, t3 = torch.remap(t2, nil, function(x, y) y:resizeAs(x):copy(x) end)
mytester:assertTensorEq(e, t3[1], 0.000001, "error remap e copy")
mytester:assertTensorEq(f, t3[2][1], 0.000001, "error remap f copy")
mytester:assertTensorEq(g, t3[2][2], 0.000001, "error remap g copy")
mytester:assertTensorEq(h, t3[2][3][1], 0.000001, "error remap h copy")
local t4, __ = torch.remap(nil, t2, function(x, y) x:resize(y:size()):copy(y) end, torch.LongTensor())
mytester:assert(torch.type(t4[1]) == 'torch.LongTensor', "error remap e copy")
mytester:assert(torch.type(t4[2][1]) == 'torch.LongTensor', "error remap f copy")
mytester:assert(torch.type(t4[2][2]) == 'torch.LongTensor', "error remap g copy")
mytester:assert(torch.type(t4[2][3][1]) == 'torch.LongTensor', "error remap h copy")
end

function torchx.test(tests)
local oldtype = torch.getdefaulttensortype()
torch.setdefaulttensortype('torch.FloatTensor')
Expand Down

0 comments on commit ee11de2

Please sign in to comment.