Skip to content

Commit

Permalink
Queue + recursive functions
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Mar 11, 2016
1 parent 37b5a89 commit a01bd5c
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 3 deletions.
28 changes: 28 additions & 0 deletions Queue.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
local Queue = torch.class("torchx.Queue")

function Queue:__init()
self.first = 0
self.last = -1
self.list = {}
end

function Queue:put(value)
local first = self.first - 1
self.first = first
self.list[first] = value
end

function Queue:empty()
return self.first > self.last
end

function Queue:get()
local last = self.last
if self:empty() then
error("Queue is empty")
end
local value = self.list[last]
self.list[last] = nil
self.last = last - 1
return value
end
1 change: 1 addition & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ torch.include('torchx', 'concat.lua')
torch.include('torchx', 'indexdir.lua')
torch.include('torchx', 'dkjson.lua')
torch.include('torchx', 'recursivetensor.lua')
torch.include('torchx', 'Queue.lua')

torch.include('torchx', 'test.lua')

Expand Down
38 changes: 35 additions & 3 deletions recursivetensor.lua
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,16 @@ function torchx.recursiveIndex(res, src, dim, indices)
if torch.type(src) == 'table' then
res = (torch.type(res) == 'table') and res or {res}
for key,_ in pairs(src) do
res[key], res[key] = torchx.recursiveIndex(res[key], src[key], dim, indices)
res[key] = torchx.recursiveIndex(res[key], src[key], dim, indices)
end
elseif torch.isTensor(src) then
res = torch.isTensor(res) or src.new()
res = torch.isTensor(res) and res or src.new()
res:index(src, dim, indices)
else
error("expecting nested tensors or tables. Got "..
torch.type(res).." and "..torch.type(src).." instead")
end
return res, src
return res
end

-- get the batch size (i.e. size of first dim for a nested tensor)
Expand All @@ -179,3 +179,35 @@ function torchx.recursiveBatchSize(input)
return input:size(1)
end
end

function torchx.recursiveSize(input, excludedim)
local res
if torch.type(input) == 'table' then
res = {}
for k,v in pairs(input) do
res[k] = torchx.recursiveSize(v, excludedim)
end
else
assert(torch.isTensor(input))
res = input:size():totable()
if excludedim then
table.remove(res, excludedim)
end
end
return res
end

function torchx.recursiveSub(src, start, stop)
local res
if torch.type(src) == 'table' then
res = {}
for key,_ in pairs(src) do
res[key] = torchx.recursiveSub(src[key], start, stop)
end
elseif torch.isTensor(src) then
res = src:sub(start, stop)
else
error("expecting nested tensors or tables. Got "..torch.type(src).." instead")
end
return res
end

0 comments on commit a01bd5c

Please sign in to comment.