Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Miscellaneous interface improvements #119

Merged
merged 32 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
06f56df
Changing perturbation format
gaurav-arya Feb 10, 2024
9ee33da
Comment edits
gaurav-arya Feb 11, 2024
b39eaa6
Add perturbation signal stub
gaurav-arya Feb 11, 2024
276184c
Include perturbation_signal in wrapper type
gaurav-arya Feb 11, 2024
0a09768
Rename to send_signal, add forwarding from stochastic triple dispatch
gaurav-arya Feb 11, 2024
87db93a
Wrap output as stochastic triple when signal sent to st
gaurav-arya Feb 11, 2024
94d624e
Export StochasticTriple
gaurav-arya Feb 15, 2024
fda6bf3
Fix send_signal in wrapper
gaurav-arya Feb 15, 2024
f0587a9
Ignore signal by default
gaurav-arya Feb 15, 2024
76c0176
Rename coupling -> derivative_coupling
gaurav-arya Feb 19, 2024
0b627d7
Use Base.convert for backends
gaurav-arya Feb 22, 2024
152cc4c
Interface nonfunctional kwarg tweaks
gaurav-arya Feb 22, 2024
c5ef43f
A big commit
gaurav-arya Feb 24, 2024
03abaa0
Slightly improve file organization
gaurav-arya Feb 28, 2024
981deb2
Allow specifying a custom propagation coupling for randst
gaurav-arya Feb 28, 2024
a1ab4ca
Force 32-bit tags
gaurav-arya Mar 6, 2024
8034a00
Add mode argument to inversion method derivative coupling
gaurav-arya Mar 6, 2024
3e7405d
Remove commented code
gaurav-arya Mar 6, 2024
f501340
Use Base.@kwdef
gaurav-arya Mar 6, 2024
972e6e9
Run formatter
gaurav-arya Mar 6, 2024
be3a249
Add formatting commit to git blame ignore
gaurav-arya Mar 6, 2024
4b600ae
Add kwargs... to some funcs
gaurav-arya Mar 6, 2024
9e6e103
Also add kwargs... for dict backend funcs
gaurav-arya Mar 6, 2024
1b3b490
Fix missing use of convert
gaurav-arya Mar 6, 2024
2094528
Add draft doc test
gaurav-arya Mar 6, 2024
f312d6c
Add a seed to doctest
gaurav-arya Mar 7, 2024
d7f35c9
Add InversionMethodDerivativeCoupling to docs
gaurav-arya Mar 7, 2024
442d698
Add missing doctest import
gaurav-arya Mar 7, 2024
f7e5937
Fix typo
gaurav-arya Mar 7, 2024
caa6d5b
temp hiding local var logic
gaurav-arya Mar 7, 2024
dcedf4e
Allow disabling handling of zero probs
gaurav-arya Mar 7, 2024
b8603bc
Fix doc example
gaurav-arya Mar 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Run this command to always ignore these in local `git blame`:
# git config blame.ignoreRevsFile .git-blame-ignore-revs

# Run formatter
70fd432667fb431e08ba52728734108d822a1922
# Run formatter
21038a047c023330876feb9259cd5c92add3ca81
# Run formatter after bracket alignment removal
Expand Down
3 changes: 2 additions & 1 deletion benchmark/iteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ for (setname, set) in (("tups", tups), ("SAs", SAs))
# seems to lead to slow benchmarks.
FIs_suite["couple_same"] = @benchmarkable StochasticAD.couple(typeof($Δs),
$Δs_all)
FIs_suite["combine_same"] = @benchmarkable StochasticAD.combine(typeof($Δs),
FIs_suite["combine_same"] = @benchmarkable StochasticAD.combine(
typeof($Δs),
$Δs_all)
end
end
Expand Down
4 changes: 2 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ pages = [
"tutorials/random_walk.md",
"tutorials/game_of_life.md",
"tutorials/particle_filter.md",
"tutorials/optimizations.md",
"tutorials/optimizations.md"
],
"Public API" => "public_api.md",
"Developer documentation" => "devdocs.md",
"Limitations" => "limitations.md",
"Limitations" => "limitations.md"
]

### Make docs
Expand Down
3 changes: 2 additions & 1 deletion docs/src/devdocs.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,6 @@ nothing # hide
## Distribution-specific customization of differentiation algorithm

```@docs
StochasticAD.randst
randst
InversionMethodDerivativeCoupling
```
8 changes: 5 additions & 3 deletions src/StochasticAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@ module StochasticAD

