Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/02_bandits/02_stochastic_bandit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ end
Construct a new StochasticBandit object from a vector of probability
distribution objects.
""" ->
function StochasticBandit{D <: UnivariateDistribution}(arms::Vector{D})
return StochasticBandit{D}(arms)
end

@doc """
Construct a new StochasticBandit object from a vector of probability
distribution objects and a time period integer.
""" ->
function StochasticBandit{D <: UnivariateDistribution}(arms::Vector{D}, t::Integer)
return StochasticBandit{D}(arms)
end
Expand Down
11 changes: 9 additions & 2 deletions src/03_learners/02_mle_learner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,26 @@ function learn!(

if nᵢ == 1
learner.oldMs[a] = r
learner.Ss[a] = 0.0
learner.Ss[a] = learner.σ₀
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these bug fixes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not clear that I should really be doing this. We're already abusing the learner.σ₀ notation a bit, since we aren't "really" using a prior as the notation would suggest. This just ensures that the standard error is never actually exactly zero. For a Bernoulli DGP, we may have an estimated standard deviation of zero for a while until we finally observe a success. So this is basically like an Agresti-Coull estimate, with this change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. This seems like we should potentially be using something like NaN instead. Or at least not calling this MLE anymore.

learner.μs[a] = r
else
learner.newMs[a] = learner.oldMs[a] + (r - learner.oldMs[a]) / nᵢ
learner.Ss[a] += (r - learner.oldMs[a]) * (r - learner.newMs[a])
learner.oldMs[a] = learner.newMs[a]
learner.μs[a] = learner.newMs[a]
learner.σs[a] = sqrt(learner.Ss[a] / (nᵢ - 1))
learner.σs[a] = sqrt(learner.Ss[a] / (nᵢ - 1) / nᵢ)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this a bug as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and it was a really fun one to track down. Doing anything with MLELearner seems like it's been broken for a while (learner.σs was the estimated standard deviation of the data, not the standard error of the mean).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, I think there's some confusion: this field was always supposed to be the standard deviation, not the standard error. So I think other places in the code are the problem rather than this line.

end

return
end

@doc """
Draw a sample from the posterior for arm a.
""" ->
function Base.rand(learner::MLELearner, a::Integer)
return rand(Normal(learner.μs[a], learner.σs[a]))
end

function Base.show(io::IO, learner::MLELearner)
@printf(io, "MLELearner(%f, %f)", learner.μ₀, learner.σ₀)
end
121 changes: 121 additions & 0 deletions src/03_learners/06_james_stein_learner.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
@doc """
A JamesSteinLearner object stores the online estimated mean and variance of all
arms. Arms with zero counts use a default mean and standard deviation.
""" ->
immutable JamesSteinLearner <: Learner
ns::Vector{Int64}
oldMs::Vector{Float64}
newMs::Vector{Float64}
Ss::Vector{Float64}
ys::Vector{Float64}
ss::Vector{Float64}
μs::Vector{Float64}
σs::Vector{Float64}
μ₀::Float64
σ₀::Float64
K::Int64
end

@doc """
Create an JamesSteinLearner object specifying only a default mean and standard
deviation.
""" ->
function JamesSteinLearner(μ₀::Real, σ₀::Real)
return JamesSteinLearner(
Array(Int64, 0),
Array(Float64, 0),
Array(Float64, 0),
Array(Float64, 0),
Array(Float64, 0),
Array(Float64, 0),
Array(Float64, 0),
Array(Float64, 0),
Float64(μ₀),
Float64(σ₀),
Int64(1)
)
end

@doc """
Return the counts for each arm.
""" ->
counts(learner::JamesSteinLearner) = learner.ns

@doc """
Return the means for each arm.
""" ->
means(learner::JamesSteinLearner) = learner.μs

@doc """
Return the standard deviations for each arm.
""" ->
stds(learner::JamesSteinLearner) = learner.σs

@doc """
Reset the JamesSteinLearner object for K arms.
""" ->
function initialize!(learner::JamesSteinLearner, K::Integer)
resize!(learner.ns, K)
resize!(learner.oldMs, K)
resize!(learner.newMs, K)
resize!(learner.Ss, K)
resize!(learner.ys, K)
resize!(learner.ss, K)
resize!(learner.μs, K)
resize!(learner.σs, K)

fill!(learner.ns, 0)
fill!(learner.ys, learner.μ₀)
fill!(learner.μs, learner.μ₀)
fill!(learner.ss, learner.σ₀)
fill!(learner.σs, learner.σ₀)

return
end

@doc """
Learn about arm a on trial t from reward r.
""" ->
function learn!(
learner::JamesSteinLearner,
context::Context,
a::Integer,
r::Real,
)
learner.ns[a] += 1
nᵢ = learner.ns[a]

if nᵢ == 1
learner.oldMs[a] = r
learner.Ss[a] = learner.σ₀
learner.ys[a] = r
learner.μs[a] = r
else
learner.newMs[a] = learner.oldMs[a] + (r - learner.oldMs[a]) / nᵢ
learner.Ss[a] += (r - learner.oldMs[a]) * (r - learner.newMs[a])
learner.oldMs[a] = learner.newMs[a]
learner.ys[a] = learner.newMs[a]
learner.ss[a] = learner.Ss[a] / (nᵢ - 1) / nᵢ
y̅ = mean(learner.ys)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not a problem, but want to confirm that these steps are changing the inferred means for all means, not just the observed arm. I don't think we assume invariance anywhere, but just confirming.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's correct. Every data point changes the predictions (slightly, due to the shrinkage to an updated global mean) for all data points.

φs = min(1.0, learner.ss ./ (sumabs2(learner.ys - y̅) ./ (learner.K - 3)))
learner.μs .= y̅ .+ (1 .- φs) .* (learner.ys .- y̅)
learner.σs .= sqrt(
(1 .- φs) .* learner.ss .+
φs .* learner.ss ./ learner.K .+
2 .* φs.^2 .* (learner.ys .- y̅).^2 ./ (learner.K .- 3)
)
end

return
end

@doc """
Draw a sample from the posterior for arm a.
""" ->
function Base.rand(learner::JamesSteinLearner, a::Integer)
return rand(Normal(learner.μs[a], learner.σs[a]))
end

function Base.show(io::IO, learner::JamesSteinLearner)
@printf(io, "JamesSteinLearner(%f, %f)", learner.μ₀, learner.σ₀)
end
5 changes: 3 additions & 2 deletions src/Bandits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ module Bandits
include(joinpath("07_distributions", "03_nonstationary_multivariate_distribution.jl"))
include(joinpath("07_distributions", "04_nonstationary_contextual_distribution.jl"))
include(joinpath("07_distributions", "05_nonstationary_logistic_contextual_distribution.jl"))
include(joinpath("07_distributions", "06_nonstationary_gaussian_distribution.jl"))
include(joinpath("07_distributions", "06_nonstationary_gaussian_distribution.jl"))
include(joinpath("07_distributions", "07_nonstationary_1dgaussianprocess_distribution.jl"))

# Bandit
Expand All @@ -35,14 +35,15 @@ module Bandits

# Learners
export Learner, MLELearner, BetaLearner, BootstrapLearner,
BootstrapMLELearner, EBBetaLearner, DiscLearner
BootstrapMLELearner, EBBetaLearner, DiscLearner, JamesSteinLearner
export initialize!, counts, means, stds, learn!, preferred_arm
include(joinpath("03_learners", "01_learner.jl"))
include(joinpath("03_learners", "02_mle_learner.jl"))
include(joinpath("03_learners", "03_beta_learner.jl"))
include(joinpath("03_learners", "04_bootstrap_learner.jl"))
include(joinpath("03_learners", "05_eb_beta_learner.jl"))
include(joinpath("03_learners", "05_disc_learner.jl"))
include(joinpath("03_learners", "06_james_stein_learner.jl"))

# Algorithms
export
Expand Down