From a01bd5c8b4050280c5cddd70a3dab1863de7dd54 Mon Sep 17 00:00:00 2001 From: nicholas-leonard Date: Fri, 11 Mar 2016 17:47:59 -0500 Subject: [PATCH] Queue + recursive functions --- Queue.lua | 28 ++++++++++++++++++++++++++++ init.lua | 1 + recursivetensor.lua | 38 +++++++++++++++++++++++++++++++++++--- 3 files changed, 64 insertions(+), 3 deletions(-) create mode 100644 Queue.lua diff --git a/Queue.lua b/Queue.lua new file mode 100644 index 0000000..856acfd --- /dev/null +++ b/Queue.lua @@ -0,0 +1,28 @@ +local Queue = torch.class("torchx.Queue") + +function Queue:__init() + self.first = 0 + self.last = -1 + self.list = {} +end + +function Queue:put(value) + local first = self.first - 1 + self.first = first + self.list[first] = value +end + +function Queue:empty() + return self.first > self.last +end + +function Queue:get() + local last = self.last + if self:empty() then + error("Queue is empty") + end + local value = self.list[last] + self.list[last] = nil + self.last = last - 1 + return value +end diff --git a/init.lua b/init.lua index 481a84e..39e9044 100644 --- a/init.lua +++ b/init.lua @@ -16,6 +16,7 @@ torch.include('torchx', 'concat.lua') torch.include('torchx', 'indexdir.lua') torch.include('torchx', 'dkjson.lua') torch.include('torchx', 'recursivetensor.lua') +torch.include('torchx', 'Queue.lua') torch.include('torchx', 'test.lua') diff --git a/recursivetensor.lua b/recursivetensor.lua index f035bfb..c3985d2 100644 --- a/recursivetensor.lua +++ b/recursivetensor.lua @@ -158,16 +158,16 @@ function torchx.recursiveIndex(res, src, dim, indices) if torch.type(src) == 'table' then res = (torch.type(res) == 'table') and res or {res} for key,_ in pairs(src) do - res[key], res[key] = torchx.recursiveIndex(res[key], src[key], dim, indices) + res[key] = torchx.recursiveIndex(res[key], src[key], dim, indices) end elseif torch.isTensor(src) then - res = torch.isTensor(res) or src.new() + res = torch.isTensor(res) and res or src.new() res:index(src, dim, indices) else error("expecting nested tensors or tables. Got ".. torch.type(res).." and "..torch.type(src).." instead") end - return res, src + return res end -- get the batch size (i.e. size of first dim for a nested tensor) @@ -179,3 +179,35 @@ function torchx.recursiveBatchSize(input) return input:size(1) end end + +function torchx.recursiveSize(input, excludedim) + local res + if torch.type(input) == 'table' then + res = {} + for k,v in pairs(input) do + res[k] = torchx.recursiveSize(v, excludedim) + end + else + assert(torch.isTensor(input)) + res = input:size():totable() + if excludedim then + table.remove(res, excludedim) + end + end + return res +end + +function torchx.recursiveSub(src, start, stop) + local res + if torch.type(src) == 'table' then + res = {} + for key,_ in pairs(src) do + res[key] = torchx.recursiveSub(src[key], start, stop) + end + elseif torch.isTensor(src) then + res = src:sub(start, stop) + else + error("expecting nested tensors or tables. Got "..torch.type(src).." instead") + end + return res +end