Skip to content

Commit

Permalink
torchx extends torch.Tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Feb 11, 2015
1 parent fd0b644 commit 75c4cb1
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 5 deletions.
3 changes: 2 additions & 1 deletion concat.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ function torch.concat(result, tensors, dim, index)
tensors = result
result = tensors[index].new()
end
result = tensors[index].new()

assert(type(tensors) == 'table', "expecting table at arg 2")

Expand Down Expand Up @@ -44,3 +43,5 @@ function torch.concat(result, tensors, dim, index)
end
return result
end

torchx.Tensor.concat = torch.concat
26 changes: 23 additions & 3 deletions find.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@

function torch.find(tensor, val)
assert(tensor:dim() == 1)

function torch.find(res, tensor, val, asTable)
if not torch.isTensor(tensor) then
asTable = val
val = tensor
tensor = res
if not asTable then
res = torch.LongTensor()
end
end
assert(tensor:dim() == 1, "torch.find only supports 1D tensors (for now)")
local i = 1
local indice = {}
tensor:apply(function(x)
Expand All @@ -9,5 +18,16 @@ function torch.find(tensor, val)
end
i = i + 1
end)
return torch.LongTensor(indice)
if asTable then
return indice
end
res:resize(#indice)
i = 0
res:apply(function()
i = i + 1
return indice[i]
end)
return res
end

torchx.Tensor.find = torch.find
14 changes: 13 additions & 1 deletion init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ require 'paths'
require 'sys'
ffi = require 'ffi'

torchx = {}
torchx = {Tensor={}}


torch.include('torchx', 'extend.lua')
torch.include('torchx', 'md5.lua')
torch.include('torchx', 'treemax.lua')
torch.include('torchx', 'find.lua')
Expand All @@ -15,4 +17,14 @@ torch.include('torchx', 'indexdir.lua')

torch.include('torchx', 'test.lua')

local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Double'}
local Tensor = torchx.Tensor
torchx.Tensor = nil

torchx.extend(types, Tensor, true)

function torchx:cuda()
torchx:extend({'Cuda'}, Tensor, true)
end

return torchx
11 changes: 11 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ function torchxtest.find()
local tensor = torch.Tensor{1,2,3,4,5,6,0.6,0,2}
local indice = torch.find(tensor, 2)
mytester:assertTensorEq(indice, torch.LongTensor{2,9}, 0.00001, "find err")

indice = indice.new()
indice:find(tensor, 2)
mytester:assertTensorEq(indice, torch.LongTensor{2,9}, 0.00001, "find err")

local indiceTbl = torch.find(tensor, 2, true)
mytester:assert(torch.type(indiceTbl) == 'table', "find asTable type error")
mytester:assertTensorEq(indice, torch.LongTensor(indiceTbl), 0.00001, "find asTable value err")
end

function torchxtest.remap()
Expand Down Expand Up @@ -86,6 +94,9 @@ function torchxtest.concat()
res2:narrow(3,5,6):copy(tensors[2])
res2:narrow(3,11,8):copy(tensors[3])
mytester:assertTensorEq(res,res2,0.00001)

res:zero():concat(tensors, 3)
mytester:assertTensorEq(res,res2,0.00001)
end

function torchx.test(tests)
Expand Down

0 comments on commit 75c4cb1

Please sign in to comment.