diff --git a/Project.toml b/Project.toml index 6a424db1..13e5542f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PSIS" uuid = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04" authors = ["Seth Axen and contributors"] -version = "0.2.6" +version = "0.2.7" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/core.jl b/src/core.jl index 7546aa6e..4f552a5f 100644 --- a/src/core.jl +++ b/src/core.jl @@ -152,7 +152,8 @@ end Compute Pareto smoothed importance sampling (PSIS) log weights [^VehtariSimpson2021]. -While `psis` computes smoothed log weights out-of-place, `psis!` smooths them in-place. +While `psis` computes smoothed log weights out-of-place if `smooth=true`, `psis!` smooths +them in-place. # Arguments @@ -170,6 +171,8 @@ While `psis` computes smoothed log weights out-of-place, `psis!` smooths them in # Keywords + - `smooth=true`: If `true`, the log-weights are smoothed. If `false`, only diagnostics + are computed. - `improved=false`: If `true`, use the adaptive empirical prior of [^Zhang2010]. If `false`, use the simpler prior of [^ZhangStephens2009], which is also used in [^VehtariSimpson2021]. @@ -179,7 +182,7 @@ While `psis` computes smoothed log weights out-of-place, `psis!` smooths them in - `result`: a [`PSISResult`](@ref) object containing the results of the Pareto-smoothing. -A warning is raised if the Pareto shape parameter ``k ≥ 0.7``. See [`PSISResult`](@ref) for +A warning is raised if the Pareto shape parameter ``k > 0.7``. See [`PSISResult`](@ref) for details and [`paretoshapeplot`](@ref) for a diagnostic plot. [^VehtariSimpson2021]: Vehtari A, Simpson D, Gelman A, Yao Y, Gabry J. (2021). @@ -195,16 +198,21 @@ details and [`paretoshapeplot`](@ref) for a diagnostic plot. """ psis, psis! -function psis(logr, reff=1; kwargs...) - T = float(eltype(logr)) - logw = similar(logr, T) - copyto!(logw, logr) - return psis!(logw, reff; kwargs...) +function psis(logr, reff=1; smooth::Bool=true, kwargs...) + if smooth + T = float(eltype(logr)) + logw = similar(logr, T) + copyto!(logw, logr) + else + logw = logr + end + return psis!(logw, reff; smooth=smooth, kwargs...) end function psis!( logw::AbstractVector, reff=1; + smooth::Bool=true, sorted::Bool=false, # deprecated improved::Bool=false, warn::Bool=true, @@ -222,7 +230,7 @@ function psis!( tail_inds = @view perm[2:(M + 1)] logu = logw[cutoff_ind] logw_tail = @views logw[tail_inds] - _, tail_dist = psis_tail!(logw_tail, logu, M, improved) + _, tail_dist = psis_tail!(logw_tail, logu, M; smooth=smooth, improved=improved) warn && check_pareto_shape(tail_dist) return PSISResult(logw, LogExpFunctions.logsumexp(logw), reff_val, M, tail_dist) end @@ -283,20 +291,24 @@ end tail_length(reff, S) = min(cld(S, 5), ceil(Int, 3 * sqrt(S / reff))) -function psis_tail!(logw, logμ, M=length(logw), improved=false) +function psis_tail!(logw, logμ, M=length(logw); improved::Bool=false, smooth::Bool=true) T = eltype(logw) logw_max = logw[M] # to improve numerical stability, we first shift the log-weights to have a maximum of 0, # equivalent to scaling the weights to have a maximum of 1. μ_scaled = exp(logμ - logw_max) - w = (logw .= exp.(logw .- logw_max)) + if smooth + # if smoothing, we can reuse storage + logw .= exp.(logw .- logw_max) + w = logw + else + w = exp.(logw .- logw_max) + end tail_dist_scaled = StatsBase.fit( GeneralizedParetoKnownMu(μ_scaled), w; sorted=true, improved=improved ) tail_dist_adjusted = prior_adjust_shape(tail_dist_scaled, M) - # undo the scaling - ξ = Distributions.shape(tail_dist_adjusted) - if isfinite(ξ) + if smooth p = uniform_probabilities(T, M) @inbounds for i in eachindex(logw, p) # undo scaling in the log-weights diff --git a/test/core.jl b/test/core.jl index 870bed05..e6442ff4 100644 --- a/test/core.jl +++ b/test/core.jl @@ -148,6 +148,17 @@ end end @testset "keywords" begin + @testset "$f smooth" for f in (psis, psis!) + x = randn(10, 100, 4) + xcopy = copy(x) + result = f(x; smooth=false) + @test x === result.log_weights + @test xcopy == x + + result = f(x; smooth=true) + @test result.log_weights != xcopy + end + @testset "sorted=true" begin x = randn(100) perm = sortperm(x)