Skip to content

Commit

Permalink
find returns a table
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Feb 14, 2015
1 parent 2d970a4 commit 03b498d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 19 deletions.
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,14 @@ Example:

<a name='torch.find'/>
### [res] torch.find(tensor, val) ###
Finds all indices of a given value `val` in Tensor `tensor`. Returns a `torch.LongTensor` of these indices.
Finds all indices of a given value `val` in Tensor `tensor`.
Returns a table of these indices.

Example:
```lua
> res = torch.find(torch.Tensor{1,2,3,1,1,2}, 1)
> print(res)
1
4
5
[torch.LongTensor of dimension 3]
> print(unpack(res))
1, 4, 5
```

<a name='torch.group'/>
Expand Down
15 changes: 2 additions & 13 deletions find.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@


function torch.find(res, tensor, val, asTable)
function torch.find(tensor, val)
if not torch.isTensor(tensor) then
asTable = val
val = tensor
Expand All @@ -9,7 +9,6 @@ function torch.find(res, tensor, val, asTable)
res = torch.LongTensor()
end
end
assert(tensor:dim() == 1, "torch.find only supports 1D tensors (for now)")
local i = 1
local indice = {}
tensor:apply(function(x)
Expand All @@ -18,16 +17,6 @@ function torch.find(res, tensor, val, asTable)
end
i = i + 1
end)
if asTable then
return indice
end
res:resize(#indice)
i = 0
res:apply(function()
i = i + 1
return indice[i]
end)
return res
return indice
end

torchx.Tensor.find = torch.find

0 comments on commit 03b498d

Please sign in to comment.