diff --git a/README.md b/README.md
index bc28634..daee5b5 100644
--- a/README.md
+++ b/README.md
@@ -29,15 +29,31 @@ Example:
```
-### [res] torch.find(tensor, val) ###
+### [res] torch.find(tensor, val, [dim]) ###
Finds all indices of a given value `val` in Tensor `tensor`.
-Returns a table of these indices.
+Returns a table of these indices by traversing the tensor one row
+at a time. When `dim=2`, the only valid value for dim other than `nil` (the default),
+the function expects a matrix and returns the row-wise indices of each found
+value `val` in the row.
-Example:
+1D example:
```lua
> res = torch.find(torch.Tensor{1,2,3,1,1,2}, 1)
-> print(unpack(res))
- 1, 4, 5
+> unpack(res)
+1 4 5
+```
+
+2D example:
+```
+> tensor = torch.Tensor{{1,2,3,4,5},{5,6,0.6,0,2}}
+> unpack(torch.find(tensor, 2))
+2 10
+> unpack(torch.find(tensor:t(), 2))
+3 10
+> unpack(torch.find(tensor, 2, 2))
+{2} {5}
+> unpack(torch.find(tensor:t(), 2, 2))
+{ } {1} { } { } {2}
```
diff --git a/find.lua b/find.lua
index 022181a..1b23469 100644
--- a/find.lua
+++ b/find.lua
@@ -1,22 +1,34 @@
-function torch.find(tensor, val)
- if not torch.isTensor(tensor) then
- asTable = val
- val = tensor
- tensor = res
- if not asTable then
- res = torch.LongTensor()
- end
- end
+function torch.find(tensor, val, dim)
local i = 1
local indice = {}
- tensor:apply(function(x)
+ if dim then
+ assert(tensor:dim() == 2, "torch.find dim arg only supports matrices for now")
+ assert(dim == 2, "torch.find only supports dim=2 for now")
+
+ local colSize, rowSize = tensor:size(1), tensor:size(2)
+ local rowIndice = {}
+ tensor:apply(function(x)
+ if x == val then
+ table.insert(rowIndice, i)
+ end
+ if i == rowSize then
+ i = 1
+ table.insert(indice, rowIndice)
+ rowIndice = {}
+ else
+ i = i + 1
+ end
+ end)
+ else
+ tensor:apply(function(x)
if x == val then
table.insert(indice, i)
end
i = i + 1
end)
+ end
return indice
end
diff --git a/test/test.lua b/test/test.lua
index daebd9f..af005d7 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -24,16 +24,18 @@ end
function torchxtest.find()
local tensor = torch.Tensor{1,2,3,4,5,6,0.6,0,2}
- local indice = torch.find(tensor, 2)
- mytester:assertTensorEq(indice, torch.LongTensor{2,9}, 0.00001, "find err")
-
- indice = indice.new()
- indice:find(tensor, 2)
- mytester:assertTensorEq(indice, torch.LongTensor{2,9}, 0.00001, "find err")
+ local indice = torch.LongTensor(torch.find(tensor, 2))
+ mytester:assertTensorEq(indice, torch.LongTensor{2,9}, 0.00001, "find (1D) err")
- local indiceTbl = torch.find(tensor, 2, true)
- mytester:assert(torch.type(indiceTbl) == 'table', "find asTable type error")
- mytester:assertTensorEq(indice, torch.LongTensor(indiceTbl), 0.00001, "find asTable value err")
+ local tensor = torch.Tensor{{1,2,3,4,5},{5,6,0.6,0,2}}
+ local indice = torch.find(tensor, 2)
+ mytester:assertTableEq(indice, {2,10}, 0.00001, "find (2D) err")
+ local indice = torch.find(tensor:t(), 2)
+ mytester:assertTableEq(indice, {3,10}, 0.00001, "find (2D transpose) err A")
+ local indice = torch.find(tensor:t(), 5)
+ mytester:assertTableEq(indice, {2,9}, 0.00001, "find (2D transpose) err B")
+ local indice = torch.find(tensor, 2, 2)
+ mytester:assertTableEq(indice, {{2},{5}}, 0.00001, "find (2D row-wise) err")
end
function torchxtest.remap()