-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathMOReinforceNormal.lua
141 lines (126 loc) · 4.49 KB
/
MOReinforceNormal.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
--
-- Created by IntelliJ IDEA.
-- User: petrfiala
-- Date: 17/02/2016
-- Time: 14:09
-- To change this template use File | Settings | File Templates.
--
------------------------------------------------------------------------
--[[ ReinforceNormal ]] --
-- Ref A. http://incompleteideas.net/sutton/williams-92.pdf
-- Inputs are mean (mu) of multivariate normal distribution.
-- Ouputs are samples drawn from these distributions.
-- Standard deviation is provided as constructor argument.
-- Uses the REINFORCE algorithm (ref. A sec 6. p.237-239) which is
-- implemented through the nn.Module:reinforce(r,b) interface.
-- gradOutputs are ignored (REINFORCE algorithm).
------------------------------------------------------------------------
local MOReinforceNormal, parent = torch.class("MOReinforceNormal", "MOReinforce")
function MOReinforceNormal:__init(stdev, stochastic)
parent.__init(self, stochastic)
self.stdev = stdev
if not stdev then
self.gradInput = { torch.Tensor(), torch.Tensor() }
end
-- TODO
reinforce_step = 0
end
function MOReinforceNormal:updateOutput(input)
local mean, stdev = input, self.stdev
if torch.type(input) == 'table' then
-- input is {mean, stdev}
assert(#input == 2)
mean, stdev = unpack(input)
end
assert(stdev)
self.output:resizeAs(mean)
if self.stochastic or self.train ~= false then
self.output:normal()
-- multiply by standard deviations
if torch.type(stdev) == 'number' then
self.output:mul(stdev)
elseif torch.isTensor(stdev) then
if stdev:dim() == mean:dim() then
assert(stdev:isSameSizeAs(mean))
self.output:cmul(stdev)
else
assert(stdev:dim() + 1 == mean:dim())
self._stdev = self._stdev or stdev.new()
self._stdev:view(stdev, 1, table.unpack(stdev:size():totable()))
self.__stdev = self.__stdev or stdev.new()
self.__stdev:expandAs(self._stdev, mean)
self.output:cmul(self.__stdev)
end
else
error "unsupported mean type"
end
if self.train then
reinforce_step = reinforce_step + 1
elseif self.train == nil then
error "model.train has to be defined!"
end
-- re-center the means to the mean
self.output:add(mean)
else
-- use maximum a posteriori (MAP) estimate
self.output:copy(mean)
end
return self.output
end
function MOReinforceNormal:updateGradInput(input, gradOutput)
-- Note that gradOutput is ignored
-- f : normal probability density function
-- x : the sampled values (self.output)
-- u : mean (mu) (mean)
-- s : standard deviation (sigma) (stdev)
local mean, stdev = input, self.stdev
local gradMean, gradStdev = self.gradInput, nil
if torch.type(input) == 'table' then
mean, stdev = unpack(input)
gradMean, gradStdev = unpack(self.gradInput)
end
assert(stdev)
-- Derivative of log normal w.r.t. mean :
-- d ln(f(x,u,s)) (x - u)
-- -------------- = -------
-- d u s^2
gradMean:resizeAs(mean)
-- (x - u)
gradMean:copy(self.output):add(-1, mean)
-- divide by squared standard deviations
if torch.type(stdev) == 'number' then
gradMean:div(stdev ^ 2)
else
if stdev:dim() == mean:dim() then
gradMean:cdiv(stdev):cdiv(stdev)
else
gradMean:cdiv(self.__stdev):cdiv(self.__stdev)
end
end
-- multiply by reward
gradMean:cmul(self:rewardAs(mean, reinforce_step))
reinforce_step = reinforce_step - 1
-- multiply by -1 ( gradient descent on mean )
gradMean:mul(-1)
-- Derivative of log normal w.r.t. stdev :
-- d ln(f(x,u,s)) (x - u)^2 - s^2
-- -------------- = ---------------
-- d s s^3
if gradStdev then
gradStdev:resizeAs(stdev)
-- (x - u)^2
gradStdev:copy(self.output):add(-1, mean):pow(2)
-- subtract s^2
self._stdev2 = self._stdev2 or stdev.new()
self._stdev2:resizeAs(stdev):copy(stdev):cmul(stdev)
gradStdev:add(-1, self._stdev2)
-- divide by s^3
self._stdev2:cmul(stdev):add(0.00000001)
gradStdev:cdiv(self._stdev2)
-- multiply by reward
gradStdev:cmul(self:rewardAs(stdev))
-- multiply by -1 ( gradient descent on stdev )
gradStdev:mul(-1)
end
return self.gradInput
end