From 3c4969a7c89ea898506701834d3fcabb7529d9b8 Mon Sep 17 00:00:00 2001 From: nicholas-leonard Date: Sat, 14 Feb 2015 16:04:18 -0500 Subject: [PATCH] torch.find works with 2D tensors --- README.md | 26 +++++++++++++++++++++----- find.lua | 32 ++++++++++++++++++++++---------- test/test.lua | 20 +++++++++++--------- 3 files changed, 54 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index bc28634..daee5b5 100644 --- a/README.md +++ b/README.md @@ -29,15 +29,31 @@ Example: ``` -### [res] torch.find(tensor, val) ### +### [res] torch.find(tensor, val, [dim]) ### Finds all indices of a given value `val` in Tensor `tensor`. -Returns a table of these indices. +Returns a table of these indices by traversing the tensor one row +at a time. When `dim=2`, the only valid value for dim other than `nil` (the default), +the function expects a matrix and returns the row-wise indices of each found +value `val` in the row. -Example: +1D example: ```lua > res = torch.find(torch.Tensor{1,2,3,1,1,2}, 1) -> print(unpack(res)) - 1, 4, 5 +> unpack(res) +1 4 5 +``` + +2D example: +``` +> tensor = torch.Tensor{{1,2,3,4,5},{5,6,0.6,0,2}} +> unpack(torch.find(tensor, 2)) +2 10 +> unpack(torch.find(tensor:t(), 2)) +3 10 +> unpack(torch.find(tensor, 2, 2)) +{2} {5} +> unpack(torch.find(tensor:t(), 2, 2)) +{ } {1} { } { } {2} ``` diff --git a/find.lua b/find.lua index 022181a..1b23469 100644 --- a/find.lua +++ b/find.lua @@ -1,22 +1,34 @@ -function torch.find(tensor, val) - if not torch.isTensor(tensor) then - asTable = val - val = tensor - tensor = res - if not asTable then - res = torch.LongTensor() - end - end +function torch.find(tensor, val, dim) local i = 1 local indice = {} - tensor:apply(function(x) + if dim then + assert(tensor:dim() == 2, "torch.find dim arg only supports matrices for now") + assert(dim == 2, "torch.find only supports dim=2 for now") + + local colSize, rowSize = tensor:size(1), tensor:size(2) + local rowIndice = {} + tensor:apply(function(x) + if x == val then + table.insert(rowIndice, i) + end + if i == rowSize then + i = 1 + table.insert(indice, rowIndice) + rowIndice = {} + else + i = i + 1 + end + end) + else + tensor:apply(function(x) if x == val then table.insert(indice, i) end i = i + 1 end) + end return indice end diff --git a/test/test.lua b/test/test.lua index daebd9f..af005d7 100644 --- a/test/test.lua +++ b/test/test.lua @@ -24,16 +24,18 @@ end function torchxtest.find() local tensor = torch.Tensor{1,2,3,4,5,6,0.6,0,2} - local indice = torch.find(tensor, 2) - mytester:assertTensorEq(indice, torch.LongTensor{2,9}, 0.00001, "find err") - - indice = indice.new() - indice:find(tensor, 2) - mytester:assertTensorEq(indice, torch.LongTensor{2,9}, 0.00001, "find err") + local indice = torch.LongTensor(torch.find(tensor, 2)) + mytester:assertTensorEq(indice, torch.LongTensor{2,9}, 0.00001, "find (1D) err") - local indiceTbl = torch.find(tensor, 2, true) - mytester:assert(torch.type(indiceTbl) == 'table', "find asTable type error") - mytester:assertTensorEq(indice, torch.LongTensor(indiceTbl), 0.00001, "find asTable value err") + local tensor = torch.Tensor{{1,2,3,4,5},{5,6,0.6,0,2}} + local indice = torch.find(tensor, 2) + mytester:assertTableEq(indice, {2,10}, 0.00001, "find (2D) err") + local indice = torch.find(tensor:t(), 2) + mytester:assertTableEq(indice, {3,10}, 0.00001, "find (2D transpose) err A") + local indice = torch.find(tensor:t(), 5) + mytester:assertTableEq(indice, {2,9}, 0.00001, "find (2D transpose) err B") + local indice = torch.find(tensor, 2, 2) + mytester:assertTableEq(indice, {{2},{5}}, 0.00001, "find (2D row-wise) err") end function torchxtest.remap()