Skip to content

Commit

Permalink
fix torch.group
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Oct 21, 2015
1 parent 39f2e6f commit e9eae23
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
15 changes: 10 additions & 5 deletions group.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,24 @@ function torch.group(sorted, index, tensor, samegrp, desc)
local idx = 1
local groups = {}
sorted:apply(function(val)
if not samegrp(start_val, val) or idx == sorted:size(1) then
if idx == sorted:size(1) then
idx = idx + 1
end

if not samegrp(start_val, val) then
groups[start_val] = {
idx=index:narrow(1, start_idx, idx-start_idx),
val=sorted:narrow(1, start_idx, idx-start_idx)
}
start_val = val
start_idx = idx
end

idx = idx + 1

if idx-1 == sorted:size(1) then
groups[start_val] = {
idx=index:narrow(1, start_idx, idx-start_idx),
val=sorted:narrow(1, start_idx, idx-start_idx)
}
end

end)

return groups, sorted, index
Expand Down
5 changes: 4 additions & 1 deletion test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ function torchxtest.group()
mytester:assert(groups[5].idx:size(1) == 3)
mytester:assertTensorEq(val, val2, 0.00001)
mytester:assertTensorNe(idx2, idx, 0.00001)
-- this was failing for me
local tensor = torch.Tensor{1,2,0}
local groups, val, idx = torch.group(tensor)
mytester:assert(groups[1] and groups[2] and groups[0])
end

function torchxtest.concat()
Expand All @@ -96,7 +100,6 @@ function torchxtest.concat()
res2:narrow(3,5,6):copy(tensors[2])
res2:narrow(3,11,8):copy(tensors[3])
mytester:assertTensorEq(res,res2,0.00001)

res:zero():concat(tensors, 3)
mytester:assertTensorEq(res,res2,0.00001)
end
Expand Down

0 comments on commit e9eae23

Please sign in to comment.