Skip to content

Commit 489b0f2

Browse files
CategoricalEntropy unit tested
1 parent 1f7f0d9 commit 489b0f2

File tree

3 files changed

+86
-0
lines changed

3 files changed

+86
-0
lines changed

CategoricalEntropy.lua

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
------------------------------------------------------------------------
2+
--[[ CategoricalEntropy ]]--
3+
-- Maximize the entropy of a categorical distribution (e.g. softmax ).
4+
-- H(X) = E(-log(p(X)) = -sum(p(X)log(p(X))
5+
-- where X = 1,...,N and N is the number of categories.
6+
-- A batch with an entropy below minEntropy will be maximized.
7+
-- d H(X=x) p(x)
8+
-- -------- = - ---- - log(p(x)) = -1 - log(p(x))
9+
-- d p p(x)
10+
------------------------------------------------------------------------
11+
local CE, parent = torch.class("nn.CategoricalEntropy", "nn.Module")
12+
13+
function CE:__init(scale, minEntropy)
14+
parent.__init(self)
15+
self.scale = scale or 1
16+
self.minEntropy = minEntropy
17+
18+
-- get the P(X) using the batch as a prior
19+
self.module = nn.Sequential()
20+
self.module:add(nn.Sum(1)) -- sum categorical probabilities over batch
21+
self._mul = nn.MulConstant(1)
22+
self.module:add(self._mul) -- make them sum to one (i.e. probabilities)
23+
24+
-- get entropy H(X)
25+
local concat = nn.ConcatTable()
26+
concat:add(nn.Identity()) -- p(X)
27+
local seq = nn.Sequential()
28+
seq:add(nn.AddConstant(0.000001)) -- prevent log(0) = nan errors
29+
seq:add(nn.Log())
30+
concat:add(seq)
31+
self.module:add(concat) -- log(p(x))
32+
self.module:add(nn.CMulTable()) -- p(x)log(p(x))
33+
self.module:add(nn.Sum()) -- sum(p(x)log(p(x)))
34+
self.module:add(nn.MulConstant(-1)) -- H(x)
35+
36+
self.modules = {self.module}
37+
38+
self.minusOne = torch.Tensor{-self.scale} -- gradient descent on maximization
39+
self.sizeAverage = true
40+
end
41+
42+
function CE:updateOutput(input)
43+
assert(input:dim() == 2, "CategoricalEntropy only works with batches")
44+
self.output:set(input)
45+
return self.output
46+
end
47+
48+
function CE:updateGradInput(input, gradOutput, scale)
49+
assert(input:dim() == 2, "CategoricalEntropy only works with batches")
50+
self.gradInput:resizeAs(input):copy(gradOutput)
51+
52+
self._mul.constant_scalar = 1/input:sum() -- sum to one
53+
self.entropy = self.module:updateOutput(input)[1]
54+
if (not self.minEntropy) or (self.entropy < self.minEntropy) then
55+
local gradEntropy = self.module:updateGradInput(input, self.minusOne, scale)
56+
if self.sizeAverage then
57+
gradEntropy:div(input:size(1))
58+
end
59+
self.gradInput:add(gradEntropy)
60+
end
61+
62+
return self.gradInput
63+
end

init.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ torch.include('dpnn', 'Clip.lua')
7070
torch.include('dpnn', 'SpatialUniformCrop.lua')
7171
torch.include('dpnn', 'SpatialGlimpse.lua')
7272
torch.include('dpnn', 'ArgMax.lua')
73+
torch.include('dpnn', 'CategoricalEntropy.lua')
7374

7475
-- REINFORCE
7576
torch.include('dpnn', 'Reinforce.lua')

test/test.lua

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,28 @@ function dpnntest.ArgMax()
705705
mytester:assertTensorEq(gradInput, input:clone():zero(), 0.000001, "ArgMax gradInput not asLong err")
706706
end
707707

708+
function dpnntest.CategoricalEntropy()
709+
local inputSize = 5
710+
local batchSize = 10
711+
local minEntropy = 12
712+
local input_ = torch.randn(batchSize, inputSize)
713+
local input = nn.SoftMax():updateOutput(input_)
714+
local gradOutput = torch.Tensor(batchSize, inputSize):zero()
715+
local ce = nn.CategoricalEntropy()
716+
local output = ce:forward(input)
717+
mytester:assertTensorEq(input, output, 0.0000001, "CategoricalEntropy output err")
718+
local gradInput = ce:backward(input, gradOutput)
719+
local output2 = input:sum(1)[1]
720+
output2:div(output2:sum())
721+
local log2 = torch.log(output2 + 0.000001)
722+
local entropy2 = -output2:cmul(log2):sum()
723+
mytester:assert(math.abs(ce.entropy - entropy2) < 0.000001, "CategoricalEntropy entropy err")
724+
local gradEntropy2 = log2:add(1) -- -1*(-1 - log(p(x))) = 1 + log(p(x))
725+
gradEntropy2:div(input:sum())
726+
local gradInput2 = gradEntropy2:div(batchSize):view(1,inputSize):expandAs(input)
727+
mytester:assertTensorEq(gradInput2, gradInput, 0.000001, "CategoricalEntropy gradInput err")
728+
end
729+
708730
function dpnnbigtest.Reinforce()
709731
-- let us try to reinforce an mlp to learn a simple distribution
710732
local n = 10

0 commit comments

Comments
 (0)