From 7eeb6ae5b15cd6f852d5ba6a590dfe1205a11f1a Mon Sep 17 00:00:00 2001 From: nicholas-leonard Date: Mon, 16 May 2016 15:43:14 -0400 Subject: [PATCH] AliasMultinomial --- AliasMultinomial.lua | 114 +++++++++++++++++++++++++++++++++++++++++++ init.lua | 1 + test/test.lua | 27 ++++++++++ 3 files changed, 142 insertions(+) create mode 100644 AliasMultinomial.lua diff --git a/AliasMultinomial.lua b/AliasMultinomial.lua new file mode 100644 index 0000000..2595518 --- /dev/null +++ b/AliasMultinomial.lua @@ -0,0 +1,114 @@ +-- ref.: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/ +local AM = torch.class("torch.AliasMultinomial") + +function AM:__init(probs) + self.J, self.q = self:setup(probs) +end + +function AM:setup(probs) + assert(probs:dim() == 1) + local K = probs:nElement() + local q = probs.new(K):zero() + local J = torch.LongTensor(K):zero() + + -- Sort the data into the outcomes with probabilities + -- that are larger and smaller than 1/K. + local smaller, larger = {}, {} + local maxk, maxp = 0, -1 + for kk = 1,K do + local prob = probs[kk] + q[kk] = K*prob + if q[kk] < 1 then + table.insert(smaller, kk) + else + table.insert(larger, kk) + end + if maxk > maxp then + + end + end + + -- Loop through and create little binary mixtures that + -- appropriately allocate the larger outcomes over the + -- overall uniform mixture. + while #smaller > 0 and #larger > 0 do + local small = table.remove(smaller) + local large = table.remove(larger) + + J[small] = large + q[large] = q[large] - (1.0 - q[small]) + + if q[large] < 1.0 then + table.insert(smaller,large) + else + table.insert(larger,large) + end + end + assert(q:min() >= 0) + if q:max() > 1 then + q:div(q:max()) + end + assert(q:max() <= 1) + if J:min() <= 0 then + -- sometimes an large index isn't added to J. + -- fix it by making the probability 1 so that J isn't indexed. + local i = 0 + J:apply(function(x) + i = i + 1 + if x <= 0 then + q[i] = 1 + end + end) + end + return J, q +end + +function AM:draw() + J = self.J + q = self.q + local K = J:nElement() + + -- Draw from the overall uniform mixture. + local kk = math.random(1,K) + + -- Draw from the binary mixture, either keeping the + -- small one, or choosing the associated larger one. + if math.random() < q[kk] then + return kk + else + return J[kk] + end +end + +function AM:batchdraw(output) + assert(torch.type(output) == 'torch.LongTensor') + assert(output:nElement() > 0) + local J = self.J + local K = J:nElement() + + self._kk = self._kk or output.new() + self._kk:resizeAs(output):random(1,K) + + self._q = self._q or self.q.new() + self._q:index(self.q, 1, self._kk:view(-1)) + + self._mask = self._b or torch.LongTensor() + self._mask:resize(self._q:size()):bernoulli(self._q) + + self.__kk = self.__kk or output.new() + self.__kk:resize(self._kk:size()):copy(self._kk) + self.__kk:cmul(self._mask) + + -- if mask == 0 then output[i] = J[kk[i]] else output[i] = 0 + + self._mask:add(-1):mul(-1) -- (1,0) - > (0,1) + output:view(-1):index(J, 1, self._kk:view(-1)) + output:cmul(self._mask) + + -- elseif mask == 1 then output[i] = kk[i] + + output:add(self.__kk) + + return output +end + diff --git a/init.lua b/init.lua index 39e9044..3c95a54 100644 --- a/init.lua +++ b/init.lua @@ -17,6 +17,7 @@ torch.include('torchx', 'indexdir.lua') torch.include('torchx', 'dkjson.lua') torch.include('torchx', 'recursivetensor.lua') torch.include('torchx', 'Queue.lua') +torch.include('torchx', 'AliasMultinomial.lua') torch.include('torchx', 'test.lua') diff --git a/test/test.lua b/test/test.lua index c3c6b87..77cf54e 100644 --- a/test/test.lua +++ b/test/test.lua @@ -104,6 +104,33 @@ function torchxtest.concat() mytester:assertTensorEq(res,res2,0.00001) end +function torchxtest.AliasMultinomial() + local probs = torch.Tensor(10):uniform(0,1) + probs:div(probs:sum()) + + local a = torch.Timer() + local am = torch.AliasMultinomial(probs) + print("setup in "..a:time().real.." seconds") + + a:reset() + am:draw() + print("draw in "..a:time().real.." seconds") + + local output = torch.LongTensor(1000, 1000) + a:reset() + am:batchdraw(output) + print("batchdraw in "..a:time().real.." seconds") + + local counts = torch.Tensor(10):zero() + output:apply(function(x) + counts[x] = counts[x] + 1 + end) + + counts:div(counts:sum()) + + mytester:assertTensorEq(probs, counts, 0.001) +end + function torchx.test(tests) local oldtype = torch.getdefaulttensortype() torch.setdefaulttensortype('torch.FloatTensor')