diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 9b1355f3..9c7d336b 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -1,6 +1,8 @@ # Run this command to always ignore these in local `git blame`: # git config blame.ignoreRevsFile .git-blame-ignore-revs +# Run formatter +70fd432667fb431e08ba52728734108d822a1922 # Run formatter 21038a047c023330876feb9259cd5c92add3ca81 # Run formatter after bracket alignment removal diff --git a/benchmark/iteration.jl b/benchmark/iteration.jl index 31819d63..6e5cd9b5 100644 --- a/benchmark/iteration.jl +++ b/benchmark/iteration.jl @@ -43,7 +43,8 @@ for (setname, set) in (("tups", tups), ("SAs", SAs)) # seems to lead to slow benchmarks. FIs_suite["couple_same"] = @benchmarkable StochasticAD.couple(typeof($Δs), $Δs_all) - FIs_suite["combine_same"] = @benchmarkable StochasticAD.combine(typeof($Δs), + FIs_suite["combine_same"] = @benchmarkable StochasticAD.combine( + typeof($Δs), $Δs_all) end end diff --git a/docs/make.jl b/docs/make.jl index 928716a4..2c6b6d6d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -18,11 +18,11 @@ pages = [ "tutorials/random_walk.md", "tutorials/game_of_life.md", "tutorials/particle_filter.md", - "tutorials/optimizations.md", + "tutorials/optimizations.md" ], "Public API" => "public_api.md", "Developer documentation" => "devdocs.md", - "Limitations" => "limitations.md", + "Limitations" => "limitations.md" ] ### Make docs diff --git a/docs/src/devdocs.md b/docs/src/devdocs.md index 4d16e636..c1cd1e7c 100644 --- a/docs/src/devdocs.md +++ b/docs/src/devdocs.md @@ -99,5 +99,6 @@ nothing # hide ## Distribution-specific customization of differentiation algorithm ```@docs -StochasticAD.randst +randst +InversionMethodDerivativeCoupling ``` \ No newline at end of file diff --git a/src/StochasticAD.jl b/src/StochasticAD.jl index e9f8f64a..d1dce669 100644 --- a/src/StochasticAD.jl +++ b/src/StochasticAD.jl @@ -2,14 +2,16 @@ module StochasticAD ### Public API -export stochastic_triple, derivative_contribution, perturbations, smooth_triple, dual_number # For working with stochastic triples +export stochastic_triple, derivative_contribution, perturbations, smooth_triple, + dual_number, StochasticTriple # For working with stochastic triples export derivative_estimate, StochasticModel, stochastic_gradient # Higher level functionality export new_weight # Particle resampling export PrunedFIsBackend, - PrunedFIsAggressiveBackend, DictFIsBackend, SmoothedFIsBackend, - StrategyWrapperFIsBackend + PrunedFIsAggressiveBackend, DictFIsBackend, SmoothedFIsBackend, + StrategyWrapperFIsBackend export PrunedFIs, PrunedFIsAggressive, DictFIs, SmoothedFIs, StrategyWrapperFIs export randst +export InversionMethodDerivativeCoupling ### Imports diff --git a/src/backends/abstract_wrapper.jl b/src/backends/abstract_wrapper.jl index 440d6679..42411812 100644 --- a/src/backends/abstract_wrapper.jl +++ b/src/backends/abstract_wrapper.jl @@ -39,18 +39,22 @@ StochasticAD.valtype(Δs::AbstractWrapperFIs) = StochasticAD.valtype(Δs.Δs) function StochasticAD.couple(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}}, Δs_all; + rep = nothing, kwargs...) where {V, FIs} _Δs_all = StochasticAD.structural_map(Δs -> Δs.Δs, Δs_all) + _rep_kwarg = !isnothing(rep) ? (; rep = rep.Δs) : (;) return reconstruct_wrapper(StochasticAD.get_any(Δs_all), - StochasticAD.couple(FIs, _Δs_all; kwargs...)) + StochasticAD.couple(FIs, _Δs_all; _rep_kwarg..., kwargs...)) end function StochasticAD.combine(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}}, Δs_all; + rep = nothing, kwargs...) where {V, FIs} _Δs_all = StochasticAD.structural_map(Δs -> Δs.Δs, Δs_all) + _rep_kwarg = !isnothing(rep) ? (; rep = rep.Δs) : (;) return reconstruct_wrapper(StochasticAD.get_any(Δs_all), - StochasticAD.combine(FIs, _Δs_all; kwargs...)) + StochasticAD.combine(FIs, _Δs_all; _rep_kwarg..., kwargs...)) end function StochasticAD.get_rep(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}}, @@ -61,8 +65,10 @@ function StochasticAD.get_rep(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}}, StochasticAD.get_rep(FIs, _Δs_all; kwargs...)) end -function StochasticAD.scalarize(Δs::AbstractWrapperFIs; kwargs...) - return StochasticAD.structural_map(StochasticAD.scalarize(Δs.Δs; kwargs...)) do _Δs +function StochasticAD.scalarize(Δs::AbstractWrapperFIs; rep = nothing, kwargs...) + _rep_kwarg = !isnothing(rep) ? (; rep = rep.Δs) : (;) + return StochasticAD.structural_map(StochasticAD.scalarize( + Δs.Δs; _rep_kwarg..., kwargs...)) do _Δs reconstruct_wrapper(Δs, _Δs) end end @@ -98,8 +104,13 @@ function StochasticAD.derivative_contribution(Δs::AbstractWrapperFIs) StochasticAD.derivative_contribution(Δs.Δs) end -function (::Type{<:AbstractWrapperFIs{V}})(Δs::AbstractWrapperFIs) where {V} - reconstruct_wrapper(Δs, StochasticAD.similar_type(typeof(Δs.Δs), V)(Δs.Δs)) +function Base.convert(::Type{<:AbstractWrapperFIs{V}}, Δs::AbstractWrapperFIs) where {V} + reconstruct_wrapper(Δs, convert(StochasticAD.similar_type(typeof(Δs.Δs), V), Δs.Δs)) +end + +function StochasticAD.send_signal( + Δs::AbstractWrapperFIs, signal::StochasticAD.AbstractPerturbationSignal) + reconstruct_wrapper(Δs, StochasticAD.send_signal(Δs.Δs, signal)) end function Base.show(io::IO, Δs::AbstractWrapperFIs) diff --git a/src/backends/dict.jl b/src/backends/dict.jl index 37aae715..7ca8e736 100644 --- a/src/backends/dict.jl +++ b/src/backends/dict.jl @@ -75,7 +75,7 @@ StochasticAD.create_Δs(::DictFIsBackend, V) = DictFIs{V}(DictFIsState()) ### Convert type of a backend -function DictFIs{V}(Δs::DictFIs) where {V} +function Base.convert(::Type{DictFIs{V}}, Δs::DictFIs) where {V} DictFIs{V}(convert(Dictionary{InfinitesimalEvent, V}, Δs.dict), Δs.state) end @@ -88,7 +88,9 @@ function StochasticAD.derivative_contribution(Δs::DictFIs{V}) where {V} sum((Δ * event.w for (event, Δ) in pairs(Δs.dict)), init = zero(V) * 0.0) end -StochasticAD.perturbations(Δs::DictFIs) = [(Δ, event.w) for (event, Δ) in pairs(Δs.dict)] +function StochasticAD.perturbations(Δs::DictFIs) + [(; Δ, weight = event.w, state = event) for (event, Δ) in pairs(Δs.dict)] +end ### Unary propagation @@ -99,7 +101,7 @@ function StochasticAD.weighted_map_Δs(f, Δs::DictFIs; kwargs...) mapped_weights = last.(mapped_values_and_weights) scaled_events = map((event, a) -> InfinitesimalEvent(event.tag, event.w * a), keys(Δs.dict), - mapped_weights) + mapped_weights) # TODO: should original events (with old tag) also be modified? dict = Dictionary(scaled_events, mapped_values) DictFIs(dict, Δs.state) end @@ -119,20 +121,23 @@ end function StochasticAD.couple(FIs::Type{<:DictFIs}, Δs_all; rep = StochasticAD.get_rep(FIs, Δs_all), - out_rep = nothing) + out_rep = nothing, + kwargs...) all_keys = Iterators.map(StochasticAD.structural_iterate(Δs_all)) do Δs keys(Δs.dict) end distinct_keys = unique(all_keys |> Iterators.flatten) - Δs_coupled_dict = [StochasticAD.structural_map(Δs -> isassigned(Δs.dict, key) ? - Δs.dict[key] : - zero(eltype(Δs.dict)), Δs_all) + Δs_coupled_dict = [StochasticAD.structural_map( + Δs -> isassigned(Δs.dict, key) ? + Δs.dict[key] : + zero(eltype(Δs.dict)), + Δs_all) for key in distinct_keys] DictFIs(Dictionary(distinct_keys, Δs_coupled_dict), rep.state) end function StochasticAD.combine(FIs::Type{<:DictFIs}, Δs_all; - rep = StochasticAD.get_rep(FIs, Δs_all)) + rep = StochasticAD.get_rep(FIs, Δs_all), kwargs...) Δs_dicts = Iterators.map(Δs -> Δs.dict, StochasticAD.structural_iterate(Δs_all)) Δs_combined_dict = reduce(Δs_dicts) do Δs_dict1, Δs_dict2 mergewith((x, y) -> StochasticAD.structural_map(+, x, y), Δs_dict1, Δs_dict2) @@ -141,6 +146,7 @@ function StochasticAD.combine(FIs::Type{<:DictFIs}, Δs_all; end function StochasticAD.scalarize(Δs::DictFIs; out_rep = nothing) + # TODO: use vcat here? tupleify(Δ1, Δ2) = StochasticAD.structural_map(tuple, Δ1, Δ2) Δ_all_allkeys = foldl(tupleify, values(Δs.dict)) Δ_all_rep = first(values(Δs.dict)) diff --git a/src/backends/pruned.jl b/src/backends/pruned.jl index 780df84d..7b7af4d8 100644 --- a/src/backends/pruned.jl +++ b/src/backends/pruned.jl @@ -17,11 +17,21 @@ struct PrunedFIsBackend <: StochasticAD.AbstractFIsBackend end State maintained by pruning backend. """ mutable struct PrunedFIsState + # tag is in place to avoid relying on mutability for uniqueness. + tag::Int32 weight::Float64 valid::Bool - PrunedFIsState(valid = true) = new(0.0, valid) + function PrunedFIsState(valid = true) + state = new(0, 0.0, valid) + state.tag = objectid(state) % typemax(Int32) + return state + end end +Base.:(==)(state1::PrunedFIsState, state2::PrunedFIsState) = state1.tag == state2.tag +# c.f. https://github.com/JuliaLang/julia/blob/61c3521613767b2af21dfa5cc5a7b8195c5bdcaf/base/hashing.jl#L38C45-L38C51 +Base.hash(state::PrunedFIsState) = state.tag + """ PrunedFIs{V} <: StochasticAD.AbstractFIs{V} @@ -59,7 +69,7 @@ StochasticAD.create_Δs(::PrunedFIsBackend, V) = PrunedFIs{V}(PrunedFIsState(fal ### Convert type of a backend -function PrunedFIs{V}(Δs::PrunedFIs) where {V} +function Base.convert(::Type{PrunedFIs{V}}, Δs::PrunedFIs) where {V} PrunedFIs{V}(convert(V, Δs.Δ), Δs.state) end @@ -79,7 +89,11 @@ pruned_value(Δs::PrunedFIs{<:Tuple}) = isempty(Δs) ? zero.(Δs.Δ) : Δs.Δ pruned_value(Δs::PrunedFIs{<:AbstractArray}) = isempty(Δs) ? zero.(Δs.Δ) : Δs.Δ StochasticAD.derivative_contribution(Δs::PrunedFIs) = pruned_value(Δs) * Δs.state.weight -StochasticAD.perturbations(Δs::PrunedFIs) = ((pruned_value(Δs), Δs.state.weight),) +function StochasticAD.perturbations(Δs::PrunedFIs) + return ((; Δ = pruned_value(Δs), + weight = Δs.state.valid ? Δs.state.weight : zero(Δs.state.weight), + state = Δs.state),) +end ### Unary propagation @@ -94,9 +108,8 @@ StochasticAD.alltrue(f, Δs::PrunedFIs) = f(Δs.Δ) ### Coupling -function StochasticAD.get_rep(::Type{<:PrunedFIs}, Δs_all) - # The code below is a bit ridiculous, but it's faster than `first` for small structures:) - return StochasticAD.get_any(Δs_all) +function StochasticAD.get_rep(FIs::Type{<:PrunedFIs}, Δs_all) + return empty(FIs) end function get_pruned_state(Δs_all) @@ -105,7 +118,7 @@ function get_pruned_state(Δs_all) isapproxzero(Δs) && return (total_weight, new_state) candidate_state = Δs.state if !candidate_state.valid || - ((new_state !== nothing) && (candidate_state === new_state)) + ((new_state !== nothing) && (candidate_state == new_state)) return (total_weight, new_state) end w = candidate_state.weight @@ -131,14 +144,15 @@ end # for pruning, coupling amounts to getting rid of perturbed values that have been # lazily kept around even after (aggressive or lazy) pruning made the perturbation invalid. # rep is unused. -function StochasticAD.couple(::Type{<:PrunedFIs}, Δs_all; rep = nothing, out_rep = nothing) +function StochasticAD.couple( + ::Type{<:PrunedFIs}, Δs_all; rep = nothing, out_rep = nothing, kwargs...) state = get_pruned_state(Δs_all) Δ_coupled = StochasticAD.structural_map(pruned_value, Δs_all) # TODO: perhaps a performance optimization possible here PrunedFIs(Δ_coupled, state) end # basically couple combined with a sum. -function StochasticAD.combine(::Type{<:PrunedFIs}, Δs_all; rep = nothing) +function StochasticAD.combine(::Type{<:PrunedFIs}, Δs_all; rep = nothing, kwargs...) state = get_pruned_state(Δs_all) Δ_combined = sum(pruned_value, StochasticAD.structural_iterate(Δs_all)) PrunedFIs(Δ_combined, state) @@ -151,7 +165,7 @@ function StochasticAD.scalarize(Δs::PrunedFIs; out_rep = nothing) end function StochasticAD.filter_state(Δs::PrunedFIs{V}, state) where {V} - Δs.state === state ? pruned_value(Δs) : zero(V) + Δs.state == state ? pruned_value(Δs) : zero(V) end ### Miscellaneous diff --git a/src/backends/pruned_aggressive.jl b/src/backends/pruned_aggressive.jl index c74f933e..25745dbe 100644 --- a/src/backends/pruned_aggressive.jl +++ b/src/backends/pruned_aggressive.jl @@ -78,7 +78,7 @@ end ### Convert type of a backend -function PrunedFIsAggressive{V}(Δs::PrunedFIsAggressive) where {V} +function Base.convert(::Type{PrunedFIsAggressive{V}}, Δs::PrunedFIsAggressive) where {V} PrunedFIsAggressive{V}(convert(V, Δs.Δ), Δs.tag, Δs.state) end @@ -96,7 +96,9 @@ function StochasticAD.derivative_contribution(Δs::PrunedFIsAggressive) pruned_value(Δs) * Δs.state.weight end -StochasticAD.perturbations(Δs::PrunedFIsAggressive) = ((pruned_value(Δs), Δs.state.weight),) +function StochasticAD.perturbations(Δs::PrunedFIsAggressive) + ((; Δ = pruned_value(Δs), weight = Δs.state.weight, state = Δs.state),) +end ### Unary propagation @@ -120,7 +122,7 @@ end # lazily kept around even after (aggressive or lazy) pruning made the perturbation invalid. function StochasticAD.couple(FIs::Type{<:PrunedFIsAggressive}, Δs_all; rep = StochasticAD.get_rep(FIs, Δs_all), - out_rep = nothing) + out_rep = nothing, kwargs...) state = rep.state Δ_coupled = StochasticAD.structural_map(pruned_value, Δs_all) # TODO: perhaps a performance optimization possible here PrunedFIsAggressive(Δ_coupled, state.active_tag, state) @@ -128,7 +130,7 @@ end # basically couple combined with a sum. function StochasticAD.combine(FIs::Type{<:PrunedFIsAggressive}, Δs_all; - rep = StochasticAD.get_rep(FIs, Δs_all)) + rep = StochasticAD.get_rep(FIs, Δs_all), kwargs...) state = rep.state Δ_combined = sum(pruned_value, StochasticAD.structural_iterate(Δs_all)) PrunedFIsAggressive(Δ_combined, state.active_tag, state) diff --git a/src/backends/smoothed.jl b/src/backends/smoothed.jl index 34baa781..fb54ab79 100644 --- a/src/backends/smoothed.jl +++ b/src/backends/smoothed.jl @@ -46,10 +46,9 @@ StochasticAD.create_Δs(::SmoothedFIsBackend, V) = SmoothedFIs{V}(0.0) ### Convert type of a backend -function (::Type{<:SmoothedFIs{V}})(Δs::SmoothedFIs) where {V} - SmoothedFIs{V}(Δs.δ) +function Base.convert(FIs::Type{<:SmoothedFIs{V}}, Δs::SmoothedFIs) where {V} + SmoothedFIs{V}(Δs.δ)::FIs end -(::Type{SmoothedFIs{V}})(Δs::SmoothedFIs) where {V} = SmoothedFIs{V}(Δs.δ) ### Getting information about perturbations @@ -70,11 +69,12 @@ StochasticAD.alltrue(f, Δs::SmoothedFIs) = true StochasticAD.get_rep(::Type{<:SmoothedFIs}, Δs_all) = StochasticAD.get_any(Δs_all) -function StochasticAD.couple(::Type{<:SmoothedFIs}, Δs_all; rep = nothing, out_rep) +function StochasticAD.couple( + ::Type{<:SmoothedFIs}, Δs_all; rep = nothing, out_rep, kwargs...) SmoothedFIs{typeof(out_rep)}(StochasticAD.structural_map(Δs -> Δs.δ, Δs_all)) end -function StochasticAD.combine(::Type{<:SmoothedFIs}, Δs_all; rep = nothing) +function StochasticAD.combine(::Type{<:SmoothedFIs}, Δs_all; rep = nothing, kwargs...) V_out = StochasticAD.valtype(first(StochasticAD.structural_iterate(Δs_all))) Δ_combined = sum(Δs -> Δs.δ, StochasticAD.structural_iterate(Δs_all)) SmoothedFIs{V_out}(Δ_combined) diff --git a/src/backends/strategy_wrapper.jl b/src/backends/strategy_wrapper.jl index ab94636d..6724e121 100644 --- a/src/backends/strategy_wrapper.jl +++ b/src/backends/strategy_wrapper.jl @@ -7,7 +7,7 @@ export StrategyWrapperFIsBackend, StrategyWrapperFIs struct StrategyWrapperFIsBackend{ B <: StochasticAD.AbstractFIsBackend, - S <: StochasticAD.AbstractPerturbationStrategy, + S <: StochasticAD.AbstractPerturbationStrategy } <: StochasticAD.AbstractFIsBackend backend::B @@ -17,7 +17,7 @@ end struct StrategyWrapperFIs{ V, FIs <: StochasticAD.AbstractFIs{V}, - S <: StochasticAD.AbstractPerturbationStrategy, + S <: StochasticAD.AbstractPerturbationStrategy } <: AbstractWrapperFIs{V, FIs} Δs::FIs @@ -38,7 +38,8 @@ function AbstractWrapperFIsModule.reconstruct_wrapper(wrapper_Δs::StrategyWrapp return StrategyWrapperFIs(Δs, wrapper_Δs.strategy) end -function AbstractWrapperFIsModule.reconstruct_wrapper(::Type{ +function AbstractWrapperFIsModule.reconstruct_wrapper( + ::Type{ <:StrategyWrapperFIs{V, FIs, S}, }, Δs) where {V, FIs, S} diff --git a/src/discrete_randomness.jl b/src/discrete_randomness.jl index 101f5a8c..ca893c17 100644 --- a/src/discrete_randomness.jl +++ b/src/discrete_randomness.jl @@ -31,13 +31,37 @@ _get_support(::Bernoulli) = (0, 1) # the map below looks a bit silly, but it gives us a collection of the categories with the same structure as probs(d). _get_support(d::Categorical) = map((val, prob) -> val, 1:ncategories(d), probs(d)) -# Derivative coupling approaches, determining which weighted perturbations to consider +## Derivative couplings +# Derivative coupling approaches, determining which weighted perturbations to consider abstract type AbstractDerivativeCoupling end -struct InversionMethodDerivativeCoupling end -## Strategies for precisely which perturbations to form given a derivative coupling +""" + InversionMethodDerivativeCoupling(; mode::Val = Val(:positive_weight), handle_zeroprob::Val = Val(true)) + +Specifies an inversion method coupling for generating perturbations from a univariate distribution. +Valid choices of `mode` are `Val(:positive_weight)`, `Val(:always_right)`, and `Val(:always_left)`. +# Example +```jldoctest +julia> using StochasticAD, Distributions, Random; Random.seed!(4321); + +julia> function X(p) + return randst(Bernoulli(1 - p); derivative_coupling = InversionMethodDerivativeCoupling(; mode = Val(:always_right))) + end +X (generic function with 1 method) + +julia> stochastic_triple(X, 0.5) +StochasticTriple of Int64: +0 + 0ε + (1 with probability -2.0ε) +``` +""" +Base.@kwdef struct InversionMethodDerivativeCoupling{M, HZP} + mode::M = Val(:positive_weight) + handle_zeroprob::HZP = Val(true) +end + +# Strategies for precisely which perturbations to form given a derivative coupling struct SingleSidedStrategy <: AbstractPerturbationStrategy end struct TwoSidedStrategy <: AbstractPerturbationStrategy end struct SmoothedStraightThroughStrategy <: AbstractPerturbationStrategy end @@ -46,17 +70,22 @@ struct IgnoreDiscreteStrategy <: AbstractPerturbationStrategy end new_Δs_strategy(Δs) = SingleSidedStrategy() +# Derivative coupling high-level interface + """ δtoΔs(d, val, δ, Δs::AbstractFIs) Given the parameter `val` of a distribution `d` and an infinitesimal change `δ`, return the discrete change in the output, with a similar representation to `Δs`. """ -δtoΔs(d, val, δ, Δs, coupling) = δtoΔs(d, val, δ, Δs, coupling, new_Δs_strategy(Δs)) -δtoΔs(d, val, δ, Δs, coupling, ::SingleSidedStrategy) = _δtoΔs(d, val, δ, Δs, coupling) -function δtoΔs(d, val, δ, Δs, coupling, ::TwoSidedStrategy) - Δs1 = _δtoΔs(d, val, δ, Δs, coupling) - Δs2 = _δtoΔs(d, val, -δ, Δs, coupling) +δtoΔs(d, val, δ, Δs, derivative_coupling) = δtoΔs( + d, val, δ, Δs, derivative_coupling, new_Δs_strategy(Δs)) +function δtoΔs(d, val, δ, Δs, derivative_coupling, ::SingleSidedStrategy) + _δtoΔs(d, val, δ, Δs, derivative_coupling) +end +function δtoΔs(d, val, δ, Δs, derivative_coupling, ::TwoSidedStrategy) + Δs1 = _δtoΔs(d, val, δ, Δs, derivative_coupling) + Δs2 = _δtoΔs(d, val, -δ, Δs, derivative_coupling) return combine((scale(Δs1, 0.5), scale(Δs2, -0.5))) end # TODO: implement this ST for other distributions and couplings, if meaningful? @@ -64,35 +93,39 @@ function δtoΔs(d::Union{Bernoulli, Binomial}, val, δ, Δs, - coupling::InversionMethodDerivativeCoupling, + derivative_coupling::InversionMethodDerivativeCoupling, ::StraightThroughStrategy) p = succprob(d) - Δs1 = _δtoΔs(d, val, δ, Δs, coupling) - Δs2 = _δtoΔs(d, val, -δ, Δs, coupling) + Δs1 = _δtoΔs(d, val, δ, Δs, derivative_coupling) + Δs2 = _δtoΔs(d, val, -δ, Δs, derivative_coupling) return combine((scale(Δs1, 1 - p), scale(Δs2, -p))) end -δtoΔs(d, val::V, δ, Δs, coupling, ::IgnoreDiscreteStrategy) where {V} = similar_empty(Δs, V) +function δtoΔs(d, val::V, δ, Δs, derivative_coupling, ::IgnoreDiscreteStrategy) where {V} + similar_empty(Δs, V) +end # Implement straight through strategy, works for all distrs, but does something that is only # meaningful for smoothed backends (using one(val)) -function δtoΔs(d, val, δ, Δs, coupling, ::SmoothedStraightThroughStrategy) +function δtoΔs(d, val, δ, Δs, derivative_coupling, ::SmoothedStraightThroughStrategy) p = _get_parameter(d) δout = ForwardDiff.derivative(a -> mean(_reconstruct(d, p + a * δ)), 0.0) return similar_new(Δs, one(val), δout) end -## Stochastic derivative rules for discrete distributions +# Derivative coupling low-level implementations function _δtoΔs(d::Geometric, val::V, δ::Real, Δs::AbstractFIs, - ::InversionMethodDerivativeCoupling) where {V <: Signed} + derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed} p = succprob(d) - if δ > 0 + if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) || + (derivative_coupling.mode isa Val{:always_right}) return val > 0 ? similar_new(Δs, -one(V), δ * val / p / (1 - p)) : similar_empty(Δs, V) - elseif δ < 0 + elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) || + (derivative_coupling.mode isa Val{:always_left}) return similar_new(Δs, one(V), -δ * (val + 1) / p) else return similar_empty(Δs, V) @@ -103,11 +136,13 @@ function _δtoΔs(d::Bernoulli, val::V, δ::Real, Δs::AbstractFIs, - ::InversionMethodDerivativeCoupling) where {V <: Signed} + derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed} p = succprob(d) - if δ > 0 + if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) || + (derivative_coupling.mode isa Val{:always_right}) return isone(val) ? similar_empty(Δs, V) : similar_new(Δs, one(V), δ / (1 - p)) - elseif δ < 0 + elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) || + (derivative_coupling.mode isa Val{:always_left}) return isone(val) ? similar_new(Δs, -one(V), -δ / p) : similar_empty(Δs, V) else return similar_empty(Δs, V) @@ -118,13 +153,15 @@ function _δtoΔs(d::Binomial, val::V, δ::Real, Δs::AbstractFIs, - ::InversionMethodDerivativeCoupling) where {V <: Signed} + derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed} p = succprob(d) n = ntrials(d) - if δ > 0 + if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) || + (derivative_coupling.mode isa Val{:always_right}) return val == n ? similar_empty(Δs, V) : similar_new(Δs, one(V), δ * (n - val) / (1 - p)) - elseif δ < 0 + elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) || + (derivative_coupling.mode isa Val{:always_left}) return !iszero(val) ? similar_new(Δs, -one(V), -δ * val / p) : similar_empty(Δs, V) else return similar_empty(Δs, V) @@ -135,11 +172,13 @@ function _δtoΔs(d::Poisson, val::V, δ::Real, Δs::AbstractFIs, - ::InversionMethodDerivativeCoupling) where {V <: Signed} + derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed} p = mean(d) # rate - if δ > 0 + if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) || + (derivative_coupling.mode isa Val{:always_right}) return similar_new(Δs, 1, δ) - elseif δ < 0 + elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) || + (derivative_coupling.mode isa Val{:always_left}) return val > 0 ? similar_new(Δs, -1, -δ * val / p) : similar_empty(Δs, V) else return similar_empty(Δs, V) @@ -150,36 +189,48 @@ function _δtoΔs(d::Categorical, val::V, δs, Δs::AbstractFIs, - ::InversionMethodDerivativeCoupling) where {V <: Signed} + derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed} p = params(d)[1] - left_sum = sum(δs[1:(val - 1)], init = zero(V)) - right_sum = -sum(δs[(val + 1):end], init = zero(V)) - - if left_sum > 0 - stop = rand() * left_sum - upto = zero(eltype(δs)) # The "upto" logic handles an edge case of probability 0 events that have non-zero derivative. - # It's a lot of logic to handle an edge case, but hopefully it's optimized away. - local left_nonzero - for i in (val - 1):-1:1 - if !iszero(p[i]) || ((upto += δs[i]) > stop) - left_nonzero = i - break + left_sum = sum(δs[1:(val - 1)], init = zero(eltype(δs))) + right_sum = -sum(δs[(val + 1):end], init = zero(eltype(δs))) + + if (derivative_coupling.mode isa Val{:positive_weight} && left_sum > 0) || + (derivative_coupling.mode isa Val{:always_left} && !iszero(left_sum)) + # compute left_nonzero + if derivative_coupling.handle_zeroprob isa Val{true} + stop = rand() * left_sum + upto = zero(eltype(δs)) # The "upto" logic handles an edge case of probability 0 events that have non-zero derivative. + # It's a lot of logic to handle an edge case, but hopefully it's optimized away. + left_nonzero = val + for i in (val - 1):-1:1 + if !iszero(p[i]) || ((upto += δs[i]) > stop) + left_nonzero = i + break + end end + else + left_nonzero = val - 1 end Δs_left = similar_new(Δs, left_nonzero - val, left_sum / p[val]) else Δs_left = similar_empty(Δs, typeof(val)) end - if right_sum < 0 - stop = -rand() * right_sum - upto = zero(eltype(δs)) - local right_nonzero - for i in (val + 1):length(p) - if !iszero(p[i]) || ((upto += δs[i]) > stop) - right_nonzero = i - break + if (derivative_coupling.mode isa Val{:positive_weight} && right_sum < 0) || + (derivative_coupling.mode isa Val{:always_right} && !iszero(right_sum)) + # compute right_nonzero + if derivative_coupling.handle_zeroprob isa Val{true} + stop = -rand() * right_sum + upto = zero(eltype(δs)) + right_nonzero = val + for i in (val + 1):length(p) + if !iszero(p[i]) || ((upto += δs[i]) > stop) + right_nonzero = i + break + end end + else + right_nonzero = val + 1 end Δs_right = similar_new(Δs, right_nonzero - val, -right_sum / p[val]) else @@ -189,6 +240,52 @@ function _δtoΔs(d::Categorical, return combine((Δs_left, Δs_right); rep = Δs) end +## Propagation couplings + +abstract type AbstractPropagationCoupling end + +""" + InversionMethodPropagationCoupling + +Specifies an inversion method coupling for propagating perturbations. +""" +struct InversionMethodPropagationCoupling <: AbstractPropagationCoupling end + +function _map_func(d, val, Δ, ::InversionMethodPropagationCoupling) + # construct alternative distribution + p = _get_parameter(d) + alt_d = _reconstruct(d, p + Δ) + # compute bounds on original ω + low = cdf(d, val - 1) + high = cdf(d, val) + # sample alternative value + alt_val = quantile(alt_d, rand(RNG) * (high - low) + low) + return convert(Signed, alt_val - val) +end + +function _map_enumeration(d, val, Δ, ::InversionMethodPropagationCoupling) + # construct alternative distribution + p = _get_parameter(d) + alt_d = _reconstruct(d, p + Δ) + # compute bounds on original ω + low = cdf(d, val - 1) + high = cdf(d, val) + if _has_finite_support(alt_d) + map(_get_support(alt_d)) do alt_val + # interval intersect of (cdf(alt_d, alt_val - 1), cdf(alt_d, alt_val)) and (low, high) + alt_low = cdf(alt_d, alt_val - 1) + alt_high = cdf(alt_d, alt_val) + prob_alt = max(0.0, min(alt_high, high) - max(alt_low, low)) / + (high - low) + return (alt_val - val, prob_alt) + end + else + error("enumeration not supported for distribution $d. Does $d have finite support?") + end +end + +## Overloading of random sampling + # Define randst interface """ @@ -196,7 +293,8 @@ end When no keyword arguments are provided, `randst` behaves identically to `rand(rng, d)` in both ordinary computation and for stochastic triple dispatches. However, `randst` also allows the user to provide various keyword arguments -for customizing the differentiation logic. The set of allowed keyword arguments depends on the type of `d`. +for customizing the differentiation logic. The set of allowed keyword arguments depends on the type of `d`: a couple +common ones are `derivative_coupling` and `propagation_coupling`. For developers: if you wish to accept custom keyword arguments in a stochastic triple dispatch, you should overload `randst`, and redirect `rand` to your `randst` method. If you do not, it suffices to just overload `rand`. @@ -213,43 +311,20 @@ for dist in [:Geometric, :Bernoulli, :Binomial, :Poisson] end @eval function randst(rng::AbstractRNG, d_st::$dist{StochasticTriple{T, V, FIs}}; - perturbation_map_kwargs = (;), - coupling = InversionMethodDerivativeCoupling()) where {T, V, FIs} + Δ_kwargs = (;), + derivative_coupling = InversionMethodDerivativeCoupling(), + propagation_coupling = InversionMethodPropagationCoupling()) where {T, V, FIs} st = _get_parameter(d_st) d = _reconstruct(d_st, st.value) val = convert(Signed, rand(rng, d)) - Δs1 = δtoΔs(d, val, st.δ, st.Δs, coupling) + Δs1 = δtoΔs(d, val, st.δ, st.Δs, derivative_coupling) - low = cdf(d, val - 1) - high = cdf(d, val) - - get_alt_d(Δ) = _reconstruct(d_st, st.value + Δ) - function map_func(Δ) - alt_d = get_alt_d(Δ) - alt_val = quantile(alt_d, rand(RNG) * (high - low) + low) - convert(Signed, alt_val - val) - end - function enumeration(Δ, _) - alt_d = get_alt_d(Δ) - if _has_finite_support(alt_d) - map(_get_support(alt_d)) do alt_val - # interval intersect of (cdf(alt_d, alt_val - 1), cdf(alt_d, alt_val)) and (low, high) - alt_low = cdf(alt_d, alt_val - 1) - alt_high = cdf(alt_d, alt_val) - prob_alt = max(0.0, min(alt_high, high) - max(alt_low, low)) / - (high - low) - return (alt_val - val, prob_alt) - end - else - error("enumeration not supported for distribution $d. Does $d have finite support?") - end - end - Δs2 = map(map_func, + Δs2 = map(Δ -> _map_func(d, val, Δ, propagation_coupling), st.Δs; - enumeration, - deriv = δ -> smoothed_delta(d, val, δ, coupling), + enumeration = (Δ, _) -> _map_enumeration(d, val, Δ, propagation_coupling), + deriv = δ -> smoothed_delta(d, val, δ, derivative_coupling), out_rep = val, - perturbation_map_kwargs...) + Δ_kwargs...) StochasticTriple{T}(val, zero(val), combine((Δs2, Δs1); rep = Δs1)) # ensure that tags are in order in combine, in case backend wishes to exploit this end @@ -264,9 +339,9 @@ end function randst(rng::AbstractRNG, d_st::Categorical{<:StochasticTriple{T}, <:AbstractVector{<:StochasticTriple{T, V}}}; - perturbation_map_kwargs = (;), - coupling = InversionMethodDerivativeCoupling()) where {T, - V} + Δ_kwargs = (;), + derivative_coupling = InversionMethodDerivativeCoupling(), + propagation_coupling = InversionMethodPropagationCoupling()) where {T, V} sts = _get_parameter(d_st) # stochastic triple for each probability p = map(st -> st.value, sts) # try to keep the same type. e.g. static array -> static array. TODO: avoid allocations d = _reconstruct(d_st, p) @@ -275,40 +350,19 @@ function randst(rng::AbstractRNG, Δs_all = map(st -> st.Δs, sts) Δs_rep = get_rep(Δs_all) - Δs1 = δtoΔs(d, val, map(st -> st.δ, sts), Δs_rep, coupling) + Δs1 = δtoΔs(d, val, map(st -> st.δ, sts), Δs_rep, derivative_coupling) - low = cdf(d, val - 1) - high = cdf(d, val) Δs_coupled = couple(Δs_all; rep = Δs_rep, out_rep = p) # TODO: again, there are possible allocations here - - get_alt_d(Δ) = _reconstruct(d, p .+ Δ) - function map_func(Δ) - alt_d = get_alt_d(Δ) - alt_val = quantile(alt_d, rand(RNG) * (high - low) + low) - convert(Signed, alt_val - val) - end - function enumeration(Δ, _) - alt_d = get_alt_d(Δ) - if _has_finite_support(alt_d) - map(_get_support(alt_d)) do alt_val - # interval intersect of (cdf(alt_d, alt_val - 1), cdf(alt_d, alt_val)) and (low, high) - alt_low = cdf(alt_d, alt_val - 1) - alt_high = cdf(alt_d, alt_val) - prob_alt = max(0.0, min(alt_high, high) - max(alt_low, low)) / (high - low) - return (alt_val - val, prob_alt) - end - else - error("enumeration not supported for distribution $d. Does $d have finite support?") - end - end - Δs2 = map(map_func, + Δs2 = map(Δ -> _map_func(d, val, Δ, propagation_coupling), Δs_coupled; - enumeration, - deriv = δ -> smoothed_delta(d, val, δ, coupling), + enumeration = (Δ, _) -> _map_enumeration(d, val, Δ, propagation_coupling), + deriv = δ -> smoothed_delta(d, val, δ, derivative_coupling), out_rep = val, - perturbation_map_kwargs...) + Δ_kwargs...) + + Δs = combine((Δs2, Δs1); rep = Δs1, out_rep = val, Δ_kwargs...) - StochasticTriple{T}(val, zero(val), combine((Δs2, Δs1); rep = Δs_rep)) + StochasticTriple{T}(val, zero(val), Δs) end ## Handling finite perturbation to Binomial number of trials diff --git a/src/finite_infinitesimals.jl b/src/finite_infinitesimals.jl index 1c053433..ebb5410a 100644 --- a/src/finite_infinitesimals.jl +++ b/src/finite_infinitesimals.jl @@ -65,3 +65,13 @@ function get_any(Δs_all) end abstract type AbstractPerturbationStrategy end + +abstract type AbstractPerturbationSignal end + +function send_signal end + +# Ignore signals by default since they do not change semantics. +function StochasticAD.send_signal( + Δs::StochasticAD.AbstractFIs, ::StochasticAD.AbstractPerturbationSignal) + return Δs +end diff --git a/src/prelude.jl b/src/prelude.jl index 91df8db9..3fbb6fa5 100644 --- a/src/prelude.jl +++ b/src/prelude.jl @@ -10,7 +10,7 @@ const BINARY_PREDICATES = [ ==, !=, <=, - >=, + >= ] const UNARY_TYPEFUNCS_NOWRAP = [Base.rtoldefault] @@ -20,7 +20,7 @@ const UNARY_TYPEFUNCS_WRAP = [ floatmin, floatmax, zero, - one, + one ] const RNG_TYPEFUNCS_WRAP = [rand, randn, randexp] diff --git a/src/smoothing.jl b/src/smoothing.jl index dd1db1fc..26cb318e 100644 --- a/src/smoothing.jl +++ b/src/smoothing.jl @@ -30,9 +30,9 @@ end # Smoothed rules for univariate single-parameter distributions. -function smoothed_delta(d, val, δ, coupling) +function smoothed_delta(d, val, δ, derivative_coupling) Δs_empty = SmoothedFIs{typeof(val)}(0.0) - return derivative_contribution(δtoΔs(d, val, δ, Δs_empty, coupling)) + return derivative_contribution(δtoΔs(d, val, δ, Δs_empty, derivative_coupling)) end for (dist, i, field) in [ @@ -40,7 +40,7 @@ for (dist, i, field) in [ (:Bernoulli, :1, :p), (:Binomial, :2, :p), (:Poisson, :1, :λ), - (:Categorical, :1, :p), + (:Categorical, :1, :p) ] # i = index of parameter p # dual overloading @eval function Base.rand(rng::AbstractRNG, @@ -49,7 +49,7 @@ for (dist, i, field) in [ end @eval function randst(rng::AbstractRNG, d_dual::$dist{<:ForwardDiff.Dual{T}}; - coupling = InversionMethodDerivativeCoupling()) where {T} + derivative_coupling = InversionMethodDerivativeCoupling()) where {T} dual = params(d_dual)[$i] # dual could represent an array of duals or a single one; map handles both cases. p = map(value, dual) @@ -59,7 +59,8 @@ for (dist, i, field) in [ d = $dist(params(d_dual)[1:($i - 1)]..., p, params(d_dual)[($i + 1):end]...) val = convert(Signed, rand(rng, d)) - partials = ForwardDiff.Partials(map(δ -> smoothed_delta(d, val, δ, coupling), δs)) + partials = ForwardDiff.Partials(map( + δ -> smoothed_delta(d, val, δ, derivative_coupling), δs)) ForwardDiff.Dual{T}(val, partials) end # frule @@ -68,9 +69,9 @@ for (dist, i, field) in [ return frule(Δargs, randst, rng, d) end @eval function ChainRulesCore.frule((_, _, Δd), ::typeof(randst), rng::AbstractRNG, - d::$dist; coupling = InversionMethodDerivativeCoupling()) + d::$dist; derivative_coupling = InversionMethodDerivativeCoupling()) val = convert(Signed, rand(rng, d)) - Δval = smoothed_delta(d, val, Δd, coupling) + Δval = smoothed_delta(d, val, Δd, derivative_coupling) return (val, Δval) end # rrule @@ -80,18 +81,18 @@ for (dist, i, field) in [ @eval function ChainRulesCore.rrule(::typeof(randst), rng::AbstractRNG, d::$dist; - coupling = InversionMethodDerivativeCoupling()) + derivative_coupling = InversionMethodDerivativeCoupling()) val = convert(Signed, rand(rng, d)) function rand_pullback(∇out) p = params(d)[$i] if p isa Real - Δp = smoothed_delta(d, val, one(val), coupling) + Δp = smoothed_delta(d, val, one(val), derivative_coupling) else # TODO: this rule is O(length(p)^2), whereas we should be able to do O(length(p)) by reversing through δtoΔs. I = eachindex(p) V = eltype(p) onehot(i) = map(j -> j == i ? one(V) : zero(V), I) - Δp = map(i -> smoothed_delta(d, val, onehot(i), coupling), I) + Δp = map(i -> smoothed_delta(d, val, onehot(i), derivative_coupling), I) end # rrule_via_ad approach below not used because slow. # Δp = rrule_via_ad(config, smoothed_delta, d, val, map(one, p))[2](∇out)[4] diff --git a/src/stochastic_triple.jl b/src/stochastic_triple.jl index faa17874..66b68cbc 100644 --- a/src/stochastic_triple.jl +++ b/src/stochastic_triple.jl @@ -59,6 +59,21 @@ Return the finite perturbation(s) of `st`, in a format dependent on the [backend perturbations(x::Real) = () perturbations(st::StochasticTriple) = perturbations(st.Δs) +""" + send_signal(st::StochasticTriple, signal::AbstractPerturbationSignal) + send_signal(Δs::StochasticAD.AbstractFIs, signal::AbstractPerturbationSignal) + +Send a certain signal to a stochastic triple's perturbation collection `st.Δs` (or to a `Δs` directly), +which the backend may process as it wishes. Semantically, unbiasedness should not be affected by the +sending of the signal. The new version of the first argument (`st` or `Δs`) after signal processing is +returned. +""" +send_signal(st::Real, ::AbstractPerturbationSignal) = st +function send_signal(st::StochasticTriple{T}, signal::AbstractPerturbationSignal) where {T} + new_Δs = send_signal(st.Δs, signal) + return StochasticTriple{T}(st.value, st.δ, new_Δs) +end + """ derivative_contribution(st::StochasticTriple) @@ -117,7 +132,7 @@ end function StochasticTriple{T}(value::A, δ::B, Δs::FIs) where {T, A, B, C, FIs <: AbstractFIs{C}} V = promote_type(A, B, C) - StochasticTriple{T}(convert(V, value), convert(V, δ), similar_type(FIs, V)(Δs)) + StochasticTriple{T}(convert(V, value), convert(V, δ), convert(similar_type(FIs, V), Δs)) end ### Conversion rules @@ -127,7 +142,7 @@ end function Base.convert(::Type{StochasticTriple{T1, V, FIs}}, x::StochasticTriple{T2}) where {T1, T2, V, FIs} (T1 !== T2) && throw(ArgumentError("Tags of combined stochastic triples do not match.")) - StochasticTriple{T1, V, FIs}(convert(V, x.value), convert(V, x.δ), FIs(x.Δs)) + StochasticTriple{T1, V, FIs}(convert(V, x.value), convert(V, x.δ), convert(FIs, x.Δs)) end # TODO: ForwardDiff's promotion rules are a little more complicated, see https://github.com/JuliaDiff/ForwardDiff.jl/issues/322 @@ -187,8 +202,7 @@ end struct Tag{F, V} end -function stochastic_triple_direction(f, p::V, direction; - backend = PrunedFIsBackend()) where {V} +function stochastic_triple_direction(f, p::V, direction; backend) where {V} Δs = create_Δs(backend, Int) # TODO: necessity of hardcoding some type here suggests interface improvements sts = structural_map(p, direction) do p_i, direction_i StochasticTriple{Tag{typeof(f), V}}(p_i, direction_i, @@ -224,9 +238,10 @@ StochasticTriple of Int64: 0 + 0ε + (1 with probability 2.0ε) ``` """ -function stochastic_triple(f, p; direction = nothing, kwargs...) +function stochastic_triple( + f, p; direction = nothing, backend::AbstractFIsBackend = PrunedFIsBackend()) if direction !== nothing - return stochastic_triple_direction(f, p, direction; kwargs...) + return stochastic_triple_direction(f, p, direction; backend) end counter = begin c = 0 @@ -240,7 +255,7 @@ function stochastic_triple(f, p; direction = nothing, kwargs...) direction = structural_map(indices, p) do i, p_i i == perturbed_index ? one(p_i) : zero(p_i) end - stochastic_triple_direction(f, p, direction; kwargs...) + stochastic_triple_direction(f, p, direction; backend) end return structural_map(map_func, indices) end diff --git a/test/resampling.jl b/test/resampling.jl index 5a871135..389820d2 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -16,7 +16,7 @@ Random.seed!(seed) T = 3 d = 2 A(θ, a = 0.01) = [exp(-a)*cos(θ[]) exp(-a)*sin(θ[]) - -exp(-a)*sin(θ[]) exp(-a)*cos(θ[])] + -exp(-a)*sin(θ[]) exp(-a)*cos(θ[])] obs(x, θ) = MvNormal(x, 0.01 * collect(I(d))) dyn(x, θ) = MvNormal(A(θ) * x, 0.02 * collect(I(d))) x0 = [2.0, 0.0] # start value of the simulation diff --git a/test/triples.jl b/test/triples.jl index d2dfdcd3..79b8ac28 100644 --- a/test/triples.jl +++ b/test/triples.jl @@ -10,12 +10,12 @@ using Zygote const backends = [ PrunedFIsBackend(), PrunedFIsAggressiveBackend(), - DictFIsBackend(), + DictFIsBackend() ] const backends_smoothed = [ SmoothedFIsBackend(), - StrategyWrapperFIsBackend(SmoothedFIsBackend(), StochasticAD.TwoSidedStrategy()), + StrategyWrapperFIsBackend(SmoothedFIsBackend(), StochasticAD.TwoSidedStrategy()) ] @testset "Distributions w.r.t. continuous parameter" begin @@ -36,7 +36,7 @@ const backends_smoothed = [ (p -> Categorical([0, p^2, 0, 0, 1 - p^2])), # check that 0's are skipped over (p -> Categorical([1.0, exp(p)] ./ (1.0 + exp(p)))), # test fix for #38 (floating point comparisons in Categorical logic) (p -> Binomial(3, p)), - (p -> Binomial(20, p)), + (p -> Binomial(20, p)) ] p_ranges = [(0.2, 0.8) for _ in 1:8] out_ranges = [0:1, 0:MAX, 0:MAX, 1:2, 1:5, 1:2, 0:3, 0:20] @@ -68,7 +68,8 @@ const backends_smoothed = [ if backend == :smoothing_autodiff batched_full_func(p) = mean([full_func(p) for i in 1:nsamples]) # The array input used for ForwardDiff below is a trick to test multiple partials - triple_deriv_forward = mean(ForwardDiff.gradient(arr -> batched_full_func(sum(arr)), + triple_deriv_forward = mean(ForwardDiff.gradient( + arr -> batched_full_func(sum(arr)), [2 * p, -p])) triple_deriv_backward = Zygote.gradient(batched_full_func, p)[1] @test isapprox(triple_deriv_forward, exact_deriv, rtol = rtol) @@ -143,7 +144,7 @@ end end array_index_mean(p) = sum([p / 2, p / 2, (1 - p)] .* arr) triple_array_index_deriv = mean(derivative_estimate(array_index, p; backend) - for i in 1:50000) + for i in 1:50000) exact_array_index_deriv = ForwardDiff.derivative(array_index_mean, p) @test isapprox(triple_array_index_deriv, exact_array_index_deriv, rtol = 5e-2) # Don't run subsequent tests with smoothing backend @@ -156,7 +157,7 @@ end end array_index2_mean(p) = sum([p / 2 * p, p / 2 * p, (1 - p) * p] .* arr) triple_array_index2_deriv = mean(derivative_estimate(array_index2, p; backend) - for i in 1:50000) + for i in 1:50000) exact_array_index2_deriv = ForwardDiff.derivative(array_index2_mean, p) @test isapprox(triple_array_index2_deriv, exact_array_index2_deriv, rtol = 5e-2) # Test case where triple and alternate array value are coupled @@ -167,7 +168,7 @@ end end array_index3_mean(p) = -5 * (1 - p) + 1 * p triple_array_index3_deriv = mean(derivative_estimate(array_index3, p; backend) - for i in 1:50000) + for i in 1:50000) exact_array_index3_deriv = ForwardDiff.derivative(array_index3_mean, p) @test isapprox(triple_array_index3_deriv, exact_array_index3_deriv, rtol = 5e-2) end @@ -181,7 +182,8 @@ end stochastic_ad_grad = derivative_estimate(f, x; backend) stochastic_ad_grad2 = derivative_contribution.(stochastic_triple(f, x; backend)) - stochastic_ad_grad_firsttwo = derivative_estimate(f, x; direction = [1.0, 1.0, 0.0], + stochastic_ad_grad_firsttwo = derivative_estimate( + f, x; direction = [1.0, 1.0, 0.0], backend) fd_grad = ForwardDiff.gradient(f, x) @test stochastic_ad_grad ≈ fd_grad @@ -362,13 +364,15 @@ end @test w2_scaled ≈ 2.0 * w2 # Test map_Δs with filter state if !is_smoothed_backend - Δs1_plus_Δs0 = StochasticAD.map_Δs((Δ, state) -> Δ + - StochasticAD.filter_state(Δs0, + Δs1_plus_Δs0 = StochasticAD.map_Δs( + (Δ, state) -> Δ + + StochasticAD.filter_state(Δs0, state), Δs1) @test derivative_contribution(Δs1_plus_Δs0) ≈ Δ * 3.0 - Δs1_plus_mapped = StochasticAD.map_Δs((Δ, state) -> Δ + - StochasticAD.filter_state(Δs1, + Δs1_plus_mapped = StochasticAD.map_Δs( + (Δ, state) -> Δ + + StochasticAD.filter_state(Δs1, state), Δs1_map) @test derivative_contribution(Δs1_plus_mapped) ≈ Δ * 3.0 + Δ^2 * 3.0 @@ -435,9 +439,11 @@ end end # For a simple sum, this should be equivalent to the combine behaviour. if check_combine && !is_smoothed_backend - @test isapprox(mean(derivative_contribution(get_Δs_coupled(; - do_combine = true)) - for i in 1:1000), expected_contribution; + @test isapprox( + mean(derivative_contribution(get_Δs_coupled(; + do_combine = true)) + for i in 1:1000), + expected_contribution; rtol = 5e-2) end # Check scalarize @@ -480,7 +486,8 @@ end NB: since the implementation of perturbations can be backend-specific, the below property need not hold in general, but does for the current non-smoothed backends. =# - @test collect(perturbations(st)) == [(1, 2.0)] + p = only(perturbations(st)) + @test p.Δ == 1 && p.weight == 2.0 @test derivative_contribution(st) == 3.0 else # Since smoothed algorithm uses the two-sided strategy, we get a different derivative contribution. @@ -529,7 +536,8 @@ end @test x_st isa StochasticAD.StochasticTriple{0, typeof(x)} @test StochasticAD.value(x_st) == x @test StochasticAD.delta(x_st) ≈ δ - @test collect(perturbations(x_st)) == [(Δ, 1.0)] + p = only(perturbations(x_st)) + @test p.Δ == Δ && p.weight == 1.0 end end diff --git a/tutorials/particle_filter/benchmark.jl b/tutorials/particle_filter/benchmark.jl index 5d922254..57058d32 100644 --- a/tutorials/particle_filter/benchmark.jl +++ b/tutorials/particle_filter/benchmark.jl @@ -13,7 +13,8 @@ secs = 10 suite = BenchmarkGroup() suite["scaling"] = BenchmarkGroup(["grads"]) -suite["scaling"]["primal"] = @benchmarkable ParticleFilterCore.log_likelihood(particle_filter, +suite["scaling"]["primal"] = @benchmarkable ParticleFilterCore.log_likelihood( + particle_filter, θtrue) suite["scaling"]["forward"] = @benchmarkable ParticleFilterCore.forw_grad(θtrue, particle_filter) diff --git a/tutorials/particle_filter/model.jl b/tutorials/particle_filter/model.jl index 9184938f..1d1ff36f 100644 --- a/tutorials/particle_filter/model.jl +++ b/tutorials/particle_filter/model.jl @@ -53,7 +53,8 @@ function generate_system(d, T) ### initialize filters m = 1000 # number of particles - kalman_filter = ParticleFilterCore.KalmanFilter(d, stochastic_model, H_Kalman, R_Kalman, + kalman_filter = ParticleFilterCore.KalmanFilter( + d, stochastic_model, H_Kalman, R_Kalman, Q_Kalman, ys) particle_filter = ParticleFilterCore.ParticleFilter(m, stochastic_model, ys, ParticleFilterCore.sample_stratified) diff --git a/tutorials/random_walk/compare_score.jl b/tutorials/random_walk/compare_score.jl index c5bd9904..48d5e69a 100644 --- a/tutorials/random_walk/compare_score.jl +++ b/tutorials/random_walk/compare_score.jl @@ -12,17 +12,18 @@ begin stds_score_baseline = Float64[] @showprogress for (n, p) in zip(RandomWalkCore.n_range, RandomWalkCore.p_range) std_triple = std(derivative_estimate(p -> RandomWalkCore.fX(p, n), p) - for i in 1:(RandomWalkCore.nsamples)) - std_smoothed = std(derivative(p -> RandomWalkCore.fX(p, - n; - hardcode_leftright_step = true), - p) - for i in 1:(RandomWalkCore.nsamples)) + for i in 1:(RandomWalkCore.nsamples)) + std_smoothed = std(derivative( + p -> RandomWalkCore.fX(p, + n; + hardcode_leftright_step = true), + p) + for i in 1:(RandomWalkCore.nsamples)) std_score = std(RandomWalkCore.score_fX_deriv(p, n, 0.0) - for i in 1:(RandomWalkCore.nsamples)) + for i in 1:(RandomWalkCore.nsamples)) avg = mean(RandomWalkCore.fX(p, n) for i in 1:10000) std_score_baseline = std(RandomWalkCore.score_fX_deriv(p, n, avg) - for i in 1:(RandomWalkCore.nsamples)) + for i in 1:(RandomWalkCore.nsamples)) push!(stds_triple, std_triple) push!(stds_score, std_score) push!(stds_score_baseline, std_score_baseline)