-
Notifications
You must be signed in to change notification settings - Fork 0
Add functions for Empirical Bayes #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
1967c23
81b14b1
22f1fe9
de2e9aa
76dfe71
6b856f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -77,19 +77,26 @@ function learn!( | |
|
|
||
| if nᵢ == 1 | ||
| learner.oldMs[a] = r | ||
| learner.Ss[a] = 0.0 | ||
| learner.Ss[a] = learner.σ₀ | ||
| learner.μs[a] = r | ||
| else | ||
| learner.newMs[a] = learner.oldMs[a] + (r - learner.oldMs[a]) / nᵢ | ||
| learner.Ss[a] += (r - learner.oldMs[a]) * (r - learner.newMs[a]) | ||
| learner.oldMs[a] = learner.newMs[a] | ||
| learner.μs[a] = learner.newMs[a] | ||
| learner.σs[a] = sqrt(learner.Ss[a] / (nᵢ - 1)) | ||
| learner.σs[a] = sqrt(learner.Ss[a] / (nᵢ - 1) / nᵢ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was this a bug as well?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, and it was a really fun one to track down. Doing anything with MLELearner seems like it's been broken for a while (learner.σs was the estimated standard deviation of the data, not the standard error of the mean).
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait, I think there's some confusion: this field was always supposed to be the standard deviation, not the standard error. So I think other places in the code are the problem rather than this line. |
||
| end | ||
|
|
||
| return | ||
| end | ||
|
|
||
| @doc """ | ||
| Draw a sample from the posterior for arm a. | ||
| """ -> | ||
| function Base.rand(learner::MLELearner, a::Integer) | ||
| return rand(Normal(learner.μs[a], learner.σs[a])) | ||
| end | ||
|
|
||
| function Base.show(io::IO, learner::MLELearner) | ||
| @printf(io, "MLELearner(%f, %f)", learner.μ₀, learner.σ₀) | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| @doc """ | ||
| A JamesSteinLearner object stores the online estimated mean and variance of all | ||
| arms. Arms with zero counts use a default mean and standard deviation. | ||
| """ -> | ||
| immutable JamesSteinLearner <: Learner | ||
| ns::Vector{Int64} | ||
| oldMs::Vector{Float64} | ||
| newMs::Vector{Float64} | ||
| Ss::Vector{Float64} | ||
| ys::Vector{Float64} | ||
| ss::Vector{Float64} | ||
| μs::Vector{Float64} | ||
| σs::Vector{Float64} | ||
| μ₀::Float64 | ||
| σ₀::Float64 | ||
| K::Int64 | ||
| end | ||
|
|
||
| @doc """ | ||
| Create an JamesSteinLearner object specifying only a default mean and standard | ||
| deviation. | ||
| """ -> | ||
| function JamesSteinLearner(μ₀::Real, σ₀::Real) | ||
| return JamesSteinLearner( | ||
| Array(Int64, 0), | ||
| Array(Float64, 0), | ||
| Array(Float64, 0), | ||
| Array(Float64, 0), | ||
| Array(Float64, 0), | ||
| Array(Float64, 0), | ||
| Array(Float64, 0), | ||
| Array(Float64, 0), | ||
| Float64(μ₀), | ||
| Float64(σ₀), | ||
| Int64(1) | ||
| ) | ||
| end | ||
|
|
||
| @doc """ | ||
| Return the counts for each arm. | ||
| """ -> | ||
| counts(learner::JamesSteinLearner) = learner.ns | ||
|
|
||
| @doc """ | ||
| Return the means for each arm. | ||
| """ -> | ||
| means(learner::JamesSteinLearner) = learner.μs | ||
|
|
||
| @doc """ | ||
| Return the standard deviations for each arm. | ||
| """ -> | ||
| stds(learner::JamesSteinLearner) = learner.σs | ||
|
|
||
| @doc """ | ||
| Reset the JamesSteinLearner object for K arms. | ||
| """ -> | ||
| function initialize!(learner::JamesSteinLearner, K::Integer) | ||
| resize!(learner.ns, K) | ||
| resize!(learner.oldMs, K) | ||
| resize!(learner.newMs, K) | ||
| resize!(learner.Ss, K) | ||
| resize!(learner.ys, K) | ||
| resize!(learner.ss, K) | ||
| resize!(learner.μs, K) | ||
| resize!(learner.σs, K) | ||
|
|
||
| fill!(learner.ns, 0) | ||
| fill!(learner.ys, learner.μ₀) | ||
| fill!(learner.μs, learner.μ₀) | ||
| fill!(learner.ss, learner.σ₀) | ||
| fill!(learner.σs, learner.σ₀) | ||
|
|
||
| return | ||
| end | ||
|
|
||
| @doc """ | ||
| Learn about arm a on trial t from reward r. | ||
| """ -> | ||
| function learn!( | ||
| learner::JamesSteinLearner, | ||
| context::Context, | ||
| a::Integer, | ||
| r::Real, | ||
| ) | ||
| learner.ns[a] += 1 | ||
| nᵢ = learner.ns[a] | ||
|
|
||
| if nᵢ == 1 | ||
| learner.oldMs[a] = r | ||
| learner.Ss[a] = learner.σ₀ | ||
| learner.ys[a] = r | ||
| learner.μs[a] = r | ||
| else | ||
| learner.newMs[a] = learner.oldMs[a] + (r - learner.oldMs[a]) / nᵢ | ||
| learner.Ss[a] += (r - learner.oldMs[a]) * (r - learner.newMs[a]) | ||
| learner.oldMs[a] = learner.newMs[a] | ||
| learner.ys[a] = learner.newMs[a] | ||
| learner.ss[a] = learner.Ss[a] / (nᵢ - 1) / nᵢ | ||
| y̅ = mean(learner.ys) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably not a problem, but want to confirm that these steps are changing the inferred means for all means, not just the observed arm. I don't think we assume invariance anywhere, but just confirming.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that's correct. Every data point changes the predictions (slightly, due to the shrinkage to an updated global mean) for all data points. |
||
| φs = min(1.0, learner.ss ./ (sumabs2(learner.ys - y̅) ./ (learner.K - 3))) | ||
| learner.μs .= y̅ .+ (1 .- φs) .* (learner.ys .- y̅) | ||
| learner.σs .= sqrt( | ||
| (1 .- φs) .* learner.ss .+ | ||
| φs .* learner.ss ./ learner.K .+ | ||
| 2 .* φs.^2 .* (learner.ys .- y̅).^2 ./ (learner.K .- 3) | ||
| ) | ||
| end | ||
|
|
||
| return | ||
| end | ||
|
|
||
| @doc """ | ||
| Draw a sample from the posterior for arm a. | ||
| """ -> | ||
| function Base.rand(learner::JamesSteinLearner, a::Integer) | ||
| return rand(Normal(learner.μs[a], learner.σs[a])) | ||
| end | ||
|
|
||
| function Base.show(io::IO, learner::JamesSteinLearner) | ||
| @printf(io, "JamesSteinLearner(%f, %f)", learner.μ₀, learner.σ₀) | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these bug fixes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not clear that I should really be doing this. We're already abusing the
learner.σ₀notation a bit, since we aren't "really" using a prior as the notation would suggest. This just ensures that the standard error is never actually exactly zero. For a Bernoulli DGP, we may have an estimated standard deviation of zero for a while until we finally observe a success. So this is basically like an Agresti-Coull estimate, with this change.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. This seems like we should potentially be using something like
NaNinstead. Or at least not calling this MLE anymore.