### Public API

export stochastic_triple, derivative_contribution, perturbations, smooth_triple, dual_number # For working with stochastic triples
export stochastic_triple, derivative_contribution, perturbations, smooth_triple,
dual_number, StochasticTriple # For working with stochastic triples
export derivative_estimate, StochasticModel, stochastic_gradient # Higher level functionality
export new_weight # Particle resampling
export PrunedFIsBackend,
PrunedFIsAggressiveBackend, DictFIsBackend, SmoothedFIsBackend,
StrategyWrapperFIsBackend
PrunedFIsAggressiveBackend, DictFIsBackend, SmoothedFIsBackend,
StrategyWrapperFIsBackend
export PrunedFIs, PrunedFIsAggressive, DictFIs, SmoothedFIs, StrategyWrapperFIs
export randst
export InversionMethodDerivativeCoupling

### Imports

Expand Down
23 changes: 17 additions & 6 deletions src/backends/abstract_wrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,22 @@ StochasticAD.valtype(Δs::AbstractWrapperFIs) = StochasticAD.valtype(Δs.Δs)

function StochasticAD.couple(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}},
Δs_all;
rep = nothing,
kwargs...) where {V, FIs}
_Δs_all = StochasticAD.structural_map(Δs -> Δs.Δs, Δs_all)
_rep_kwarg = !isnothing(rep) ? (; rep = rep.Δs) : (;)
return reconstruct_wrapper(StochasticAD.get_any(Δs_all),
StochasticAD.couple(FIs, _Δs_all; kwargs...))
StochasticAD.couple(FIs, _Δs_all; _rep_kwarg..., kwargs...))
end

function StochasticAD.combine(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}},
Δs_all;
rep = nothing,
kwargs...) where {V, FIs}
_Δs_all = StochasticAD.structural_map(Δs -> Δs.Δs, Δs_all)
_rep_kwarg = !isnothing(rep) ? (; rep = rep.Δs) : (;)
return reconstruct_wrapper(StochasticAD.get_any(Δs_all),
StochasticAD.combine(FIs, _Δs_all; kwargs...))
StochasticAD.combine(FIs, _Δs_all; _rep_kwarg..., kwargs...))
end

function StochasticAD.get_rep(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}},
Expand All @@ -61,8 +65,10 @@ function StochasticAD.get_rep(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}},
StochasticAD.get_rep(FIs, _Δs_all; kwargs...))
end

function StochasticAD.scalarize(Δs::AbstractWrapperFIs; kwargs...)
return StochasticAD.structural_map(StochasticAD.scalarize(Δs.Δs; kwargs...)) do _Δs
function StochasticAD.scalarize(Δs::AbstractWrapperFIs; rep = nothing, kwargs...)
_rep_kwarg = !isnothing(rep) ? (; rep = rep.Δs) : (;)
return StochasticAD.structural_map(StochasticAD.scalarize(
Δs.Δs; _rep_kwarg..., kwargs...)) do _Δs
reconstruct_wrapper(Δs, _Δs)
end
end
Expand Down Expand Up @@ -98,8 +104,13 @@ function StochasticAD.derivative_contribution(Δs::AbstractWrapperFIs)
StochasticAD.derivative_contribution(Δs.Δs)
end

function (::Type{<:AbstractWrapperFIs{V}})(Δs::AbstractWrapperFIs) where {V}
reconstruct_wrapper(Δs, StochasticAD.similar_type(typeof(Δs.Δs), V)(Δs.Δs))
function Base.convert(::Type{<:AbstractWrapperFIs{V}}, Δs::AbstractWrapperFIs) where {V}
reconstruct_wrapper(Δs, convert(StochasticAD.similar_type(typeof(Δs.Δs), V), Δs.Δs))
end

function StochasticAD.send_signal(
Δs::AbstractWrapperFIs, signal::StochasticAD.AbstractPerturbationSignal)
reconstruct_wrapper(Δs, StochasticAD.send_signal(Δs.Δs, signal))
end

function Base.show(io::IO, Δs::AbstractWrapperFIs)
Expand Down
22 changes: 14 additions & 8 deletions src/backends/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ StochasticAD.create_Δs(::DictFIsBackend, V) = DictFIs{V}(DictFIsState())

### Convert type of a backend

