Skip to content

Commit

Permalink
torch.find works with 2D tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Feb 14, 2015
1 parent 03b498d commit 3c4969a
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 24 deletions.
26 changes: 21 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,31 @@ Example:
```

<a name='torch.find'/>
### [res] torch.find(tensor, val) ###
### [res] torch.find(tensor, val, [dim]) ###
Finds all indices of a given value `val` in Tensor `tensor`.
Returns a table of these indices.
Returns a table of these indices by traversing the tensor one row
at a time. When `dim=2`, the only valid value for dim other than `nil` (the default),
the function expects a matrix and returns the row-wise indices of each found
value `val` in the row.

Example:
1D example:
```lua
> res = torch.find(torch.Tensor{1,2,3,1,1,2}, 1)
> print(unpack(res))
1, 4, 5
> unpack(res)
1 4 5
```

2D example:
```
> tensor = torch.Tensor{{1,2,3,4,5},{5,6,0.6,0,2}}
> unpack(torch.find(tensor, 2))
2 10
> unpack(torch.find(tensor:t(), 2))
3 10
> unpack(torch.find(tensor, 2, 2))
{2} {5}
> unpack(torch.find(tensor:t(), 2, 2))
{ } {1} { } { } {2}
```

<a name='torch.group'/>
Expand Down
32 changes: 22 additions & 10 deletions find.lua
Original file line number Diff line number Diff line change
@@ -1,22 +1,34 @@


function torch.find(tensor, val)
if not torch.isTensor(tensor) then
asTable = val
val = tensor
tensor = res
if not asTable then
res = torch.LongTensor()
end
end
function torch.find(tensor, val, dim)
local i = 1
local indice = {}
tensor:apply(function(x)
if dim then
assert(tensor:dim() == 2, "torch.find dim arg only supports matrices for now")
assert(dim == 2, "torch.find only supports dim=2 for now")

local colSize, rowSize = tensor:size(1), tensor:size(2)
local rowIndice = {}
tensor:apply(function(x)
if x == val then
table.insert(rowIndice, i)
end
if i == rowSize then
i = 1
table.insert(indice, rowIndice)
rowIndice = {}
else
i = i + 1
end
end)
else
tensor:apply(function(x)
if x == val then
table.insert(indice, i)
end
i = i + 1
end)
end
return indice
end

20 changes: 11 additions & 9 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,18 @@ end

function torchxtest.find()
local tensor = torch.Tensor{1,2,3,4,5,6,0.6,0,2}
local indice = torch.find(tensor, 2)
mytester:assertTensorEq(indice, torch.LongTensor{2,9}, 0.00001, "find err")

indice = indice.new()
indice:find(tensor, 2)
mytester:assertTensorEq(indice, torch.LongTensor{2,9}, 0.00001, "find err")
local indice = torch.LongTensor(torch.find(tensor, 2))
mytester:assertTensorEq(indice, torch.LongTensor{2,9}, 0.00001, "find (1D) err")

local indiceTbl = torch.find(tensor, 2, true)
mytester:assert(torch.type(indiceTbl) == 'table', "find asTable type error")
mytester:assertTensorEq(indice, torch.LongTensor(indiceTbl), 0.00001, "find asTable value err")
local tensor = torch.Tensor{{1,2,3,4,5},{5,6,0.6,0,2}}
local indice = torch.find(tensor, 2)
mytester:assertTableEq(indice, {2,10}, 0.00001, "find (2D) err")
local indice = torch.find(tensor:t(), 2)
mytester:assertTableEq(indice, {3,10}, 0.00001, "find (2D transpose) err A")
local indice = torch.find(tensor:t(), 5)
mytester:assertTableEq(indice, {2,9}, 0.00001, "find (2D transpose) err B")
local indice = torch.find(tensor, 2, 2)
mytester:assertTableEq(indice, {{2},{5}}, 0.00001, "find (2D row-wise) err")
end

function torchxtest.remap()
Expand Down

0 comments on commit 3c4969a

Please sign in to comment.