Skip to content

Commit

Permalink
Merge pull request #10 from nicholas-leonard/treemax
Browse files Browse the repository at this point in the history
Treemax
  • Loading branch information
nicholas-leonard committed Oct 17, 2014
2 parents 19e1f48 + 42bae26 commit 23e645d
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 1 deletion.
6 changes: 6 additions & 0 deletions init.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
require 'torch'

torchx = {}

torch.include('torchx', 'treemax.lua')
torch.include('torchx', 'test.lua')
18 changes: 17 additions & 1 deletion test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,30 @@ local torchxtest = {}
local precision_forward = 1e-6
local precision_backward = 1e-6
local nloop = 50
local mytester

--e.g. usage: th -ltorchx -e "torchx.test{'SoftMaxTree','BlockSparse'}"

function torchxtest.treemax()
local treeSize = {3,3,2}
-- 13,14,25 18 = 3x3x2
-- 6,7,12 6 = 3x2
-- 7,5 2 = 2
local tensor = torch.Tensor{0,0,0,0,0,13, 1,1,1,1,1,10, 0,6, 2,5, 7,5}
local maxVal, maxIdx = torch.treemax(tensor, treeSize)
mytester:assert(maxVal == 7, "treemax maxVal 1")
mytester:assert(maxIdx == 17, "treemax maxIdx 1")
-- 27,14,25
local tensor = torch.Tensor{0,0,0,0,0,27, 1,1,1,1,1,10, 0,6, 2,5, 7,5}
local maxVal, maxIdx = torch.treemax(tensor, treeSize)
mytester:assert(maxVal == 27, "treemax maxVal 2")
mytester:assert(maxIdx == 6, "treemax maxIdx 2")
end

function torchx.test(tests)
local oldtype = torch.getdefaulttensortype()
torch.setdefaulttensortype('torch.FloatTensor')
math.randomseed(os.time())
jac = nn.Jacobian
mytester = torch.Tester()
mytester:add(torchxtest)
mytester:run(tests)
Expand Down
33 changes: 33 additions & 0 deletions treemax.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
local treeMaxBuffer = {}
function torch.treemax(tensor, treeSize)
assert(torch.type(treeSize) == 'table')
assert(tensor:dim() == 1)
local tmb = treeMaxBuffer[torch.type(tensor)] -- upvalue
if not tmb then
tmb = {
mean = tensor.new(),
max = tensor.new(),
idx = torch.LongTensor()
}
treeMaxBuffer[torch.type(tensor)] = tmb
end

local lvl = tensor
local maxIdx, maxVal = 1, 0
for i=1,#treeSize do
lvl = lvl:view(treeSize[i], -1)
local lvlStride = lvl:size(2)
if i < #treeSize then
tmb.mean:mean(lvl, 2)
tmb.max:max(tmb.idx, tmb.mean:select(2,1), 1)
else
tmb.max:max(tmb.idx, lvl:select(2,1), 1)
end

local lvlMax, lvlIdx = tmb.max[1], tmb.idx[1]
lvl = lvl[lvlIdx]
maxIdx = maxIdx + (lvlIdx-1)*lvlStride
maxVal = lvlMax
end
return maxVal, maxIdx
end

0 comments on commit 23e645d

Please sign in to comment.