function DictFIs{V}(Δs::DictFIs) where {V}
function Base.convert(::Type{DictFIs{V}}, Δs::DictFIs) where {V}
DictFIs{V}(convert(Dictionary{InfinitesimalEvent, V}, Δs.dict), Δs.state)
end

Expand All @@ -88,7 +88,9 @@ function StochasticAD.derivative_contribution(Δs::DictFIs{V}) where {V}
sum((Δ * event.w for (event, Δ) in pairs(Δs.dict)), init = zero(V) * 0.0)
end

StochasticAD.perturbations(Δs::DictFIs) = [(Δ, event.w) for (event, Δ) in pairs(Δs.dict)]
function StochasticAD.perturbations(Δs::DictFIs)
[(; Δ, weight = event.w, state = event) for (event, Δ) in pairs(Δs.dict)]
end

### Unary propagation

Expand All @@ -99,7 +101,7 @@ function StochasticAD.weighted_map_Δs(f, Δs::DictFIs; kwargs...)
mapped_weights = last.(mapped_values_and_weights)
scaled_events = map((event, a) -> InfinitesimalEvent(event.tag, event.w * a),
keys(Δs.dict),
mapped_weights)
mapped_weights) # TODO: should original events (with old tag) also be modified?
dict = Dictionary(scaled_events, mapped_values)
DictFIs(dict, Δs.state)
end
Expand All @@ -119,20 +121,23 @@ end

