diff --git a/concat.lua b/concat.lua index 56688d1..5e69ad8 100644 --- a/concat.lua +++ b/concat.lua @@ -8,7 +8,6 @@ function torch.concat(result, tensors, dim, index) tensors = result result = tensors[index].new() end - result = tensors[index].new() assert(type(tensors) == 'table', "expecting table at arg 2") @@ -44,3 +43,5 @@ function torch.concat(result, tensors, dim, index) end return result end + +torchx.Tensor.concat = torch.concat diff --git a/find.lua b/find.lua index 30345b7..a8ee55f 100644 --- a/find.lua +++ b/find.lua @@ -1,6 +1,15 @@ -function torch.find(tensor, val) - assert(tensor:dim() == 1) + +function torch.find(res, tensor, val, asTable) + if not torch.isTensor(tensor) then + asTable = val + val = tensor + tensor = res + if not asTable then + res = torch.LongTensor() + end + end + assert(tensor:dim() == 1, "torch.find only supports 1D tensors (for now)") local i = 1 local indice = {} tensor:apply(function(x) @@ -9,5 +18,16 @@ function torch.find(tensor, val) end i = i + 1 end) - return torch.LongTensor(indice) + if asTable then + return indice + end + res:resize(#indice) + i = 0 + res:apply(function() + i = i + 1 + return indice[i] + end) + return res end + +torchx.Tensor.find = torch.find diff --git a/init.lua b/init.lua index 01342b4..daac808 100644 --- a/init.lua +++ b/init.lua @@ -3,8 +3,10 @@ require 'paths' require 'sys' ffi = require 'ffi' -torchx = {} +torchx = {Tensor={}} + +torch.include('torchx', 'extend.lua') torch.include('torchx', 'md5.lua') torch.include('torchx', 'treemax.lua') torch.include('torchx', 'find.lua') @@ -15,4 +17,14 @@ torch.include('torchx', 'indexdir.lua') torch.include('torchx', 'test.lua') +local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Double'} +local Tensor = torchx.Tensor +torchx.Tensor = nil + +torchx.extend(types, Tensor, true) + +function torchx:cuda() + torchx:extend({'Cuda'}, Tensor, true) +end +return torchx diff --git a/test/test.lua b/test/test.lua index 4d396c7..daebd9f 100644 --- a/test/test.lua +++ b/test/test.lua @@ -26,6 +26,14 @@ function torchxtest.find() local tensor = torch.Tensor{1,2,3,4,5,6,0.6,0,2} local indice = torch.find(tensor, 2) mytester:assertTensorEq(indice, torch.LongTensor{2,9}, 0.00001, "find err") + + indice = indice.new() + indice:find(tensor, 2) + mytester:assertTensorEq(indice, torch.LongTensor{2,9}, 0.00001, "find err") + + local indiceTbl = torch.find(tensor, 2, true) + mytester:assert(torch.type(indiceTbl) == 'table', "find asTable type error") + mytester:assertTensorEq(indice, torch.LongTensor(indiceTbl), 0.00001, "find asTable value err") end function torchxtest.remap() @@ -86,6 +94,9 @@ function torchxtest.concat() res2:narrow(3,5,6):copy(tensors[2]) res2:narrow(3,11,8):copy(tensors[3]) mytester:assertTensorEq(res,res2,0.00001) + + res:zero():concat(tensors, 3) + mytester:assertTensorEq(res,res2,0.00001) end function torchx.test(tests)