Skip to content

Commit

Permalink
Reinforce.stochastic evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Sep 16, 2015
1 parent 28a5971 commit b04c0d3
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 12 deletions.
20 changes: 14 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,16 @@ When `castTarget = true` (the default), the `targetModule` is cast along with th
Ref A. [Simple Statistical Gradient-Following Algorithms for Connectionist Reinforcement Learning](http://incompleteideas.net/sutton/williams-92.pdf)

Abstract class for modules that implement the REINFORCE algorithm (ref. A).

```lua
module = nn.Reinforce([stochastic])
```

The `reinforce(reward)` method is called by a special Reward Criterion (e.g. [VRClassReward](#nn.VRClassReward)).
After which, when backward is called, the reward will be used to generate gradInputs.
When `stochastic=true`, the module is stochastic (i.e. samples from a distribution)
during evaluation and training.
When `stochastic=false` (the default), the module is only stochastic during training.

The REINFORCE rule for a module can be summarized as follows :
```lua
Expand Down Expand Up @@ -559,14 +567,14 @@ Ref A. [Simple Statistical Gradient-Following Algorithms for
Connectionist Reinforcement Learning](http://incompleteideas.net/sutton/williams-92.pdf)

```lua
module = nn.ReinforceBernoulli()
module = nn.ReinforceBernoulli([stochastic])
```

A [Reinforce](#nn.Reinforce) subclass that implements the REINFORCE algorithm
(ref. A p.230-236) for the Bernoulli probability distribution.
Inputs are bernoulli probabilities `p`.
During training, outputs are samples drawn from this distribution.
During evaluation, outputs are the same as the inputs.
During evaluation, when `stochastic=false`, outputs are the same as the inputs.
Uses the REINFORCE algorithm (ref. A p.230-236) which is
implemented through the [reinforce](#nn.Module.reinforce) interface (`gradOutputs` are ignored).

Expand All @@ -588,15 +596,15 @@ d ln(f(output,input)) d ln(f(x,p)) (x - p)
Ref A. [Simple Statistical Gradient-Following Algorithms for Connectionist Reinforcement Learning](http://incompleteideas.net/sutton/williams-92.pdf)

```lua
module = nn.ReinforceNormal(stdev)
module = nn.ReinforceNormal(stdev, [stochastic])
```

A [Reinforce](#nn.Reinforce) subclass that implements the REINFORCE algorithm
(ref. A p.238-239) for a Normal (i.e. Gaussian) probability distribution.
Inputs are the means of the normal distribution.
The `stdev` argument specifies the standard deviation of the distribution.
During training, outputs are samples drawn from this distribution.
During evaluation, outputs are the same as the inputs, i.e. the means.
During evaluation, when `stochastic=false`, outputs are the same as the inputs, i.e. the means.
Uses the REINFORCE algorithm (ref. A p.238-239) which is
implemented through the [reinforce](#nn.Module.reinforce) interface (`gradOutputs` are ignored).

Expand All @@ -622,7 +630,7 @@ module (see [this example](https://github.com/Element-Research/rnn/blob/master/e
Ref A. [Simple Statistical Gradient-Following Algorithms for Connectionist Reinforcement Learning](http://incompleteideas.net/sutton/williams-92.pdf)

```lua
module = nn.ReinforceCategorical()
module = nn.ReinforceCategorical([stochastic])
```

A [Reinforce](#nn.Reinforce) subclass that implements the REINFORCE algorithm
Expand All @@ -633,7 +641,7 @@ For `n` categories, both the `input` and `output` ares of size `batchSize x n`.
During training, outputs are samples drawn from this distribution.
The outputs are returned in one-hot encoding i.e.
the output for each example has exactly one category having a 1, while the remainder are zero.
During evaluation, outputs are the same as the inputs, i.e. the means.
During evaluation, when `stochastic=false`, outputs are the same as the inputs, i.e. the probabilities `p`.
Uses the REINFORCE algorithm (ref. A) which is
implemented through the [reinforce](#nn.Module.reinforce) interface (`gradOutputs` are ignored).

Expand Down
7 changes: 7 additions & 0 deletions Reinforce.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
------------------------------------------------------------------------
local Reinforce, parent = torch.class("nn.Reinforce", "nn.Module")

function Reinforce:__init(stochastic)
parent.__init(self)
-- true makes it stochastic during evaluation and training
-- false makes it stochastic only during training
self.stochastic = stochastic
end

-- a Reward Criterion will call this
function Reinforce:reinforce(reward)
parent.reinforce(self, reward)
Expand Down
2 changes: 1 addition & 1 deletion ReinforceBernoulli.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ local ReinforceBernoulli, parent = torch.class("nn.ReinforceBernoulli", "nn.Rein

function ReinforceBernoulli:updateOutput(input)
self.output:resizeAs(input)
if self.train ~= false then
if self.stochastic or self.train ~= false then
-- sample from bernoulli with P(output=1) = input
self._uniform = self._uniform or input.new()
self._uniform:resizeAs(input):uniform(0,1)
Expand Down
2 changes: 1 addition & 1 deletion ReinforceCategorical.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ local ReinforceCategorical, parent = torch.class("nn.ReinforceCategorical", "nn.
function ReinforceCategorical:updateOutput(input)
self.output:resizeAs(input)
self._index = self._index or ((torch.type(input) == 'torch.CudaTensor') and torch.CudaTensor() or torch.LongTensor())
if self.train ~= false then
if self.stochastic or self.train ~= false then
-- sample from categorical with p = input
input.multinomial(self._index, input, 1)
-- one hot encoding
Expand Down
7 changes: 3 additions & 4 deletions ReinforceNormal.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@
------------------------------------------------------------------------
local ReinforceNormal, parent = torch.class("nn.ReinforceNormal", "nn.Reinforce")

function ReinforceNormal:__init(stdev)
parent.__init(self)
function ReinforceNormal:__init(stdev, stochastic)
parent.__init(self, stochastic)
self.stdev = stdev
end

function ReinforceNormal:updateOutput(input)
-- TODO : input could also be a table of mean and stdev tensors
self.output:resizeAs(input)
if self.train ~= false then
if self.stochastic or self.train ~= false then
self.output:normal()

-- multiply by standard deviations
if torch.type(self.stdev) == 'number' then
self.output:mul(self.stdev)
Expand Down

0 comments on commit b04c0d3

Please sign in to comment.