function StochasticAD.couple(FIs::Type{<:DictFIs}, Δs_all;
rep = StochasticAD.get_rep(FIs, Δs_all),
out_rep = nothing)
out_rep = nothing,
kwargs...)
all_keys = Iterators.map(StochasticAD.structural_iterate(Δs_all)) do Δs
keys(Δs.dict)
end
distinct_keys = unique(all_keys |> Iterators.flatten)
Δs_coupled_dict = [StochasticAD.structural_map(Δs -> isassigned(Δs.dict, key) ?
Δs.dict[key] :
zero(eltype(Δs.dict)), Δs_all)
Δs_coupled_dict = [StochasticAD.structural_map(
Δs -> isassigned(Δs.dict, key) ?
Δs.dict[key] :
zero(eltype(Δs.dict)),
Δs_all)
for key in distinct_keys]
DictFIs(Dictionary(distinct_keys, Δs_coupled_dict), rep.state)
end

function StochasticAD.combine(FIs::Type{<:DictFIs}, Δs_all;
rep = StochasticAD.get_rep(FIs, Δs_all))
rep = StochasticAD.get_rep(FIs, Δs_all), kwargs...)
Δs_dicts = Iterators.map(Δs -> Δs.dict, StochasticAD.structural_iterate(Δs_all))
Δs_combined_dict = reduce(Δs_dicts) do Δs_dict1, Δs_dict2
mergewith((x, y) -> StochasticAD.structural_map(+, x, y), Δs_dict1, Δs_dict2)
Expand All @@ -141,6 +146,7 @@ function StochasticAD.combine(FIs::Type{<:DictFIs}, Δs_all;
end

function StochasticAD.scalarize(Δs::DictFIs; out_rep = nothing)
# TODO: use vcat here?
tupleify(Δ1, Δ2) = StochasticAD.structural_map(tuple, Δ1, Δ2)
Δ_all_allkeys = foldl(tupleify, values(Δs.dict))
Δ_all_rep = first(values(Δs.dict))
Expand Down
34 changes: 24 additions & 10 deletions src/backends/pruned.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,21 @@ struct PrunedFIsBackend <: StochasticAD.AbstractFIsBackend end
State maintained by pruning backend.
"""
mutable struct PrunedFIsState
# tag is in place to avoid relying on mutability for uniqueness.
tag::Int32
weight::Float64
valid::Bool
PrunedFIsState(valid = true) = new(0.0, valid)
function PrunedFIsState(valid = true)
state = new(0, 0.0, valid)
state.tag = objectid(state) % typemax(Int32)
return state
end
end

Base.:(==)(state1::PrunedFIsState, state2::PrunedFIsState) = state1.tag == state2.tag
# c.f. https://github.com/JuliaLang/julia/blob/61c3521613767b2af21dfa5cc5a7b8195c5bdcaf/base/hashing.jl#L38C45-L38C51
Base.hash(state::PrunedFIsState) = state.tag

"""
PrunedFIs{V} <: StochasticAD.AbstractFIs{V}

Expand Down Expand Up @@ -59,7 +69,7 @@ StochasticAD.create_Δs(::PrunedFIsBackend, V) = PrunedFIs{V}(PrunedFIsState(fal

### Convert type of a backend

function PrunedFIs{V}(Δs::PrunedFIs) where {V}
function Base.convert(::Type{PrunedFIs{V}}, Δs::PrunedFIs) where {V}
PrunedFIs{V}(convert(V, Δs.Δ), Δs.state)
end

Expand All @@ -79,7 +89,11 @@ pruned_value(Δs::PrunedFIs{<:Tuple}) = isempty(Δs) ? zero.(Δs.Δ) : Δs.Δ
pruned_value(Δs::PrunedFIs{<:AbstractArray}) = isempty(Δs) ? zero.(Δs.Δ) : Δs.Δ

StochasticAD.derivative_contribution(Δs::PrunedFIs) = pruned_value(Δs) * Δs.state.weight
StochasticAD.perturbations(Δs::PrunedFIs) = ((pruned_value(Δs), Δs.state.weight),)
function StochasticAD.perturbations(Δs::PrunedFIs)
return ((; Δ = pruned_value(Δs),
weight = Δs.state.valid ? Δs.state.weight : zero(Δs.state.weight),
state = Δs.state),)
end

### Unary propagation

Expand All @@ -94,9 +108,8 @@ StochasticAD.alltrue(f, Δs::PrunedFIs) = f(Δs.Δ)

### Coupling

function StochasticAD.get_rep(::Type{<:PrunedFIs}, Δs_all)
# The code below is a bit ridiculous, but it's faster than `first` for small structures:)
return StochasticAD.get_any(Δs_all)
function StochasticAD.get_rep(FIs::Type{<:PrunedFIs}, Δs_all)
return empty(FIs)
end

function get_pruned_state(Δs_all)
Expand All @@ -105,7 +118,7 @@ function get_pruned_state(Δs_all)
isapproxzero(Δs) && return (total_weight, new_state)
candidate_state = Δs.state
if !candidate_state.valid ||
((new_state !== nothing) && (candidate_state === new_state))
((new_state !== nothing) && (candidate_state == new_state))
return (total_weight, new_state)
end
w = candidate_state.weight
Expand All @@ -131,14 +144,15 @@ end
# for pruning, coupling amounts to getting rid of perturbed values that have been
# lazily kept around even after (aggressive or lazy) pruning made the perturbation invalid.
# rep is unused.
function StochasticAD.couple(::Type{<:PrunedFIs}, Δs_all; rep = nothing, out_rep = nothing)
function StochasticAD.couple(
::Type{<:PrunedFIs}, Δs_all; rep = nothing, out_rep = nothing, kwargs...)
state = get_pruned_state(Δs_all)
Δ_coupled = StochasticAD.structural_map(pruned_value, Δs_all) # TODO: perhaps a performance optimization possible here
PrunedFIs(Δ_coupled, state)
end

# basically couple combined with a sum.
function StochasticAD.combine(::Type{<:PrunedFIs}, Δs_all; rep = nothing)
function StochasticAD.combine(::Type{<:PrunedFIs}, Δs_all; rep = nothing, kwargs...)
state = get_pruned_state(Δs_all)
Δ_combined = sum(pruned_value, StochasticAD.structural_iterate(Δs_all))
PrunedFIs(Δ_combined, state)
Expand All @@ -151,7 +165,7 @@ function StochasticAD.scalarize(Δs::PrunedFIs; out_rep = nothing)
end

function StochasticAD.filter_state(Δs::PrunedFIs{V}, state) where {V}
Δs.state === state ? pruned_value(Δs) : zero(V)
Δs.state == state ? pruned_value(Δs) : zero(V)
end

### Miscellaneous
Expand Down
10 changes: 6 additions & 4 deletions src/backends/pruned_aggressive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ end

### Convert type of a backend

function PrunedFIsAggressive{V}(Δs::PrunedFIsAggressive) where {V}
function Base.convert(::Type{PrunedFIsAggressive{V}}, Δs::PrunedFIsAggressive) where {V}
PrunedFIsAggressive{V}(convert(V, Δs.Δ), Δs.tag, Δs.state)
end

Expand All @@ -96,7 +96,9 @@ function StochasticAD.derivative_contribution(Δs::PrunedFIsAggressive)
pruned_value(Δs) * Δs.state.weight
end

StochasticAD.perturbations(Δs::PrunedFIsAggressive) = ((pruned_value(Δs), Δs.state.weight),)
function StochasticAD.perturbations(Δs::PrunedFIsAggressive)
((; Δ = pruned_value(Δs), weight = Δs.state.weight, state = Δs.state),)
end

### Unary propagation

Expand All @@ -120,15 +122,15 @@ end
# lazily kept around even after (aggressive or lazy) pruning made the perturbation invalid.
function StochasticAD.couple(FIs::Type{<:PrunedFIsAggressive}, Δs_all;
rep = StochasticAD.get_rep(FIs, Δs_all),
out_rep = nothing)
out_rep = nothing, kwargs...)
state = rep.state
Δ_coupled = StochasticAD.structural_map(pruned_value, Δs_all) # TODO: perhaps a performance optimization possible here
PrunedFIsAggressive(Δ_coupled, state.active_tag, state)
end

# basically couple combined with a sum.
function StochasticAD.combine(FIs::Type{<:PrunedFIsAggressive}, Δs_all;
rep = StochasticAD.get_rep(FIs, Δs_all))
rep = StochasticAD.get_rep(FIs, Δs_all), kwargs...)
state = rep.state
Δ_combined = sum(pruned_value, StochasticAD.structural_iterate(Δs_all))
PrunedFIsAggressive(Δ_combined, state.active_tag, state)
Expand Down
10 changes: 5 additions & 5 deletions src/backends/smoothed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,9 @@ StochasticAD.create_Δs(::SmoothedFIsBackend, V) = SmoothedFIs{V}(0.0)

### Convert type of a backend

function (::Type{<:SmoothedFIs{V}})(Δs::SmoothedFIs) where {V}
SmoothedFIs{V}(Δs.δ)
function Base.convert(FIs::Type{<:SmoothedFIs{V}}, Δs::SmoothedFIs) where {V}
SmoothedFIs{V}(Δs.δ)::FIs
end
(::Type{SmoothedFIs{V}})(Δs::SmoothedFIs) where {V} = SmoothedFIs{V}(Δs.δ)

### Getting information about perturbations

Expand All @@ -70,11 +69,12 @@ StochasticAD.alltrue(f, Δs::SmoothedFIs) = true

StochasticAD.get_rep(::Type{<:SmoothedFIs}, Δs_all) = StochasticAD.get_any(Δs_all)

function StochasticAD.couple(::Type{<:SmoothedFIs}, Δs_all; rep = nothing, out_rep)
function StochasticAD.couple(
::Type{<:SmoothedFIs}, Δs_all; rep = nothing, out_rep, kwargs...)
SmoothedFIs{typeof(out_rep)}(StochasticAD.structural_map(Δs -> Δs.δ, Δs_all))
end

function StochasticAD.combine(::Type{<:SmoothedFIs}, Δs_all; rep = nothing)
function StochasticAD.combine(::Type{<:SmoothedFIs}, Δs_all; rep = nothing, kwargs...)
V_out = StochasticAD.valtype(first(StochasticAD.structural_iterate(Δs_all)))
Δ_combined = sum(Δs -> Δs.δ, StochasticAD.structural_iterate(Δs_all))
SmoothedFIs{V_out}(Δ_combined)
Expand Down
7 changes: 4 additions & 3 deletions src/backends/strategy_wrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ export StrategyWrapperFIsBackend, StrategyWrapperFIs

struct StrategyWrapperFIsBackend{
B <: StochasticAD.AbstractFIsBackend,
S <: StochasticAD.AbstractPerturbationStrategy,
S <: StochasticAD.AbstractPerturbationStrategy
} <:
StochasticAD.AbstractFIsBackend
backend::B
Expand All @@ -17,7 +17,7 @@ end
struct StrategyWrapperFIs{
V,
FIs <: StochasticAD.AbstractFIs{V},
S <: StochasticAD.AbstractPerturbationStrategy,
S <: StochasticAD.AbstractPerturbationStrategy
} <:
AbstractWrapperFIs{V, FIs}
Δs::FIs
Expand All @@ -38,7 +38,8 @@ function AbstractWrapperFIsModule.reconstruct_wrapper(wrapper_Δs::StrategyWrapp
return StrategyWrapperFIs(Δs, wrapper_Δs.strategy)
end

function AbstractWrapperFIsModule.reconstruct_wrapper(::Type{
function AbstractWrapperFIsModule.reconstruct_wrapper(
::Type{
<:StrategyWrapperFIs{V, FIs, S},
},
Δs) where {V, FIs, S}
Expand Down
Loading
Loading