From 8dfc80005b7c238475bad2dd00330438268d486a Mon Sep 17 00:00:00 2001 From: Shravan Goswami Date: Thu, 2 Oct 2025 22:38:52 +0530 Subject: [PATCH 01/17] Simplify the workflow for computing model gradients --- .gitignore | 4 +- JuliaBUGS/Project.toml | 2 + JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl | 48 +++++++-- JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl | 41 ++++++- JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl | 16 ++- JuliaBUGS/src/JuliaBUGS.jl | 100 +++++++++++++++++- JuliaBUGS/src/model/Model.jl | 3 + JuliaBUGS/src/model/logdensityproblems.jl | 82 ++++++++++++++ .../test/BUGSPrimitives/distributions.jl | 3 +- JuliaBUGS/test/Project.toml | 2 + JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl | 62 +++++++++-- JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl | 10 +- JuliaBUGS/test/model/bugsmodel.jl | 52 +++++++++ JuliaBUGS/test/parallel_sampling.jl | 5 +- 14 files changed, 397 insertions(+), 33 deletions(-) diff --git a/.gitignore b/.gitignore index 779fd4d05..e5c898ac2 100644 --- a/.gitignore +++ b/.gitignore @@ -29,4 +29,6 @@ Manifest.toml *.local.* # gitingest generated files -digest.txt \ No newline at end of file +digest.txt + +tmp/ \ No newline at end of file diff --git a/JuliaBUGS/Project.toml b/JuliaBUGS/Project.toml index 8077d3517..5cfec993d 100644 --- a/JuliaBUGS/Project.toml +++ b/JuliaBUGS/Project.toml @@ -9,6 +9,7 @@ AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -52,6 +53,7 @@ AdvancedHMC = "0.6, 0.7, 0.8" AdvancedMH = "0.8" BangBang = "0.4.1" Bijectors = "0.13, 0.14, 0.15.5" +DifferentiationInterface = "0.7" Distributions = "0.23.8, 0.24, 0.25" Documenter = "0.27, 1" GLMakie = "0.10, 0.11, 0.12, 0.13" diff --git a/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl b/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl index 179ef02b5..bf052e3ed 100644 --- a/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl @@ -3,10 +3,12 @@ module JuliaBUGSAdvancedHMCExt using AbstractMCMC using AdvancedHMC using ADTypes +import DifferentiationInterface as DI using JuliaBUGS -using JuliaBUGS: BUGSModel, getparams, initialize! +using JuliaBUGS: BUGSModel, BUGSModelWithGradient, getparams, initialize! using JuliaBUGS.LogDensityProblems using JuliaBUGS.LogDensityProblemsAD +using JuliaBUGS.Model: _logdensity_switched using JuliaBUGS.Random using MCMCChains: Chains @@ -40,10 +42,13 @@ end function _gibbs_internal_hmc( rng::Random.AbstractRNG, cond_model::BUGSModel, sampler, ad_backend, state ) - # Wrap model with AD gradient computation - logdensitymodel = AbstractMCMC.LogDensityModel( - LogDensityProblemsAD.ADgradient(ad_backend, cond_model) + # Create gradient model on-the-fly using DifferentiationInterface + x = getparams(cond_model) + prep = DI.prepare_gradient( + _logdensity_switched, ad_backend, x, DI.Constant(cond_model) ) + ad_model = BUGSModelWithGradient(ad_backend, prep, cond_model) + logdensitymodel = AbstractMCMC.LogDensityModel(ad_model) # Take HMC/NUTS step if isnothing(state) @@ -53,7 +58,7 @@ function _gibbs_internal_hmc( logdensitymodel, sampler; n_adapts=0, # Disable adaptation within Gibbs - initial_params=getparams(cond_model), + initial_params=x, ) else # Use existing state for subsequent steps @@ -67,7 +72,7 @@ end function AbstractMCMC.bundle_samples( ts::Vector{<:AdvancedHMC.Transition}, - logdensitymodel::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper}, + logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSModelWithGradient}, sampler::AdvancedHMC.AbstractHMCSampler, state, chain_type::Type{Chains}; @@ -98,4 +103,35 @@ function AbstractMCMC.bundle_samples( ) end +# Keep backward compatibility with LogDensityProblemsAD wrapper +function AbstractMCMC.bundle_samples( + ts::Vector{<:AdvancedHMC.Transition}, + logdensitymodel::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper}, + sampler::AdvancedHMC.AbstractHMCSampler, + state, + chain_type::Type{Chains}; + discard_initial=0, + thinning=1, + kwargs..., +) + param_samples = [t.z.θ for t in ts] + + stats_names = collect(keys(merge((; lp=ts[1].z.ℓπ.value), AdvancedHMC.stat(ts[1])))) + stats_values = [ + vcat(ts[i].z.ℓπ.value, collect(values(AdvancedHMC.stat(ts[i])))) for + i in eachindex(ts) + ] + + # Delegate to gen_chains for proper parameter naming from BUGSModel + return JuliaBUGS.gen_chains( + logdensitymodel, + param_samples, + stats_names, + stats_values; + discard_initial=discard_initial, + thinning=thinning, + kwargs..., + ) +end + end diff --git a/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl b/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl index ca30555be..1d07ade50 100644 --- a/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl +++ b/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl @@ -3,10 +3,12 @@ module JuliaBUGSAdvancedMHExt using AbstractMCMC using AdvancedMH using ADTypes +import DifferentiationInterface as DI using JuliaBUGS -using JuliaBUGS: BUGSModel, getparams, initialize! +using JuliaBUGS: BUGSModel, BUGSModelWithGradient, getparams, initialize! using JuliaBUGS.LogDensityProblems using JuliaBUGS.LogDensityProblemsAD +using JuliaBUGS.Model: _logdensity_switched using JuliaBUGS.Random using MCMCChains: Chains @@ -52,10 +54,13 @@ end function _gibbs_internal_mh( rng::Random.AbstractRNG, cond_model::BUGSModel, sampler, ad_backend, state ) - # Wrap model with AD gradient computation for gradient-based proposals - logdensitymodel = AbstractMCMC.LogDensityModel( - LogDensityProblemsAD.ADgradient(ad_backend, cond_model) + # Create gradient model on-the-fly using DifferentiationInterface + x = getparams(cond_model) + prep = DI.prepare_gradient( + _logdensity_switched, ad_backend, x, DI.Constant(cond_model) ) + ad_model = BUGSModelWithGradient(ad_backend, prep, cond_model) + logdensitymodel = AbstractMCMC.LogDensityModel(ad_model) # Take MH step with gradient information if isnothing(state) @@ -64,7 +69,7 @@ function _gibbs_internal_mh( logdensitymodel, sampler; n_adapts=0, # Disable adaptation within Gibbs - initial_params=getparams(cond_model), + initial_params=x, ) else t, s = AbstractMCMC.step(rng, logdensitymodel, sampler, state; n_adapts=0) @@ -103,6 +108,32 @@ function AbstractMCMC.bundle_samples( ) end +function AbstractMCMC.bundle_samples( + ts::Vector{<:AdvancedMH.Transition}, + logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSModelWithGradient}, + sampler::AdvancedMH.MHSampler, + state, + chain_type::Type{Chains}; + discard_initial=0, + thinning=1, + kwargs..., +) + param_samples = [t.params for t in ts] + stats_names = [:lp] + stats_values = [[t.lp] for t in ts] + + return JuliaBUGS.gen_chains( + logdensitymodel, + param_samples, + stats_names, + stats_values; + discard_initial=discard_initial, + thinning=thinning, + kwargs..., + ) +end + +# Keep backward compatibility with LogDensityProblemsAD wrapper function AbstractMCMC.bundle_samples( ts::Vector{<:AdvancedMH.Transition}, logdensitymodel::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper}, diff --git a/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl b/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl index eec864093..40d77e848 100644 --- a/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl +++ b/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl @@ -2,7 +2,7 @@ module JuliaBUGSMCMCChainsExt using AbstractMCMC using JuliaBUGS -using JuliaBUGS: BUGSModel, find_generated_quantities_variables, evaluate!!, getparams +using JuliaBUGS: BUGSModel, BUGSModelWithGradient, find_generated_quantities_variables, evaluate!!, getparams using JuliaBUGS.AbstractPPL using JuliaBUGS.Accessors using JuliaBUGS.LogDensityProblemsAD @@ -21,6 +21,20 @@ function JuliaBUGS.gen_chains( ) end +function JuliaBUGS.gen_chains( + model::AbstractMCMC.LogDensityModel{<:BUGSModelWithGradient}, + samples, + stats_names, + stats_values; + kwargs..., +) + # Extract BUGSModel from gradient wrapper + bugs_model = model.logdensity.base_model + + return JuliaBUGS.gen_chains(bugs_model, samples, stats_names, stats_values; kwargs...) +end + +# Keep backward compatibility with LogDensityProblemsAD wrapper function JuliaBUGS.gen_chains( model::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper}, samples, diff --git a/JuliaBUGS/src/JuliaBUGS.jl b/JuliaBUGS/src/JuliaBUGS.jl index e340d8470..3e780011c 100644 --- a/JuliaBUGS/src/JuliaBUGS.jl +++ b/JuliaBUGS/src/JuliaBUGS.jl @@ -6,6 +6,7 @@ using Accessors using ADTypes using BangBang using Bijectors: Bijectors +using DifferentiationInterface using Distributions using Graphs, MetaGraphsNext using LinearAlgebra @@ -17,6 +18,7 @@ using Serialization: Serialization using StaticArrays import Base: ==, hash, Symbol, size +import DifferentiationInterface as DI import Distributions: truncated export @bugs @@ -239,13 +241,48 @@ function validate_bugs_expression(expr, line_num) end """ - compile(model_def, data[, initial_params]; skip_validation=false) + compile(model_def, data[, initial_params]; skip_validation=false, adtype=nothing) Compile the model with model definition and data. Optionally, initializations can be provided. If initializations are not provided, values will be sampled from the prior distributions. By default, validates that all functions in the model are in the BUGS allowlist (suitable for @bugs macro). Set `skip_validation=true` to skip validation (for @model macro usage). + +If `adtype` is provided, returns a `BUGSModelWithGradient` that supports gradient-based MCMC +samplers like HMC/NUTS. The gradient computation is prepared during compilation for optimal performance. + +# Arguments +- `model_def::Expr`: Model definition from @bugs macro +- `data::NamedTuple`: Observed data +- `initial_params::NamedTuple=NamedTuple()`: Initial parameter values (optional) +- `skip_validation::Bool=false`: Skip function validation (for @model macro) +- `eval_module::Module=@__MODULE__`: Module for evaluation +- `adtype`: AD backend specification. Can be: + - `AutoReverseDiff(compile=true)` - ReverseDiff with tape compilation (fastest) + - `AutoReverseDiff(compile=false)` - ReverseDiff without compilation + - `:ReverseDiff` - Shorthand for `AutoReverseDiff(compile=true)` + - `:ForwardDiff` - Shorthand for `AutoForwardDiff()` + - `:Zygote` - Shorthand for `AutoZygote()` + - Any other `ADTypes.AbstractADType` + +# Examples +```julia +# Basic compilation +model = compile(model_def, data) + +# With gradient support using explicit ADType +model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + +# With gradient support using symbol shorthand +model = compile(model_def, data; adtype=:ReverseDiff) # Same as above + +# Using ForwardDiff for small models +model = compile(model_def, data; adtype=:ForwardDiff) + +# Sample with NUTS +chain = AbstractMCMC.sample(model, NUTS(0.8), 1000) +``` """ function compile( model_def::Expr, @@ -253,6 +290,7 @@ function compile( initial_params::NamedTuple=NamedTuple(); skip_validation::Bool=false, eval_module::Module=@__MODULE__, + adtype::Union{Nothing,ADTypes.AbstractADType,Symbol}=nothing, ) # Validate functions by default (for @bugs macro usage) # Skip validation only for @model macro @@ -281,7 +319,65 @@ function compile( values(eval_env), ), ) - return BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params) + base_model = BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params) + + # If adtype provided, wrap with gradient capabilities + if adtype !== nothing + # Convert symbol to ADType if needed + adtype_obj = _resolve_adtype(adtype) + return _wrap_with_gradient(base_model, adtype_obj) + end + + return base_model +end + +""" + _resolve_adtype(adtype) -> ADTypes.AbstractADType + +Convert symbol shortcuts to ADTypes, or return the ADType as-is. + +Supported symbol shortcuts: +- `:ReverseDiff` -> `AutoReverseDiff(compile=true)` +- `:ForwardDiff` -> `AutoForwardDiff()` +- `:Zygote` -> `AutoZygote()` +- `:Enzyme` -> `AutoEnzyme()` +""" +function _resolve_adtype(adtype::Symbol) + if adtype === :ReverseDiff + return ADTypes.AutoReverseDiff(compile=true) + elseif adtype === :ForwardDiff + return ADTypes.AutoForwardDiff() + elseif adtype === :Zygote + return ADTypes.AutoZygote() + elseif adtype === :Enzyme + return ADTypes.AutoEnzyme() + else + error("Unknown AD backend symbol: $adtype. " * + "Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme. " * + "Or use an ADTypes object like AutoReverseDiff(compile=true).") + end +end + +# Pass through ADTypes objects unchanged +_resolve_adtype(adtype::ADTypes.AbstractADType) = adtype + +# Helper function to prepare gradient - separated to handle world age issues +function _wrap_with_gradient(base_model::Model.BUGSModel, adtype::ADTypes.AbstractADType) + # Get initial parameters for preparation + # Use invokelatest to handle world age issues with generated functions + x = Base.invokelatest(getparams, base_model) + + # Prepare gradient using DifferentiationInterface + # Use invokelatest to handle world age issues when calling logdensity during preparation + prep = Base.invokelatest( + DI.prepare_gradient, + Model._logdensity_switched, + adtype, + x, + DI.Constant(base_model) + ) + + return Model.BUGSModelWithGradient(adtype, prep, base_model) end # function compile( # model_str::String, diff --git a/JuliaBUGS/src/model/Model.jl b/JuliaBUGS/src/model/Model.jl index 37ca24aa8..2efe7adb1 100644 --- a/JuliaBUGS/src/model/Model.jl +++ b/JuliaBUGS/src/model/Model.jl @@ -2,8 +2,10 @@ module Model using Accessors using AbstractPPL +using ADTypes using BangBang using Bijectors +import DifferentiationInterface as DI using Distributions using Graphs using LinearAlgebra @@ -21,5 +23,6 @@ include("logdensityproblems.jl") export parameters, variables, initialize!, getparams, settrans, set_evaluation_mode export regenerate_log_density_function, set_observed_values! export evaluate_with_rng!!, evaluate_with_env!!, evaluate_with_values!! +export BUGSModelWithGradient, _logdensity_switched end # Model diff --git a/JuliaBUGS/src/model/logdensityproblems.jl b/JuliaBUGS/src/model/logdensityproblems.jl index 07d82b018..1b0381cae 100644 --- a/JuliaBUGS/src/model/logdensityproblems.jl +++ b/JuliaBUGS/src/model/logdensityproblems.jl @@ -24,3 +24,85 @@ end function LogDensityProblems.capabilities(::AbstractBUGSModel) return LogDensityProblems.LogDensityOrder{0}() end + +""" + BUGSModelWithGradient{B,P,M} + +Wraps a BUGSModel with automatic differentiation capabilities using DifferentiationInterface. +Implements both `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. + +# Fields +- `backend::B`: ADTypes backend (e.g., AutoReverseDiff()) +- `prep::P`: Prepared gradient from DifferentiationInterface (can be nothing) +- `base_model::M`: The underlying BUGSModel + +# Example +```julia +model_def = @bugs begin + x ~ dnorm(0, 1) +end +data = NamedTuple() + +# Create model with gradient capabilities +model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + +# Use with gradient-based MCMC +chain = AbstractMCMC.sample(rng, model, NUTS(0.8), 1000) +``` +""" +struct BUGSModelWithGradient{B<:ADTypes.AbstractADType,P,M<:BUGSModel} + backend::B + prep::P + base_model::M +end + +# Forward base BUGSModel interface +function LogDensityProblems.logdensity(model::BUGSModelWithGradient, x::AbstractVector) + return LogDensityProblems.logdensity(model.base_model, x) +end + +function LogDensityProblems.dimension(model::BUGSModelWithGradient) + return LogDensityProblems.dimension(model.base_model) +end + +function LogDensityProblems.capabilities(::Type{<:BUGSModelWithGradient}) + return LogDensityProblems.LogDensityOrder{1}() # Gradient available +end + +""" + _logdensity_switched(x, base_model_constant) + +Helper function that switches argument order for DifferentiationInterface compatibility. +DI expects the active argument (to differentiate w.r.t.) to come first. +""" +function _logdensity_switched(x::AbstractVector, base_model_constant::DI.Constant) + base_model = DI.unwrap(base_model_constant) + return LogDensityProblems.logdensity(base_model, x) +end + +# Fallback for testing during preparation (when DI calls without Constant wrapper) +function _logdensity_switched(x::AbstractVector, base_model::BUGSModel) + return LogDensityProblems.logdensity(base_model, x) +end + +""" + LogDensityProblems.logdensity_and_gradient(model::BUGSModelWithGradient, x) + +Compute log density and its gradient using DifferentiationInterface. +Uses prepared gradient if available, otherwise falls back to unprepared computation. +""" +function LogDensityProblems.logdensity_and_gradient( + model::BUGSModelWithGradient, x::AbstractVector +) + # Active argument (x) comes first for DI + # Base model passed as Constant context + if model.prep === nothing + return DI.value_and_gradient( + _logdensity_switched, model.backend, x, DI.Constant(model.base_model) + ) + else + return DI.value_and_gradient( + _logdensity_switched, model.prep, model.backend, x, DI.Constant(model.base_model) + ) + end +end diff --git a/JuliaBUGS/test/BUGSPrimitives/distributions.jl b/JuliaBUGS/test/BUGSPrimitives/distributions.jl index 69505e2f7..0262b051d 100644 --- a/JuliaBUGS/test/BUGSPrimitives/distributions.jl +++ b/JuliaBUGS/test/BUGSPrimitives/distributions.jl @@ -15,9 +15,8 @@ end A[1:2, 1:2] ~ dwish(B[:, :], 2) C[1:2] ~ dmnorm(mu[:], A[:, :]) end - model = compile(model_def, (mu=[0, 0], B=[1 0; 0 1]), (A=[1 0; 0 1],)) + ad_model = compile(model_def, (mu=[0, 0], B=[1 0; 0 1]), (A=[1 0; 0 1],); adtype=AutoReverseDiff()) - ad_model = ADgradient(:ReverseDiff, model) theta = [ 0.7931743744870574, 0.5151017206811268, diff --git a/JuliaBUGS/test/Project.toml b/JuliaBUGS/test/Project.toml index 9c02130ee..056862146 100644 --- a/JuliaBUGS/test/Project.toml +++ b/JuliaBUGS/test/Project.toml @@ -7,6 +7,7 @@ AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -36,6 +37,7 @@ AdvancedHMC = "0.6, 0.7" AdvancedMH = "0.8" BangBang = "0.4.1" ChainRules = "1" +DifferentiationInterface = "0.7" Distributions = "0.23.8, 0.24, 0.25" Documenter = "0.27, 1" Graphs = "1" diff --git a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl index 3eafb4736..cb6a62c68 100644 --- a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl @@ -6,10 +6,9 @@ y = x[1] + x[3] end data = (mu=[0, 0], sigma=[1 0; 0 1]) - model = compile(model_def, data) - ad_model = ADgradient(:ReverseDiff, model; compile=Val(true)) + ad_model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) n_samples, n_adapts = 10, 0 - D = LogDensityProblems.dimension(model) + D = LogDensityProblems.dimension(ad_model) initial_θ = rand(D) samples_and_stats = AbstractMCMC.sample( StableRNG(1234), @@ -27,18 +26,67 @@ [Symbol("x[3]"), Symbol("x[1:2][1]"), Symbol("x[1:2][2]"), :y] end + @testset "Symbol AD backend shortcuts" begin + model_def = @bugs begin + mu ~ dnorm(0, 1) + for i in 1:N + y[i] ~ dnorm(mu, 1) + end + end + data = (N=5, y=[1.0, 2.0, 1.5, 2.5, 1.8]) + + # Test that symbol shortcut works + ad_model_symbol = compile(model_def, data; adtype=:ReverseDiff) + ad_model_explicit = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + + @test ad_model_symbol isa JuliaBUGS.Model.BUGSModelWithGradient + @test ad_model_explicit isa JuliaBUGS.Model.BUGSModelWithGradient + + # Test that both produce equivalent results + n_samples, n_adapts = 100, 100 + D = LogDensityProblems.dimension(ad_model_symbol) + initial_θ = rand(StableRNG(123), D) + + samples_symbol = AbstractMCMC.sample( + StableRNG(1234), + ad_model_symbol, + NUTS(0.8), + n_samples; + progress=false, + chain_type=Chains, + n_adapts=n_adapts, + init_params=initial_θ, + discard_initial=n_adapts, + ) + + samples_explicit = AbstractMCMC.sample( + StableRNG(1234), + ad_model_explicit, + NUTS(0.8), + n_samples; + progress=false, + chain_type=Chains, + n_adapts=n_adapts, + init_params=initial_θ, + discard_initial=n_adapts, + ) + + # Results should be very similar (same RNG seed) + @test summarize(samples_symbol)[:mu].nt.mean[1] ≈ + summarize(samples_explicit)[:mu].nt.mean[1] rtol=0.1 + end + @testset "Inference results on examples: $example" for example in [:seeds, :rats, :stacks] (; model_def, data, inits, reference_results) = Base.getfield( JuliaBUGS.BUGSExamples, example ) - model = JuliaBUGS.compile(model_def, data, inits) - ad_model = ADgradient(:ReverseDiff, model; compile=Val(true)) + ad_model = JuliaBUGS.compile(model_def, data, inits; adtype=AutoReverseDiff(compile=true)) n_samples, n_adapts = 1000, 1000 - D = LogDensityProblems.dimension(model) - initial_θ = JuliaBUGS.getparams(model) + D = LogDensityProblems.dimension(ad_model) + initial_θ = JuliaBUGS.getparams(ad_model.base_model) samples_and_stats = AbstractMCMC.sample( StableRNG(1234), diff --git a/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl b/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl index b6b512980..7d6f35832 100644 --- a/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl +++ b/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl @@ -26,11 +26,10 @@ y=[1.58, 4.80, 7.10, 8.86, 11.73, 14.52, 18.22, 18.73, 21.04, 22.93], ) - model = compile(model_def, data, (;)) - ad_model = ADgradient(:ReverseDiff, model; compile=Val(true)) + ad_model = compile(model_def, data, (;); adtype=AutoReverseDiff(compile=true)) n_samples, n_adapts = 2000, 1000 - D = LogDensityProblems.dimension(model) + D = LogDensityProblems.dimension(ad_model) initial_θ = rand(D) hmc_chain = AbstractMCMC.sample( @@ -73,7 +72,7 @@ n_samples, n_adapts = 20000, 5000 mh_chain = AbstractMCMC.sample( - model, + ad_model.base_model, RWMH(MvNormal(zeros(D), I)), n_samples; progress=false, @@ -107,8 +106,7 @@ sigma[2] ~ InverseGamma(2, 3) sigma[3] ~ InverseGamma(2, 3) end - model = compile(model_def, (;)) - ad_model = ADgradient(:ReverseDiff, model; compile=Val(true)) + ad_model = compile(model_def, (;); adtype=AutoReverseDiff(compile=true)) hmc_chain = AbstractMCMC.sample( ad_model, NUTS(0.8), 10; progress=false, chain_type=Chains ) diff --git a/JuliaBUGS/test/model/bugsmodel.jl b/JuliaBUGS/test/model/bugsmodel.jl index 3f602ad72..ddfe835c5 100644 --- a/JuliaBUGS/test/model/bugsmodel.jl +++ b/JuliaBUGS/test/model/bugsmodel.jl @@ -402,4 +402,56 @@ end @test occursin("Variable sizes and types:", output) end end + + @testset "AD Type Parameter" begin + model_def = @bugs begin + mu ~ dnorm(0, 1) + y ~ dnorm(mu, 1) + end + data = (y=1.5,) + + @testset "Symbol shortcuts" begin + # Test :ReverseDiff shortcut + model_rd = compile(model_def, data; adtype=:ReverseDiff) + @test model_rd isa JuliaBUGS.Model.BUGSModelWithGradient + + # Test equivalence with explicit ADType + model_explicit = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + @test model_explicit isa JuliaBUGS.Model.BUGSModelWithGradient + + # Test that unknown symbol throws error + @test_throws ErrorException compile(model_def, data; adtype=:UnknownBackend) + end + + @testset "Explicit ADTypes" begin + # Test with compile=true + model_compile = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + @test model_compile isa JuliaBUGS.Model.BUGSModelWithGradient + + # Test with compile=false + model_nocompile = compile(model_def, data; adtype=AutoReverseDiff(compile=false)) + @test model_nocompile isa JuliaBUGS.Model.BUGSModelWithGradient + end + + @testset "Default behavior (no adtype)" begin + # Without adtype, should return regular BUGSModel + model_default = compile(model_def, data) + @test model_default isa BUGSModel + @test !(model_default isa JuliaBUGS.Model.BUGSModelWithGradient) + end + + @testset "Gradient computation" begin + model = compile(model_def, data; adtype=:ReverseDiff) + test_point = [0.0] + + # Test that gradient can be computed + ℓ, grad = LogDensityProblems.logdensity_and_gradient(model, test_point) + + @test ℓ isa Real + @test grad isa Vector + @test length(grad) == 1 + @test isfinite(ℓ) + @test all(isfinite, grad) + end + end end diff --git a/JuliaBUGS/test/parallel_sampling.jl b/JuliaBUGS/test/parallel_sampling.jl index 7871aca7f..8b857e4a0 100644 --- a/JuliaBUGS/test/parallel_sampling.jl +++ b/JuliaBUGS/test/parallel_sampling.jl @@ -19,9 +19,8 @@ data = (N=N, x=x_data) inits = (mu=0.0, tau=1.0) - model = compile(model_def, data, inits) - # Use compile=Val(false) for thread safety with ReverseDiff - ad_model = ADgradient(:ReverseDiff, model; compile=Val(false)) + # Use compile=false for thread safety with ReverseDiff + ad_model = compile(model_def, data, inits; adtype=AutoReverseDiff(compile=false)) # Single chain reference n_samples = 200 From d77bbac62ee735aa50e96a02d5d6d0971d31412b Mon Sep 17 00:00:00 2001 From: Shravan Goswami Date: Thu, 2 Oct 2025 23:33:18 +0530 Subject: [PATCH 02/17] update docs and bump version to 0.10.4 --- JuliaBUGS/History.md | 10 ++ JuliaBUGS/Project.toml | 2 +- JuliaBUGS/docs/src/example.md | 200 +++++++++++++++++++++++++++------- 3 files changed, 169 insertions(+), 43 deletions(-) diff --git a/JuliaBUGS/History.md b/JuliaBUGS/History.md index cb5567431..066d5ad3a 100644 --- a/JuliaBUGS/History.md +++ b/JuliaBUGS/History.md @@ -1,5 +1,15 @@ # JuliaBUGS Changelog +## 0.10.4 + +- **DifferentiationInterface.jl integration**: JuliaBUGS now uses [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) for automatic differentiation, providing a unified interface to multiple AD backends. + - Add `adtype` parameter to `compile()` function for specifying AD backends + - Support convenient symbol shortcuts: `:ReverseDiff`, `:ForwardDiff`, `:Zygote`, `:Enzyme` + - Gradient computation is prepared during compilation for optimal performance + - Example: `model = compile(model_def, data; adtype=:ReverseDiff)` + - Full control available via explicit ADTypes: `adtype=AutoReverseDiff(compile=true)` + - Backward compatible: models without `adtype` work as before + ## 0.10.1 Expose docs for changes in [v0.10.0](https://github.com/TuringLang/JuliaBUGS.jl/releases/tag/JuliaBUGS-v0.10.0) diff --git a/JuliaBUGS/Project.toml b/JuliaBUGS/Project.toml index 5cfec993d..65bd31780 100644 --- a/JuliaBUGS/Project.toml +++ b/JuliaBUGS/Project.toml @@ -1,6 +1,6 @@ name = "JuliaBUGS" uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" -version = "0.10.3" +version = "0.10.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/JuliaBUGS/docs/src/example.md b/JuliaBUGS/docs/src/example.md index eaba1a01e..3cf73a4cc 100644 --- a/JuliaBUGS/docs/src/example.md +++ b/JuliaBUGS/docs/src/example.md @@ -190,33 +190,54 @@ initialize!(model, initializations) initialize!(model, rand(26)) ``` -`LogDensityProblemsAD.jl` defined some extensions that support automatic differentiation packages. -For example, with `ReverseDiff.jl` +### Automatic Differentiation + +JuliaBUGS integrates with automatic differentiation (AD) through [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), enabling gradient-based inference methods like Hamiltonian Monte Carlo (HMC) and No-U-Turn Sampler (NUTS). + +#### Specifying an AD Backend + +To compile a model with gradient support, pass the `adtype` parameter to `compile`: ```julia -using LogDensityProblemsAD, ReverseDiff +# Using explicit ADType from ADTypes.jl +using ADTypes +model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + +# Using convenient symbol shortcuts +model = compile(model_def, data; adtype=:ReverseDiff) # Equivalent to above +``` -ad_model = ADgradient(:ReverseDiff, model; compile=Val(true)) +Available AD backends include: +- `:ReverseDiff` - ReverseDiff with tape compilation (recommended for most models) +- `:ForwardDiff` - ForwardDiff (efficient for models with few parameters) +- `:Zygote` - Zygote (source-to-source AD) +- `:Enzyme` - Enzyme (experimental, high-performance) + +For fine-grained control, use explicit `ADTypes` constructors: + +```julia +# ReverseDiff without compilation +model = compile(model_def, data; adtype=AutoReverseDiff(compile=false)) ``` -Here `ad_model` will also implement all the interfaces of [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/). -`LogDensityProblemsAD.jl` will automatically add the interface function [`logdensity_and_gradient`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.logdensity_and_gradient) to the model, which will return the log density and gradient of the model. -And `ad_model` can be used in the same way as `model` in the example below. +The compiled model with gradient support implements the [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/) interface, including [`logdensity_and_gradient`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.logdensity_and_gradient), which returns both the log density and its gradient. ### Inference -For a differentiable model, we can use [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) to perform inference. -For instance, +For gradient-based inference, we use [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) with models compiled with an `adtype`: ```julia -using AdvancedHMC, AbstractMCMC, LogDensityProblems, MCMCChains +using AdvancedHMC, AbstractMCMC, LogDensityProblems, MCMCChains, ReverseDiff + +# Compile with gradient support +model = compile(model_def, data; adtype=:ReverseDiff) n_samples, n_adapts = 2000, 1000 D = LogDensityProblems.dimension(model); initial_θ = rand(D) samples_and_stats = AbstractMCMC.sample( - ad_model, + model, NUTS(0.8), n_samples; chain_type = Chains, @@ -224,6 +245,7 @@ samples_and_stats = AbstractMCMC.sample( init_params = initial_θ, discard_initial = n_adapts ) +describe(samples_and_stats) ``` This will return the MCMC Chain, @@ -234,39 +256,72 @@ Chains MCMC chain (2000×40×1 Array{Real, 3}): Iterations = 1001:1:3000 Number of chains = 1 Samples per chain = 2000 -parameters = alpha0, alpha12, alpha1, alpha2, tau, b[16], b[12], b[10], b[14], b[13], b[7], b[6], b[20], b[1], b[4], b[5], b[2], b[18], b[8], b[3], b[9], b[21], b[17], b[15], b[11], b[19], sigma +parameters = tau, alpha12, alpha2, alpha1, alpha0, b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19], b[20], b[21], sigma internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, is_adapt Summary Statistics - parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec - Symbol Float64 Float64 Float64 Real Float64 Float64 Missing - - alpha0 -0.5642 0.2320 0.0084 766.9305 1022.5211 1.0021 missing - alpha12 -0.8489 0.5247 0.0170 946.0418 1044.1109 1.0002 missing - alpha1 0.0587 0.3715 0.0119 966.4367 1233.2257 1.0007 missing - alpha2 1.3852 0.3410 0.0127 712.2978 974.1566 1.0002 missing - tau 1.8880 0.7705 0.0447 348.9331 338.3655 1.0030 missing - b[16] -0.2445 0.4459 0.0132 1528.0578 843.8225 1.0003 missing - b[12] 0.2050 0.3602 0.0086 1868.6126 1202.1363 0.9996 missing - b[10] -0.3500 0.2893 0.0090 1047.3119 1245.9358 1.0008 missing - ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ - 19 rows omitted + parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec + Symbol Float64 Float64 Float64 Real Float64 Float64 Missing + + tau 73.1490 193.8441 43.2582 56.3430 20.6688 1.0155 missing + alpha12 -0.8052 0.4392 0.0158 761.2180 1049.1664 1.0020 missing + alpha2 1.3428 0.2813 0.0140 422.8810 1013.2570 1.0061 missing + alpha1 0.0845 0.3126 0.0113 773.2202 981.8487 1.0051 missing + alpha0 -0.5480 0.1944 0.0087 537.6212 1156.2083 1.0014 missing + b[1] -0.1905 0.2540 0.0129 374.3372 971.7526 1.0034 missing + b[2] 0.0161 0.2178 0.0056 1505.6353 1002.8787 1.0001 missing + b[3] -0.1986 0.2375 0.0128 367.6766 1287.8215 1.0015 missing + b[4] 0.2792 0.2498 0.0163 201.1558 1168.7538 1.0068 missing + b[5] 0.1170 0.2397 0.0092 659.5422 1484.8584 1.0016 missing + b[6] 0.0667 0.2821 0.0074 1745.5567 902.1014 1.0067 missing + b[7] 0.0597 0.2218 0.0055 1589.5590 1145.6017 1.0065 missing + b[8] 0.1769 0.2316 0.0102 554.5974 1318.8089 1.0001 missing + b[9] -0.1257 0.2233 0.0073 930.0346 1186.4283 1.0031 missing + b[10] -0.2513 0.2392 0.0159 213.6323 1142.4487 1.0096 missing + b[11] 0.0768 0.2783 0.0081 1376.5999 1218.1537 1.0009 missing + b[12] 0.1171 0.2768 0.0079 1354.9409 1130.8217 1.0052 missing + b[13] -0.0688 0.2433 0.0055 1895.0387 1527.7066 1.0010 missing + b[14] -0.1363 0.2558 0.0075 1276.0992 1208.8587 1.0001 missing + b[15] 0.2334 0.2757 0.0135 439.2241 837.3396 1.0036 missing + b[16] -0.1212 0.3024 0.0106 1093.4416 914.9457 0.9997 missing + b[17] -0.2120 0.3142 0.0166 360.6420 702.4098 1.0009 missing + b[18] 0.0346 0.2282 0.0056 1665.0325 1281.7179 1.0011 missing + b[19] -0.0244 0.2400 0.0052 2186.7638 1179.6971 1.0132 missing + b[20] 0.2108 0.2421 0.0131 349.7657 1263.5781 1.0016 missing + b[21] -0.0509 0.2813 0.0061 2200.5614 916.6256 0.9998 missing + sigma 0.2797 0.1362 0.0168 56.3430 21.4971 1.0123 missing Quantiles - parameters 2.5% 25.0% 50.0% 75.0% 97.5% - Symbol Float64 Float64 Float64 Float64 Float64 - - alpha0 -1.0143 -0.7143 -0.5590 -0.4100 -0.1185 - alpha12 -1.9063 -1.1812 -0.8296 -0.5153 0.1521 - alpha1 -0.6550 -0.1822 0.0512 0.2885 0.8180 - alpha2 0.7214 1.1663 1.3782 1.5998 2.0986 - tau 0.5461 1.3941 1.8353 2.3115 3.6225 - b[16] -1.2359 -0.4836 -0.1909 0.0345 0.5070 - b[12] -0.4493 -0.0370 0.1910 0.4375 0.9828 - b[10] -0.9570 -0.5264 -0.3331 -0.1514 0.1613 - ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ - 19 rows omitted - + parameters 2.5% 25.0% 50.0% 75.0% 97.5% + Symbol Float64 Float64 Float64 Float64 Float64 + + tau 3.1280 7.4608 13.0338 28.2289 929.6520 + alpha12 -1.6645 -1.0887 -0.7952 -0.5635 0.1162 + alpha2 0.8398 1.1494 1.3233 1.5337 1.9177 + alpha1 -0.5796 -0.1059 0.1042 0.2883 0.6702 + alpha0 -0.9340 -0.6751 -0.5463 -0.4086 -0.1752 + b[1] -0.7430 -0.3415 -0.1566 -0.0074 0.2535 + b[2] -0.4261 -0.1083 0.0192 0.1420 0.4810 + b[3] -0.7394 -0.3377 -0.1687 -0.0242 0.2041 + b[4] -0.1108 0.0873 0.2409 0.4375 0.8267 + b[5] -0.3141 -0.0458 0.0900 0.2563 0.6489 + b[6] -0.4679 -0.0896 0.0291 0.2202 0.7060 + b[7] -0.3861 -0.0685 0.0534 0.1847 0.5207 + b[8] -0.2326 0.0221 0.1505 0.3162 0.6861 + b[9] -0.6007 -0.2482 -0.0984 0.0057 0.2771 + b[10] -0.7936 -0.4108 -0.2255 -0.0617 0.1290 + b[11] -0.4381 -0.0796 0.0353 0.2178 0.7232 + b[12] -0.3806 -0.0451 0.0750 0.2671 0.7625 + b[13] -0.5841 -0.2135 -0.0443 0.0652 0.4055 + b[14] -0.6854 -0.2872 -0.1015 0.0147 0.3476 + b[15] -0.2054 0.0257 0.1898 0.4004 0.8660 + b[16] -0.8173 -0.2829 -0.0804 0.0532 0.4094 + b[17] -0.9071 -0.3911 -0.1595 0.0099 0.2864 + b[18] -0.4526 -0.0919 0.0140 0.1686 0.4985 + b[19] -0.5055 -0.1547 -0.0091 0.1134 0.4528 + b[20] -0.2120 0.0318 0.1788 0.3673 0.7416 + b[21] -0.6482 -0.2044 -0.0263 0.1051 0.5246 + sigma 0.0328 0.1882 0.2770 0.3661 0.5654 ``` This is consistent with the result in the [OpenBUGS seeds example](https://chjackson.github.io/openbugsdoc/Examples/Seeds.html). @@ -283,7 +338,7 @@ The model compilation code remains the same, and we can sample multiple chains i ```julia n_chains = 4 samples_and_stats = AbstractMCMC.sample( - ad_model, + model, AdvancedHMC.NUTS(0.65), AbstractMCMC.MCMCThreads(), n_samples, @@ -311,7 +366,7 @@ For example: ```julia @everywhere begin - using JuliaBUGS, LogDensityProblems, LogDensityProblemsAD, AbstractMCMC, AdvancedHMC, MCMCChains, ReverseDiff # also other packages one may need + using JuliaBUGS, LogDensityProblems, AbstractMCMC, AdvancedHMC, MCMCChains, ADTypes, ReverseDiff # Define the functions to use # Use `@bugs_primitive` to register the functions to use in the model @@ -322,7 +377,7 @@ end n_chains = nprocs() - 1 # use all the processes except the parent process samples_and_stats = AbstractMCMC.sample( - ad_model, + model, AdvancedHMC.NUTS(0.65), AbstractMCMC.MCMCDistributed(), n_samples, @@ -342,6 +397,67 @@ In this case, we pass two additional arguments to `AbstractMCMC.sample`: Note that the `init_params` argument is now a vector of initial parameters for each chain. Sometimes the progress logger can cause problems in distributed setting, so we can disable it by setting `progress = false`. +## Choosing an Automatic Differentiation Backend + +JuliaBUGS integrates with multiple automatic differentiation (AD) backends through [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), providing flexibility to choose the most suitable backend for your model. + +### Available Backends + +The following AD backends are supported via convenient symbol shortcuts: + +- **`:ReverseDiff`** (recommended) — Tape-based reverse-mode AD, highly efficient for models with many parameters. Uses compilation by default for optimal performance. +- **`:ForwardDiff`** — Forward-mode AD, efficient for models with few parameters (typically < 20). +- **`:Zygote`** — Source-to-source reverse-mode AD, general-purpose but may be slower than ReverseDiff for many models. +- **`:Enzyme`** — Experimental high-performance AD backend with LLVM-level transformations. + +### Usage Examples + +#### Basic Usage with Symbol Shortcuts + +The simplest way to specify an AD backend is using symbol shortcuts: + +```julia +# ReverseDiff with tape compilation (recommended for most models) +model = compile(model_def, data; adtype=:ReverseDiff) + +# ForwardDiff (good for small models with few parameters) +model = compile(model_def, data; adtype=:ForwardDiff) + +# Zygote (source-to-source AD) +model = compile(model_def, data; adtype=:Zygote) +``` + +#### Advanced Configuration + +For fine-grained control, use explicit `ADTypes` constructors: + +```julia +using ADTypes + +# ReverseDiff without tape compilation +model = compile(model_def, data; adtype=AutoReverseDiff(compile=false)) + +# ReverseDiff with compilation (equivalent to :ReverseDiff) +model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) +``` + +### Performance Considerations + +- **ReverseDiff with compilation** (`:ReverseDiff`) is recommended for most models, especially those with many parameters. Compilation adds a one-time overhead but significantly speeds up subsequent gradient evaluations. + +- **ForwardDiff** (`:ForwardDiff`) is often faster for models with few parameters (< 20), as it avoids tape construction overhead. + +- **Tape compilation trade-off**: While `AutoReverseDiff(compile=true)` has higher initial compilation time, it provides faster gradient evaluations during sampling. For quick prototyping or models that will only be sampled a few times, `AutoReverseDiff(compile=false)` may be preferable. + +### Compatibility + +All models compiled with an `adtype` implement the full [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/) interface, making them compatible with: + +- [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) — NUTS and HMC samplers +- Any other sampler that works with the LogDensityProblems interface + +The gradient computation is prepared during model compilation for optimal performance during sampling. + ## More Examples We have transcribed all the examples from the first volume of the BUGS Examples ([original](https://www.multibugs.org/examples/latest/VolumeI.html) and [transcribed](https://github.com/TuringLang/JuliaBUGS.jl/tree/main/JuliaBUGS/src/BUGSExamples/Volume_1)). All programs and data are included, and can be compiled using the steps described in the tutorial above. From 199ba715c61fe3006d209b5a4f85257f412435a2 Mon Sep 17 00:00:00 2001 From: Shravan Goswami Date: Thu, 2 Oct 2025 23:56:45 +0530 Subject: [PATCH 03/17] run JuliaFormatter --- JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl | 4 +--- JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl | 4 +--- JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl | 7 +++++- JuliaBUGS/src/JuliaBUGS.jl | 24 +++++++++---------- JuliaBUGS/src/model/logdensityproblems.jl | 6 ++++- .../test/BUGSPrimitives/distributions.jl | 4 +++- JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl | 24 ++++++++++--------- JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl | 4 ++-- JuliaBUGS/test/model/bugsmodel.jl | 20 +++++++++------- JuliaBUGS/test/parallel_sampling.jl | 2 +- 10 files changed, 55 insertions(+), 44 deletions(-) diff --git a/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl b/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl index bf052e3ed..88f7a1d9b 100644 --- a/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl @@ -44,9 +44,7 @@ function _gibbs_internal_hmc( ) # Create gradient model on-the-fly using DifferentiationInterface x = getparams(cond_model) - prep = DI.prepare_gradient( - _logdensity_switched, ad_backend, x, DI.Constant(cond_model) - ) + prep = DI.prepare_gradient(_logdensity_switched, ad_backend, x, DI.Constant(cond_model)) ad_model = BUGSModelWithGradient(ad_backend, prep, cond_model) logdensitymodel = AbstractMCMC.LogDensityModel(ad_model) diff --git a/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl b/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl index 1d07ade50..88c43249b 100644 --- a/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl +++ b/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl @@ -56,9 +56,7 @@ function _gibbs_internal_mh( ) # Create gradient model on-the-fly using DifferentiationInterface x = getparams(cond_model) - prep = DI.prepare_gradient( - _logdensity_switched, ad_backend, x, DI.Constant(cond_model) - ) + prep = DI.prepare_gradient(_logdensity_switched, ad_backend, x, DI.Constant(cond_model)) ad_model = BUGSModelWithGradient(ad_backend, prep, cond_model) logdensitymodel = AbstractMCMC.LogDensityModel(ad_model) diff --git a/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl b/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl index 40d77e848..a69007b74 100644 --- a/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl +++ b/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl @@ -2,7 +2,12 @@ module JuliaBUGSMCMCChainsExt using AbstractMCMC using JuliaBUGS -using JuliaBUGS: BUGSModel, BUGSModelWithGradient, find_generated_quantities_variables, evaluate!!, getparams +using JuliaBUGS: + BUGSModel, + BUGSModelWithGradient, + find_generated_quantities_variables, + evaluate!!, + getparams using JuliaBUGS.AbstractPPL using JuliaBUGS.Accessors using JuliaBUGS.LogDensityProblemsAD diff --git a/JuliaBUGS/src/JuliaBUGS.jl b/JuliaBUGS/src/JuliaBUGS.jl index 3e780011c..de31e43a7 100644 --- a/JuliaBUGS/src/JuliaBUGS.jl +++ b/JuliaBUGS/src/JuliaBUGS.jl @@ -320,14 +320,14 @@ function compile( ), ) base_model = BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params) - + # If adtype provided, wrap with gradient capabilities if adtype !== nothing # Convert symbol to ADType if needed adtype_obj = _resolve_adtype(adtype) return _wrap_with_gradient(base_model, adtype_obj) end - + return base_model end @@ -344,7 +344,7 @@ Supported symbol shortcuts: """ function _resolve_adtype(adtype::Symbol) if adtype === :ReverseDiff - return ADTypes.AutoReverseDiff(compile=true) + return ADTypes.AutoReverseDiff(; compile=true) elseif adtype === :ForwardDiff return ADTypes.AutoForwardDiff() elseif adtype === :Zygote @@ -352,9 +352,11 @@ function _resolve_adtype(adtype::Symbol) elseif adtype === :Enzyme return ADTypes.AutoEnzyme() else - error("Unknown AD backend symbol: $adtype. " * - "Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme. " * - "Or use an ADTypes object like AutoReverseDiff(compile=true).") + error( + "Unknown AD backend symbol: $adtype. " * + "Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme. " * + "Or use an ADTypes object like AutoReverseDiff(compile=true).", + ) end end @@ -366,17 +368,13 @@ function _wrap_with_gradient(base_model::Model.BUGSModel, adtype::ADTypes.Abstra # Get initial parameters for preparation # Use invokelatest to handle world age issues with generated functions x = Base.invokelatest(getparams, base_model) - + # Prepare gradient using DifferentiationInterface # Use invokelatest to handle world age issues when calling logdensity during preparation prep = Base.invokelatest( - DI.prepare_gradient, - Model._logdensity_switched, - adtype, - x, - DI.Constant(base_model) + DI.prepare_gradient, Model._logdensity_switched, adtype, x, DI.Constant(base_model) ) - + return Model.BUGSModelWithGradient(adtype, prep, base_model) end # function compile( diff --git a/JuliaBUGS/src/model/logdensityproblems.jl b/JuliaBUGS/src/model/logdensityproblems.jl index 1b0381cae..2b97c5c7d 100644 --- a/JuliaBUGS/src/model/logdensityproblems.jl +++ b/JuliaBUGS/src/model/logdensityproblems.jl @@ -102,7 +102,11 @@ function LogDensityProblems.logdensity_and_gradient( ) else return DI.value_and_gradient( - _logdensity_switched, model.prep, model.backend, x, DI.Constant(model.base_model) + _logdensity_switched, + model.prep, + model.backend, + x, + DI.Constant(model.base_model), ) end end diff --git a/JuliaBUGS/test/BUGSPrimitives/distributions.jl b/JuliaBUGS/test/BUGSPrimitives/distributions.jl index 0262b051d..82c4f04af 100644 --- a/JuliaBUGS/test/BUGSPrimitives/distributions.jl +++ b/JuliaBUGS/test/BUGSPrimitives/distributions.jl @@ -15,7 +15,9 @@ end A[1:2, 1:2] ~ dwish(B[:, :], 2) C[1:2] ~ dmnorm(mu[:], A[:, :]) end - ad_model = compile(model_def, (mu=[0, 0], B=[1 0; 0 1]), (A=[1 0; 0 1],); adtype=AutoReverseDiff()) + ad_model = compile( + model_def, (mu=[0, 0], B=[1 0; 0 1]), (A=[1 0; 0 1],); adtype=AutoReverseDiff() + ) theta = [ 0.7931743744870574, diff --git a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl index cb6a62c68..5877aa298 100644 --- a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl @@ -6,7 +6,7 @@ y = x[1] + x[3] end data = (mu=[0, 0], sigma=[1 0; 0 1]) - ad_model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + ad_model = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) n_samples, n_adapts = 10, 0 D = LogDensityProblems.dimension(ad_model) initial_θ = rand(D) @@ -34,19 +34,19 @@ end end data = (N=5, y=[1.0, 2.0, 1.5, 2.5, 1.8]) - + # Test that symbol shortcut works ad_model_symbol = compile(model_def, data; adtype=:ReverseDiff) - ad_model_explicit = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) - + ad_model_explicit = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) + @test ad_model_symbol isa JuliaBUGS.Model.BUGSModelWithGradient @test ad_model_explicit isa JuliaBUGS.Model.BUGSModelWithGradient - + # Test that both produce equivalent results n_samples, n_adapts = 100, 100 D = LogDensityProblems.dimension(ad_model_symbol) initial_θ = rand(StableRNG(123), D) - + samples_symbol = AbstractMCMC.sample( StableRNG(1234), ad_model_symbol, @@ -58,7 +58,7 @@ init_params=initial_θ, discard_initial=n_adapts, ) - + samples_explicit = AbstractMCMC.sample( StableRNG(1234), ad_model_explicit, @@ -70,10 +70,10 @@ init_params=initial_θ, discard_initial=n_adapts, ) - + # Results should be very similar (same RNG seed) - @test summarize(samples_symbol)[:mu].nt.mean[1] ≈ - summarize(samples_explicit)[:mu].nt.mean[1] rtol=0.1 + @test summarize(samples_symbol)[:mu].nt.mean[1] ≈ + summarize(samples_explicit)[:mu].nt.mean[1] rtol = 0.1 end @testset "Inference results on examples: $example" for example in @@ -81,7 +81,9 @@ (; model_def, data, inits, reference_results) = Base.getfield( JuliaBUGS.BUGSExamples, example ) - ad_model = JuliaBUGS.compile(model_def, data, inits; adtype=AutoReverseDiff(compile=true)) + ad_model = JuliaBUGS.compile( + model_def, data, inits; adtype=AutoReverseDiff(; compile=true) + ) n_samples, n_adapts = 1000, 1000 diff --git a/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl b/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl index 7d6f35832..2e7c16367 100644 --- a/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl +++ b/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl @@ -26,7 +26,7 @@ y=[1.58, 4.80, 7.10, 8.86, 11.73, 14.52, 18.22, 18.73, 21.04, 22.93], ) - ad_model = compile(model_def, data, (;); adtype=AutoReverseDiff(compile=true)) + ad_model = compile(model_def, data, (;); adtype=AutoReverseDiff(; compile=true)) n_samples, n_adapts = 2000, 1000 D = LogDensityProblems.dimension(ad_model) @@ -106,7 +106,7 @@ sigma[2] ~ InverseGamma(2, 3) sigma[3] ~ InverseGamma(2, 3) end - ad_model = compile(model_def, (;); adtype=AutoReverseDiff(compile=true)) + ad_model = compile(model_def, (;); adtype=AutoReverseDiff(; compile=true)) hmc_chain = AbstractMCMC.sample( ad_model, NUTS(0.8), 10; progress=false, chain_type=Chains ) diff --git a/JuliaBUGS/test/model/bugsmodel.jl b/JuliaBUGS/test/model/bugsmodel.jl index ddfe835c5..7a68d6264 100644 --- a/JuliaBUGS/test/model/bugsmodel.jl +++ b/JuliaBUGS/test/model/bugsmodel.jl @@ -414,22 +414,26 @@ end # Test :ReverseDiff shortcut model_rd = compile(model_def, data; adtype=:ReverseDiff) @test model_rd isa JuliaBUGS.Model.BUGSModelWithGradient - + # Test equivalence with explicit ADType - model_explicit = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + model_explicit = compile( + model_def, data; adtype=AutoReverseDiff(; compile=true) + ) @test model_explicit isa JuliaBUGS.Model.BUGSModelWithGradient - + # Test that unknown symbol throws error @test_throws ErrorException compile(model_def, data; adtype=:UnknownBackend) end @testset "Explicit ADTypes" begin # Test with compile=true - model_compile = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + model_compile = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) @test model_compile isa JuliaBUGS.Model.BUGSModelWithGradient - + # Test with compile=false - model_nocompile = compile(model_def, data; adtype=AutoReverseDiff(compile=false)) + model_nocompile = compile( + model_def, data; adtype=AutoReverseDiff(; compile=false) + ) @test model_nocompile isa JuliaBUGS.Model.BUGSModelWithGradient end @@ -443,10 +447,10 @@ end @testset "Gradient computation" begin model = compile(model_def, data; adtype=:ReverseDiff) test_point = [0.0] - + # Test that gradient can be computed ℓ, grad = LogDensityProblems.logdensity_and_gradient(model, test_point) - + @test ℓ isa Real @test grad isa Vector @test length(grad) == 1 diff --git a/JuliaBUGS/test/parallel_sampling.jl b/JuliaBUGS/test/parallel_sampling.jl index 8b857e4a0..23dfa4ebf 100644 --- a/JuliaBUGS/test/parallel_sampling.jl +++ b/JuliaBUGS/test/parallel_sampling.jl @@ -20,7 +20,7 @@ inits = (mu=0.0, tau=1.0) # Use compile=false for thread safety with ReverseDiff - ad_model = compile(model_def, data, inits; adtype=AutoReverseDiff(compile=false)) + ad_model = compile(model_def, data, inits; adtype=AutoReverseDiff(; compile=false)) # Single chain reference n_samples = 200 From 1e515d1e1edb3a2bb24c66e3b273c2f76035548b Mon Sep 17 00:00:00 2001 From: Shravan Goswami Date: Fri, 3 Oct 2025 00:26:53 +0530 Subject: [PATCH 04/17] try to fix benchmark failures --- JuliaBUGS/benchmark/benchmark.jl | 9 +++------ JuliaBUGS/benchmark/run_benchmarks.jl | 4 ++-- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/JuliaBUGS/benchmark/benchmark.jl b/JuliaBUGS/benchmark/benchmark.jl index 1c558fd16..14f86dd32 100644 --- a/JuliaBUGS/benchmark/benchmark.jl +++ b/JuliaBUGS/benchmark/benchmark.jl @@ -83,16 +83,13 @@ function _create_results_dataframe(results::OrderedDict{Symbol,BenchmarkResult}) ), ) end + DataFrames.rename!(df, :Density_Time => "Density Time (µs)", :Density_Gradient_Time => "Density+Gradient Time (µs)") return df end function _print_results_table( - results::OrderedDict{Symbol,BenchmarkResult}; backend=Val(:text) + results::OrderedDict{Symbol,BenchmarkResult}; backend=:text ) df = _create_results_dataframe(results) - return pretty_table( - df; - header=["Model", "Parameters", "Density Time (µs)", "Density+Gradient Time (µs)"], - backend=backend, - ) + return pretty_table(df; backend=backend) end diff --git a/JuliaBUGS/benchmark/run_benchmarks.jl b/JuliaBUGS/benchmark/run_benchmarks.jl index 3194239f3..88459701c 100644 --- a/JuliaBUGS/benchmark/run_benchmarks.jl +++ b/JuliaBUGS/benchmark/run_benchmarks.jl @@ -45,7 +45,7 @@ for (model_name, model) in zip(examples_to_benchmark, juliabugs_models) end println("### Stan results:") -_print_results_table(stan_results; backend=Val(:markdown)) +_print_results_table(stan_results; backend=:markdown) println("### JuliaBUGS Mooncake results:") -_print_results_table(juliabugs_results; backend=Val(:markdown)) +_print_results_table(juliabugs_results; backend=:markdown) From f0135a8df6b142d9272fc8dbce5d63af21b65f60 Mon Sep 17 00:00:00 2001 From: Shravan Goswami Date: Fri, 3 Oct 2025 00:45:37 +0530 Subject: [PATCH 05/17] format --- JuliaBUGS/benchmark/benchmark.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/JuliaBUGS/benchmark/benchmark.jl b/JuliaBUGS/benchmark/benchmark.jl index 14f86dd32..30d5985a6 100644 --- a/JuliaBUGS/benchmark/benchmark.jl +++ b/JuliaBUGS/benchmark/benchmark.jl @@ -83,13 +83,15 @@ function _create_results_dataframe(results::OrderedDict{Symbol,BenchmarkResult}) ), ) end - DataFrames.rename!(df, :Density_Time => "Density Time (µs)", :Density_Gradient_Time => "Density+Gradient Time (µs)") + DataFrames.rename!( + df, + :Density_Time => "Density Time (µs)", + :Density_Gradient_Time => "Density+Gradient Time (µs)", + ) return df end -function _print_results_table( - results::OrderedDict{Symbol,BenchmarkResult}; backend=:text -) +function _print_results_table(results::OrderedDict{Symbol,BenchmarkResult}; backend=:text) df = _create_results_dataframe(results) return pretty_table(df; backend=backend) end From c80dc9f73267431246b03cf24fad040867e3411b Mon Sep 17 00:00:00 2001 From: Shravan Goswami Date: Mon, 6 Oct 2025 15:38:44 +0530 Subject: [PATCH 06/17] add Mooncake and update docs --- JuliaBUGS/History.md | 2 +- JuliaBUGS/docs/src/example.md | 4 ++++ JuliaBUGS/src/JuliaBUGS.jl | 7 ++++++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/JuliaBUGS/History.md b/JuliaBUGS/History.md index 066d5ad3a..65f850dd2 100644 --- a/JuliaBUGS/History.md +++ b/JuliaBUGS/History.md @@ -4,7 +4,7 @@ - **DifferentiationInterface.jl integration**: JuliaBUGS now uses [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) for automatic differentiation, providing a unified interface to multiple AD backends. - Add `adtype` parameter to `compile()` function for specifying AD backends - - Support convenient symbol shortcuts: `:ReverseDiff`, `:ForwardDiff`, `:Zygote`, `:Enzyme` + - Support convenient symbol shortcuts: `:ReverseDiff`, `:ForwardDiff`, `:Zygote`, `:Enzyme`, `:Mooncake` - Gradient computation is prepared during compilation for optimal performance - Example: `model = compile(model_def, data; adtype=:ReverseDiff)` - Full control available via explicit ADTypes: `adtype=AutoReverseDiff(compile=true)` diff --git a/JuliaBUGS/docs/src/example.md b/JuliaBUGS/docs/src/example.md index 3cf73a4cc..2fc947740 100644 --- a/JuliaBUGS/docs/src/example.md +++ b/JuliaBUGS/docs/src/example.md @@ -409,6 +409,7 @@ The following AD backends are supported via convenient symbol shortcuts: - **`:ForwardDiff`** — Forward-mode AD, efficient for models with few parameters (typically < 20). - **`:Zygote`** — Source-to-source reverse-mode AD, general-purpose but may be slower than ReverseDiff for many models. - **`:Enzyme`** — Experimental high-performance AD backend with LLVM-level transformations. +- **`:Mooncake`** — High-performance reverse-mode AD with advanced optimizations. ### Usage Examples @@ -449,6 +450,9 @@ model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) - **Tape compilation trade-off**: While `AutoReverseDiff(compile=true)` has higher initial compilation time, it provides faster gradient evaluations during sampling. For quick prototyping or models that will only be sampled a few times, `AutoReverseDiff(compile=false)` may be preferable. +!!! warning "Compiled tapes and control flow" + Compiled ReverseDiff tapes cannot handle value-dependent control flow (e.g., `if x[1] > 0`). If your model has such control flow, use `AutoReverseDiff(compile=false)` or a different backend like `:ForwardDiff` or `:Mooncake`. See the [ReverseDiff documentation](https://juliadiff.org/ReverseDiff.jl/stable/api/#The-AbstractTape-API) for details. + ### Compatibility All models compiled with an `adtype` implement the full [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/) interface, making them compatible with: diff --git a/JuliaBUGS/src/JuliaBUGS.jl b/JuliaBUGS/src/JuliaBUGS.jl index de31e43a7..ef7dfd106 100644 --- a/JuliaBUGS/src/JuliaBUGS.jl +++ b/JuliaBUGS/src/JuliaBUGS.jl @@ -264,6 +264,8 @@ samplers like HMC/NUTS. The gradient computation is prepared during compilation - `:ReverseDiff` - Shorthand for `AutoReverseDiff(compile=true)` - `:ForwardDiff` - Shorthand for `AutoForwardDiff()` - `:Zygote` - Shorthand for `AutoZygote()` + - `:Enzyme` - Shorthand for `AutoEnzyme()` + - `:Mooncake` - Shorthand for `AutoMooncake()` - Any other `ADTypes.AbstractADType` # Examples @@ -341,6 +343,7 @@ Supported symbol shortcuts: - `:ForwardDiff` -> `AutoForwardDiff()` - `:Zygote` -> `AutoZygote()` - `:Enzyme` -> `AutoEnzyme()` +- `:Mooncake` -> `AutoMooncake()` """ function _resolve_adtype(adtype::Symbol) if adtype === :ReverseDiff @@ -351,10 +354,12 @@ function _resolve_adtype(adtype::Symbol) return ADTypes.AutoZygote() elseif adtype === :Enzyme return ADTypes.AutoEnzyme() + elseif adtype === :Mooncake + return ADTypes.AutoMooncake(; config=nothing) else error( "Unknown AD backend symbol: $adtype. " * - "Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme. " * + "Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme, :Mooncake. " * "Or use an ADTypes object like AutoReverseDiff(compile=true).", ) end From ef36d7d2fc9dde0d7a9898e4e05a7baa7110698f Mon Sep 17 00:00:00 2001 From: Shravan Goswami Date: Tue, 7 Oct 2025 12:13:12 +0530 Subject: [PATCH 07/17] remove symbols --- JuliaBUGS/History.md | 7 +-- JuliaBUGS/docs/src/example.md | 52 ++++++++--------- JuliaBUGS/examples/sir.jl | 3 +- JuliaBUGS/src/JuliaBUGS.jl | 57 +++---------------- JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl | 26 ++++----- JuliaBUGS/test/model/bugsmodel.jl | 19 +------ 6 files changed, 54 insertions(+), 110 deletions(-) diff --git a/JuliaBUGS/History.md b/JuliaBUGS/History.md index 65f850dd2..c004752ac 100644 --- a/JuliaBUGS/History.md +++ b/JuliaBUGS/History.md @@ -3,11 +3,10 @@ ## 0.10.4 - **DifferentiationInterface.jl integration**: JuliaBUGS now uses [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) for automatic differentiation, providing a unified interface to multiple AD backends. - - Add `adtype` parameter to `compile()` function for specifying AD backends - - Support convenient symbol shortcuts: `:ReverseDiff`, `:ForwardDiff`, `:Zygote`, `:Enzyme`, `:Mooncake` + - Add `adtype` parameter to `compile()` function for specifying AD backends via [ADTypes.jl](https://github.com/SciML/ADTypes.jl) + - Supports multiple backends: `AutoReverseDiff`, `AutoForwardDiff`, `AutoZygote`, `AutoEnzyme`, `AutoMooncake` - Gradient computation is prepared during compilation for optimal performance - - Example: `model = compile(model_def, data; adtype=:ReverseDiff)` - - Full control available via explicit ADTypes: `adtype=AutoReverseDiff(compile=true)` + - Example: `model = compile(model_def, data; adtype=AutoReverseDiff(compile=true))` - Backward compatible: models without `adtype` work as before ## 0.10.1 diff --git a/JuliaBUGS/docs/src/example.md b/JuliaBUGS/docs/src/example.md index 2fc947740..57ea54466 100644 --- a/JuliaBUGS/docs/src/example.md +++ b/JuliaBUGS/docs/src/example.md @@ -199,21 +199,19 @@ JuliaBUGS integrates with automatic differentiation (AD) through [Differentiatio To compile a model with gradient support, pass the `adtype` parameter to `compile`: ```julia -# Using explicit ADType from ADTypes.jl +# Compile with gradient support using ADTypes from ADTypes.jl using ADTypes model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) - -# Using convenient symbol shortcuts -model = compile(model_def, data; adtype=:ReverseDiff) # Equivalent to above ``` Available AD backends include: -- `:ReverseDiff` - ReverseDiff with tape compilation (recommended for most models) -- `:ForwardDiff` - ForwardDiff (efficient for models with few parameters) -- `:Zygote` - Zygote (source-to-source AD) -- `:Enzyme` - Enzyme (experimental, high-performance) +- `AutoReverseDiff(compile=true)` - ReverseDiff with tape compilation (recommended for most models) +- `AutoForwardDiff()` - ForwardDiff (efficient for models with few parameters) +- `AutoZygote()` - Zygote (source-to-source AD) +- `AutoEnzyme()` - Enzyme (experimental, high-performance) +- `AutoMooncake()` - Mooncake (high-performance reverse-mode AD) -For fine-grained control, use explicit `ADTypes` constructors: +For fine-grained control, you can configure the AD backend: ```julia # ReverseDiff without compilation @@ -230,7 +228,7 @@ For gradient-based inference, we use [`AdvancedHMC.jl`](https://github.com/Turin using AdvancedHMC, AbstractMCMC, LogDensityProblems, MCMCChains, ReverseDiff # Compile with gradient support -model = compile(model_def, data; adtype=:ReverseDiff) +model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) n_samples, n_adapts = 2000, 1000 @@ -403,34 +401,36 @@ JuliaBUGS integrates with multiple automatic differentiation (AD) backends throu ### Available Backends -The following AD backends are supported via convenient symbol shortcuts: +The following AD backends are supported via [ADTypes.jl](https://github.com/SciML/ADTypes.jl): -- **`:ReverseDiff`** (recommended) — Tape-based reverse-mode AD, highly efficient for models with many parameters. Uses compilation by default for optimal performance. -- **`:ForwardDiff`** — Forward-mode AD, efficient for models with few parameters (typically < 20). -- **`:Zygote`** — Source-to-source reverse-mode AD, general-purpose but may be slower than ReverseDiff for many models. -- **`:Enzyme`** — Experimental high-performance AD backend with LLVM-level transformations. -- **`:Mooncake`** — High-performance reverse-mode AD with advanced optimizations. +- **`AutoReverseDiff(compile=true)`** (recommended) — Tape-based reverse-mode AD, highly efficient for models with many parameters. Uses compilation by default for optimal performance. +- **`AutoForwardDiff()`** — Forward-mode AD, efficient for models with few parameters (typically < 20). +- **`AutoZygote()`** — Source-to-source reverse-mode AD, general-purpose but may be slower than ReverseDiff for many models. +- **`AutoEnzyme()`** — Experimental high-performance AD backend with LLVM-level transformations. +- **`AutoMooncake()`** — High-performance reverse-mode AD with advanced optimizations. ### Usage Examples -#### Basic Usage with Symbol Shortcuts +#### Basic Usage -The simplest way to specify an AD backend is using symbol shortcuts: +Specify an AD backend using ADTypes: ```julia +using ADTypes + # ReverseDiff with tape compilation (recommended for most models) -model = compile(model_def, data; adtype=:ReverseDiff) +model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) # ForwardDiff (good for small models with few parameters) -model = compile(model_def, data; adtype=:ForwardDiff) +model = compile(model_def, data; adtype=AutoForwardDiff()) # Zygote (source-to-source AD) -model = compile(model_def, data; adtype=:Zygote) +model = compile(model_def, data; adtype=AutoZygote()) ``` #### Advanced Configuration -For fine-grained control, use explicit `ADTypes` constructors: +For fine-grained control, you can configure the AD backends: ```julia using ADTypes @@ -438,20 +438,20 @@ using ADTypes # ReverseDiff without tape compilation model = compile(model_def, data; adtype=AutoReverseDiff(compile=false)) -# ReverseDiff with compilation (equivalent to :ReverseDiff) +# ReverseDiff with compilation (default, recommended) model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) ``` ### Performance Considerations -- **ReverseDiff with compilation** (`:ReverseDiff`) is recommended for most models, especially those with many parameters. Compilation adds a one-time overhead but significantly speeds up subsequent gradient evaluations. +- **ReverseDiff with compilation** (`AutoReverseDiff(compile=true)`) is recommended for most models, especially those with many parameters. Compilation adds a one-time overhead but significantly speeds up subsequent gradient evaluations. -- **ForwardDiff** (`:ForwardDiff`) is often faster for models with few parameters (< 20), as it avoids tape construction overhead. +- **ForwardDiff** (`AutoForwardDiff()`) is often faster for models with few parameters (< 20), as it avoids tape construction overhead. - **Tape compilation trade-off**: While `AutoReverseDiff(compile=true)` has higher initial compilation time, it provides faster gradient evaluations during sampling. For quick prototyping or models that will only be sampled a few times, `AutoReverseDiff(compile=false)` may be preferable. !!! warning "Compiled tapes and control flow" - Compiled ReverseDiff tapes cannot handle value-dependent control flow (e.g., `if x[1] > 0`). If your model has such control flow, use `AutoReverseDiff(compile=false)` or a different backend like `:ForwardDiff` or `:Mooncake`. See the [ReverseDiff documentation](https://juliadiff.org/ReverseDiff.jl/stable/api/#The-AbstractTape-API) for details. + Compiled ReverseDiff tapes cannot handle value-dependent control flow (e.g., `if x[1] > 0`). If your model has such control flow, use `AutoReverseDiff(compile=false)` or a different backend like `AutoForwardDiff()` or `AutoMooncake()`. See the [ReverseDiff documentation](https://juliadiff.org/ReverseDiff.jl/stable/api/#The-AbstractTape-API) for details. ### Compatibility diff --git a/JuliaBUGS/examples/sir.jl b/JuliaBUGS/examples/sir.jl index 108d47ce1..ccd4f3279 100644 --- a/JuliaBUGS/examples/sir.jl +++ b/JuliaBUGS/examples/sir.jl @@ -7,6 +7,7 @@ using JuliaBUGS: @model using Distributions using DifferentialEquations using LogDensityProblems, LogDensityProblemsAD +using ADTypes using AbstractMCMC, AdvancedHMC, MCMCChains using Distributed # For distributed example @@ -113,7 +114,7 @@ model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGraph()) # --- MCMC Sampling: NUTS with ForwardDiff AD --- # Create an AD-aware wrapper for the model using ForwardDiff for gradients -ad_model_forwarddiff = ADgradient(:ForwardDiff, model) +ad_model_forwarddiff = ADgradient(AutoForwardDiff(), model) # MCMC settings n_samples = 1000 diff --git a/JuliaBUGS/src/JuliaBUGS.jl b/JuliaBUGS/src/JuliaBUGS.jl index ef7dfd106..a6e4e5f0c 100644 --- a/JuliaBUGS/src/JuliaBUGS.jl +++ b/JuliaBUGS/src/JuliaBUGS.jl @@ -258,14 +258,13 @@ samplers like HMC/NUTS. The gradient computation is prepared during compilation - `initial_params::NamedTuple=NamedTuple()`: Initial parameter values (optional) - `skip_validation::Bool=false`: Skip function validation (for @model macro) - `eval_module::Module=@__MODULE__`: Module for evaluation -- `adtype`: AD backend specification. Can be: +- `adtype`: AD backend specification using ADTypes. Examples: - `AutoReverseDiff(compile=true)` - ReverseDiff with tape compilation (fastest) - `AutoReverseDiff(compile=false)` - ReverseDiff without compilation - - `:ReverseDiff` - Shorthand for `AutoReverseDiff(compile=true)` - - `:ForwardDiff` - Shorthand for `AutoForwardDiff()` - - `:Zygote` - Shorthand for `AutoZygote()` - - `:Enzyme` - Shorthand for `AutoEnzyme()` - - `:Mooncake` - Shorthand for `AutoMooncake()` + - `AutoForwardDiff()` - ForwardDiff backend + - `AutoZygote()` - Zygote backend + - `AutoEnzyme()` - Enzyme backend + - `AutoMooncake()` - Mooncake backend - Any other `ADTypes.AbstractADType` # Examples @@ -273,14 +272,11 @@ samplers like HMC/NUTS. The gradient computation is prepared during compilation # Basic compilation model = compile(model_def, data) -# With gradient support using explicit ADType +# With gradient support using ReverseDiff (recommended for most models) model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) -# With gradient support using symbol shorthand -model = compile(model_def, data; adtype=:ReverseDiff) # Same as above - # Using ForwardDiff for small models -model = compile(model_def, data; adtype=:ForwardDiff) +model = compile(model_def, data; adtype=AutoForwardDiff()) # Sample with NUTS chain = AbstractMCMC.sample(model, NUTS(0.8), 1000) @@ -325,49 +321,12 @@ function compile( # If adtype provided, wrap with gradient capabilities if adtype !== nothing - # Convert symbol to ADType if needed - adtype_obj = _resolve_adtype(adtype) - return _wrap_with_gradient(base_model, adtype_obj) + return _wrap_with_gradient(base_model, adtype) end return base_model end -""" - _resolve_adtype(adtype) -> ADTypes.AbstractADType - -Convert symbol shortcuts to ADTypes, or return the ADType as-is. - -Supported symbol shortcuts: -- `:ReverseDiff` -> `AutoReverseDiff(compile=true)` -- `:ForwardDiff` -> `AutoForwardDiff()` -- `:Zygote` -> `AutoZygote()` -- `:Enzyme` -> `AutoEnzyme()` -- `:Mooncake` -> `AutoMooncake()` -""" -function _resolve_adtype(adtype::Symbol) - if adtype === :ReverseDiff - return ADTypes.AutoReverseDiff(; compile=true) - elseif adtype === :ForwardDiff - return ADTypes.AutoForwardDiff() - elseif adtype === :Zygote - return ADTypes.AutoZygote() - elseif adtype === :Enzyme - return ADTypes.AutoEnzyme() - elseif adtype === :Mooncake - return ADTypes.AutoMooncake(; config=nothing) - else - error( - "Unknown AD backend symbol: $adtype. " * - "Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme, :Mooncake. " * - "Or use an ADTypes object like AutoReverseDiff(compile=true).", - ) - end -end - -# Pass through ADTypes objects unchanged -_resolve_adtype(adtype::ADTypes.AbstractADType) = adtype - # Helper function to prepare gradient - separated to handle world age issues function _wrap_with_gradient(base_model::Model.BUGSModel, adtype::ADTypes.AbstractADType) # Get initial parameters for preparation diff --git a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl index 5877aa298..843963864 100644 --- a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl @@ -26,7 +26,7 @@ [Symbol("x[3]"), Symbol("x[1:2][1]"), Symbol("x[1:2][2]"), :y] end - @testset "Symbol AD backend shortcuts" begin + @testset "AD backend sampling" begin model_def = @bugs begin mu ~ dnorm(0, 1) for i in 1:N @@ -35,21 +35,21 @@ end data = (N=5, y=[1.0, 2.0, 1.5, 2.5, 1.8]) - # Test that symbol shortcut works - ad_model_symbol = compile(model_def, data; adtype=:ReverseDiff) - ad_model_explicit = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) + # Test that ReverseDiff backend works + ad_model_compiled = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) + ad_model_nocompile = compile(model_def, data; adtype=AutoReverseDiff(; compile=false)) - @test ad_model_symbol isa JuliaBUGS.Model.BUGSModelWithGradient - @test ad_model_explicit isa JuliaBUGS.Model.BUGSModelWithGradient + @test ad_model_compiled isa JuliaBUGS.Model.BUGSModelWithGradient + @test ad_model_nocompile isa JuliaBUGS.Model.BUGSModelWithGradient # Test that both produce equivalent results n_samples, n_adapts = 100, 100 - D = LogDensityProblems.dimension(ad_model_symbol) + D = LogDensityProblems.dimension(ad_model_compiled) initial_θ = rand(StableRNG(123), D) - samples_symbol = AbstractMCMC.sample( + samples_compiled = AbstractMCMC.sample( StableRNG(1234), - ad_model_symbol, + ad_model_compiled, NUTS(0.8), n_samples; progress=false, @@ -59,9 +59,9 @@ discard_initial=n_adapts, ) - samples_explicit = AbstractMCMC.sample( + samples_nocompile = AbstractMCMC.sample( StableRNG(1234), - ad_model_explicit, + ad_model_nocompile, NUTS(0.8), n_samples; progress=false, @@ -72,8 +72,8 @@ ) # Results should be very similar (same RNG seed) - @test summarize(samples_symbol)[:mu].nt.mean[1] ≈ - summarize(samples_explicit)[:mu].nt.mean[1] rtol = 0.1 + @test summarize(samples_compiled)[:mu].nt.mean[1] ≈ + summarize(samples_nocompile)[:mu].nt.mean[1] rtol = 0.1 end @testset "Inference results on examples: $example" for example in diff --git a/JuliaBUGS/test/model/bugsmodel.jl b/JuliaBUGS/test/model/bugsmodel.jl index 7a68d6264..71b19834d 100644 --- a/JuliaBUGS/test/model/bugsmodel.jl +++ b/JuliaBUGS/test/model/bugsmodel.jl @@ -410,22 +410,7 @@ end end data = (y=1.5,) - @testset "Symbol shortcuts" begin - # Test :ReverseDiff shortcut - model_rd = compile(model_def, data; adtype=:ReverseDiff) - @test model_rd isa JuliaBUGS.Model.BUGSModelWithGradient - - # Test equivalence with explicit ADType - model_explicit = compile( - model_def, data; adtype=AutoReverseDiff(; compile=true) - ) - @test model_explicit isa JuliaBUGS.Model.BUGSModelWithGradient - - # Test that unknown symbol throws error - @test_throws ErrorException compile(model_def, data; adtype=:UnknownBackend) - end - - @testset "Explicit ADTypes" begin + @testset "ADTypes backends" begin # Test with compile=true model_compile = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) @test model_compile isa JuliaBUGS.Model.BUGSModelWithGradient @@ -445,7 +430,7 @@ end end @testset "Gradient computation" begin - model = compile(model_def, data; adtype=:ReverseDiff) + model = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) test_point = [0.0] # Test that gradient can be computed From 2f8b72d24f2717742c4eae199b4b499b7d54b6cc Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Sun, 26 Oct 2025 23:43:24 +0000 Subject: [PATCH 08/17] Update JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl index 843963864..6549aa93e 100644 --- a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl @@ -37,7 +37,9 @@ # Test that ReverseDiff backend works ad_model_compiled = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) - ad_model_nocompile = compile(model_def, data; adtype=AutoReverseDiff(; compile=false)) + ad_model_nocompile = compile( + model_def, data; adtype=AutoReverseDiff(; compile=false) + ) @test ad_model_compiled isa JuliaBUGS.Model.BUGSModelWithGradient @test ad_model_nocompile isa JuliaBUGS.Model.BUGSModelWithGradient From a44fda50f2d241fb81cf77bc2c42451dcc587a86 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 3 Dec 2025 10:04:11 +0000 Subject: [PATCH 09/17] Refactor BUGSModelWithGradient and remove LogDensityProblemsAD support BREAKING: Use compile(...; adtype=...) or BUGSModelWithGradient(model, adtype) instead of ADgradient() --- JuliaBUGS/History.md | 7 +- JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl | 39 +--------- JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl | 35 +-------- JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl | 15 ---- JuliaBUGS/src/JuliaBUGS.jl | 11 +-- JuliaBUGS/src/gibbs.jl | 2 +- JuliaBUGS/src/model/Model.jl | 13 +++- JuliaBUGS/src/model/logdensityproblems.jl | 94 +++++++++++++++-------- 8 files changed, 83 insertions(+), 133 deletions(-) diff --git a/JuliaBUGS/History.md b/JuliaBUGS/History.md index c004752ac..24fb9c5f9 100644 --- a/JuliaBUGS/History.md +++ b/JuliaBUGS/History.md @@ -7,7 +7,12 @@ - Supports multiple backends: `AutoReverseDiff`, `AutoForwardDiff`, `AutoZygote`, `AutoEnzyme`, `AutoMooncake` - Gradient computation is prepared during compilation for optimal performance - Example: `model = compile(model_def, data; adtype=AutoReverseDiff(compile=true))` - - Backward compatible: models without `adtype` work as before + - New `BUGSModelWithGradient(model, adtype)` constructor for adding gradients to existing models + - Models without `adtype` work as before (no gradients) + +- **Breaking**: `LogDensityProblemsAD.ADgradient` is no longer supported for gradient computation. + - **Old**: `ad_model = ADgradient(:ReverseDiff, model)` + - **New**: `model = compile(model_def, data; adtype=AutoReverseDiff())` or `BUGSModelWithGradient(model, AutoReverseDiff())` ## 0.10.1 diff --git a/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl b/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl index 88f7a1d9b..eca960856 100644 --- a/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl @@ -3,12 +3,9 @@ module JuliaBUGSAdvancedHMCExt using AbstractMCMC using AdvancedHMC using ADTypes -import DifferentiationInterface as DI using JuliaBUGS using JuliaBUGS: BUGSModel, BUGSModelWithGradient, getparams, initialize! using JuliaBUGS.LogDensityProblems -using JuliaBUGS.LogDensityProblemsAD -using JuliaBUGS.Model: _logdensity_switched using JuliaBUGS.Random using MCMCChains: Chains @@ -42,10 +39,9 @@ end function _gibbs_internal_hmc( rng::Random.AbstractRNG, cond_model::BUGSModel, sampler, ad_backend, state ) - # Create gradient model on-the-fly using DifferentiationInterface + # Create gradient model on-the-fly + ad_model = BUGSModelWithGradient(cond_model, ad_backend) x = getparams(cond_model) - prep = DI.prepare_gradient(_logdensity_switched, ad_backend, x, DI.Constant(cond_model)) - ad_model = BUGSModelWithGradient(ad_backend, prep, cond_model) logdensitymodel = AbstractMCMC.LogDensityModel(ad_model) # Take HMC/NUTS step @@ -101,35 +97,4 @@ function AbstractMCMC.bundle_samples( ) end -# Keep backward compatibility with LogDensityProblemsAD wrapper -function AbstractMCMC.bundle_samples( - ts::Vector{<:AdvancedHMC.Transition}, - logdensitymodel::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper}, - sampler::AdvancedHMC.AbstractHMCSampler, - state, - chain_type::Type{Chains}; - discard_initial=0, - thinning=1, - kwargs..., -) - param_samples = [t.z.θ for t in ts] - - stats_names = collect(keys(merge((; lp=ts[1].z.ℓπ.value), AdvancedHMC.stat(ts[1])))) - stats_values = [ - vcat(ts[i].z.ℓπ.value, collect(values(AdvancedHMC.stat(ts[i])))) for - i in eachindex(ts) - ] - - # Delegate to gen_chains for proper parameter naming from BUGSModel - return JuliaBUGS.gen_chains( - logdensitymodel, - param_samples, - stats_names, - stats_values; - discard_initial=discard_initial, - thinning=thinning, - kwargs..., - ) -end - end diff --git a/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl b/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl index 88c43249b..edab75d99 100644 --- a/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl +++ b/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl @@ -3,12 +3,9 @@ module JuliaBUGSAdvancedMHExt using AbstractMCMC using AdvancedMH using ADTypes -import DifferentiationInterface as DI using JuliaBUGS using JuliaBUGS: BUGSModel, BUGSModelWithGradient, getparams, initialize! using JuliaBUGS.LogDensityProblems -using JuliaBUGS.LogDensityProblemsAD -using JuliaBUGS.Model: _logdensity_switched using JuliaBUGS.Random using MCMCChains: Chains @@ -54,10 +51,9 @@ end function _gibbs_internal_mh( rng::Random.AbstractRNG, cond_model::BUGSModel, sampler, ad_backend, state ) - # Create gradient model on-the-fly using DifferentiationInterface + # Create gradient model on-the-fly + ad_model = BUGSModelWithGradient(cond_model, ad_backend) x = getparams(cond_model) - prep = DI.prepare_gradient(_logdensity_switched, ad_backend, x, DI.Constant(cond_model)) - ad_model = BUGSModelWithGradient(ad_backend, prep, cond_model) logdensitymodel = AbstractMCMC.LogDensityModel(ad_model) # Take MH step with gradient information @@ -131,31 +127,4 @@ function AbstractMCMC.bundle_samples( ) end -# Keep backward compatibility with LogDensityProblemsAD wrapper -function AbstractMCMC.bundle_samples( - ts::Vector{<:AdvancedMH.Transition}, - logdensitymodel::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper}, - sampler::AdvancedMH.MHSampler, - state, - chain_type::Type{Chains}; - discard_initial=0, - thinning=1, - kwargs..., -) - # Same extraction for gradient-based MH samplers - param_samples = [t.params for t in ts] - stats_names = [:lp] - stats_values = [[t.lp] for t in ts] - - return JuliaBUGS.gen_chains( - logdensitymodel, - param_samples, - stats_names, - stats_values; - discard_initial=discard_initial, - thinning=thinning, - kwargs..., - ) -end - end diff --git a/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl b/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl index a69007b74..224883c36 100644 --- a/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl +++ b/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl @@ -10,7 +10,6 @@ using JuliaBUGS: getparams using JuliaBUGS.AbstractPPL using JuliaBUGS.Accessors -using JuliaBUGS.LogDensityProblemsAD using MCMCChains: Chains function JuliaBUGS.gen_chains( @@ -39,20 +38,6 @@ function JuliaBUGS.gen_chains( return JuliaBUGS.gen_chains(bugs_model, samples, stats_names, stats_values; kwargs...) end -# Keep backward compatibility with LogDensityProblemsAD wrapper -function JuliaBUGS.gen_chains( - model::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper}, - samples, - stats_names, - stats_values; - kwargs..., -) - # Extract BUGSModel from ADGradient wrapper - bugs_model = model.logdensity.ℓ - - return JuliaBUGS.gen_chains(bugs_model, samples, stats_names, stats_values; kwargs...) -end - """ elementwise_varnames(vn::VarName, val) diff --git a/JuliaBUGS/src/JuliaBUGS.jl b/JuliaBUGS/src/JuliaBUGS.jl index 4ee74962d..2c840c424 100644 --- a/JuliaBUGS/src/JuliaBUGS.jl +++ b/JuliaBUGS/src/JuliaBUGS.jl @@ -332,17 +332,8 @@ end # Helper function to prepare gradient - separated to handle world age issues function _wrap_with_gradient(base_model::Model.BUGSModel, adtype::ADTypes.AbstractADType) - # Get initial parameters for preparation # Use invokelatest to handle world age issues with generated functions - x = Base.invokelatest(getparams, base_model) - - # Prepare gradient using DifferentiationInterface - # Use invokelatest to handle world age issues when calling logdensity during preparation - prep = Base.invokelatest( - DI.prepare_gradient, Model._logdensity_switched, adtype, x, DI.Constant(base_model) - ) - - return Model.BUGSModelWithGradient(adtype, prep, base_model) + return Base.invokelatest(Model.BUGSModelWithGradient, base_model, adtype) end # function compile( # model_str::String, diff --git a/JuliaBUGS/src/gibbs.jl b/JuliaBUGS/src/gibbs.jl index fa71d85b9..8a1c4e870 100644 --- a/JuliaBUGS/src/gibbs.jl +++ b/JuliaBUGS/src/gibbs.jl @@ -432,7 +432,7 @@ function AbstractMCMC.step( # For gradient-based samplers, wrap with AD _, ad_backend = sub_sampler logdensitymodel = AbstractMCMC.LogDensityModel( - LogDensityProblemsAD.ADgradient(ad_backend, cond_model) + Model.BUGSModelWithGradient(cond_model, ad_backend) ) else # For non-gradient samplers, use model directly diff --git a/JuliaBUGS/src/model/Model.jl b/JuliaBUGS/src/model/Model.jl index 2efe7adb1..c59d123b5 100644 --- a/JuliaBUGS/src/model/Model.jl +++ b/JuliaBUGS/src/model/Model.jl @@ -20,9 +20,14 @@ include("evaluation.jl") include("abstractppl.jl") include("logdensityproblems.jl") -export parameters, variables, initialize!, getparams, settrans, set_evaluation_mode -export regenerate_log_density_function, set_observed_values! -export evaluate_with_rng!!, evaluate_with_env!!, evaluate_with_values!! -export BUGSModelWithGradient, _logdensity_switched +# Public user-facing API +export parameters, variables, initialize!, getparams, settrans +export set_evaluation_mode, set_observed_values! + +# Evaluation mode types +export UseGraph, UseGeneratedLogDensityFunction + +# Gradient wrapper +export BUGSModelWithGradient end # Model diff --git a/JuliaBUGS/src/model/logdensityproblems.jl b/JuliaBUGS/src/model/logdensityproblems.jl index 2b97c5c7d..30e195d82 100644 --- a/JuliaBUGS/src/model/logdensityproblems.jl +++ b/JuliaBUGS/src/model/logdensityproblems.jl @@ -26,14 +26,14 @@ function LogDensityProblems.capabilities(::AbstractBUGSModel) end """ - BUGSModelWithGradient{B,P,M} + BUGSModelWithGradient{AD,P,M} Wraps a BUGSModel with automatic differentiation capabilities using DifferentiationInterface. Implements both `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. # Fields -- `backend::B`: ADTypes backend (e.g., AutoReverseDiff()) -- `prep::P`: Prepared gradient from DifferentiationInterface (can be nothing) +- `adtype::AD`: ADTypes backend (e.g., AutoReverseDiff()) +- `prep::P`: Prepared gradient from DifferentiationInterface - `base_model::M`: The underlying BUGSModel # Example @@ -50,12 +50,59 @@ model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) chain = AbstractMCMC.sample(rng, model, NUTS(0.8), 1000) ``` """ -struct BUGSModelWithGradient{B<:ADTypes.AbstractADType,P,M<:BUGSModel} - backend::B +struct BUGSModelWithGradient{AD<:ADTypes.AbstractADType,P,M<:BUGSModel} + adtype::AD prep::P base_model::M end +""" + BUGSModelWithGradient(model::BUGSModel, adtype::ADTypes.AbstractADType) + +Construct a gradient-enabled model wrapper from a BUGSModel and an AD backend. + +# AD Backend Compatibility + +Different AD backends have different compatibility with evaluation modes: + +- **`UseGeneratedLogDensityFunction`**: Only compatible with mutation-supporting backends + like `AutoMooncake` and `AutoEnzyme`. The generated functions mutate arrays in-place. +- **`UseGraph`**: Compatible with `AutoReverseDiff`, `AutoForwardDiff`, and other + tape-based or forward-mode backends. Also works with Mooncake and Enzyme. + +If an incompatible combination is detected, a warning is issued and the model is +automatically switched to `UseGraph` mode. + +# Example +```julia +model = compile(model_def, data) +grad_model = BUGSModelWithGradient(model, AutoReverseDiff(compile=true)) +``` +""" +function BUGSModelWithGradient(model::BUGSModel, adtype::ADTypes.AbstractADType) + # Check AD backend compatibility with evaluation mode + model = _check_ad_compatibility(model, adtype) + + x = getparams(model) + prep = DI.prepare_gradient(_logdensity_for_gradient, adtype, x, DI.Constant(model)) + return BUGSModelWithGradient(adtype, prep, model) +end + +# AD backends that support mutation (required for UseGeneratedLogDensityFunction) +_supports_mutation(::ADTypes.AutoMooncake) = true +_supports_mutation(::ADTypes.AutoEnzyme) = true +_supports_mutation(::ADTypes.AbstractADType) = false + +function _check_ad_compatibility(model::BUGSModel, adtype::ADTypes.AbstractADType) + if model.evaluation_mode isa UseGeneratedLogDensityFunction && + !_supports_mutation(adtype) + @warn "AD backend $(typeof(adtype)) does not support mutation required by " * + "UseGeneratedLogDensityFunction mode. Switching to UseGraph mode." maxlog = 1 + return set_evaluation_mode(model, UseGraph()) + end + return model +end + # Forward base BUGSModel interface function LogDensityProblems.logdensity(model::BUGSModelWithGradient, x::AbstractVector) return LogDensityProblems.logdensity(model.base_model, x) @@ -70,43 +117,26 @@ function LogDensityProblems.capabilities(::Type{<:BUGSModelWithGradient}) end """ - _logdensity_switched(x, base_model_constant) + _logdensity_for_gradient(x, model_constant) -Helper function that switches argument order for DifferentiationInterface compatibility. -DI expects the active argument (to differentiate w.r.t.) to come first. +Target function for gradient computation via DifferentiationInterface. +The parameter vector `x` comes first (the argument to differentiate w.r.t.), +and the model is wrapped in `DI.Constant` to indicate it's not differentiated. """ -function _logdensity_switched(x::AbstractVector, base_model_constant::DI.Constant) - base_model = DI.unwrap(base_model_constant) - return LogDensityProblems.logdensity(base_model, x) -end - -# Fallback for testing during preparation (when DI calls without Constant wrapper) -function _logdensity_switched(x::AbstractVector, base_model::BUGSModel) - return LogDensityProblems.logdensity(base_model, x) +function _logdensity_for_gradient(x::AbstractVector, model_constant::DI.Constant) + model = DI.unwrap(model_constant) + return _eval_logdensity(model, model.evaluation_mode, x) end """ LogDensityProblems.logdensity_and_gradient(model::BUGSModelWithGradient, x) Compute log density and its gradient using DifferentiationInterface. -Uses prepared gradient if available, otherwise falls back to unprepared computation. """ function LogDensityProblems.logdensity_and_gradient( model::BUGSModelWithGradient, x::AbstractVector ) - # Active argument (x) comes first for DI - # Base model passed as Constant context - if model.prep === nothing - return DI.value_and_gradient( - _logdensity_switched, model.backend, x, DI.Constant(model.base_model) - ) - else - return DI.value_and_gradient( - _logdensity_switched, - model.prep, - model.backend, - x, - DI.Constant(model.base_model), - ) - end + return DI.value_and_gradient( + _logdensity_for_gradient, model.prep, model.adtype, x, DI.Constant(model.base_model) + ) end From 7d033d8427816828138df1f6bbf84f7e22221d8a Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 4 Dec 2025 07:32:23 +0000 Subject: [PATCH 10/17] Fix AD gradient function signature and add AD compatibility tests --- JuliaBUGS/src/model/logdensityproblems.jl | 7 +++---- JuliaBUGS/test/Project.toml | 10 +++++++--- JuliaBUGS/test/runtests.jl | 3 +++ 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/JuliaBUGS/src/model/logdensityproblems.jl b/JuliaBUGS/src/model/logdensityproblems.jl index 60d30b44b..da44db04d 100644 --- a/JuliaBUGS/src/model/logdensityproblems.jl +++ b/JuliaBUGS/src/model/logdensityproblems.jl @@ -148,14 +148,13 @@ function LogDensityProblems.capabilities(::Type{<:BUGSModelWithGradient}) end """ - _logdensity_for_gradient(x, model_constant) + _logdensity_for_gradient(x, model) Target function for gradient computation via DifferentiationInterface. The parameter vector `x` comes first (the argument to differentiate w.r.t.), -and the model is wrapped in `DI.Constant` to indicate it's not differentiated. +and the model is passed as a constant context (not differentiated). """ -function _logdensity_for_gradient(x::AbstractVector, model_constant::DI.Constant) - model = DI.unwrap(model_constant) +function _logdensity_for_gradient(x::AbstractVector, model::BUGSModel) return _eval_logdensity(model, model.evaluation_mode, x) end diff --git a/JuliaBUGS/test/Project.toml b/JuliaBUGS/test/Project.toml index 056862146..f03a6ce3b 100644 --- a/JuliaBUGS/test/Project.toml +++ b/JuliaBUGS/test/Project.toml @@ -8,6 +8,8 @@ BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -32,13 +34,15 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] ADTypes = "1.14.0" AbstractMCMC = "5" -AbstractPPL = "0.8.4, 0.9, 0.10, 0.11" -AdvancedHMC = "0.6, 0.7" +AbstractPPL = "0.8.4, 0.9, 0.10, 0.11, 0.12, 0.13" +AdvancedHMC = "0.6, 0.7, 0.8" AdvancedMH = "0.8" BangBang = "0.4.1" ChainRules = "1" DifferentiationInterface = "0.7" Distributions = "0.23.8, 0.24, 0.25" +ForwardDiff = "1" +Mooncake = "0.4" Documenter = "0.27, 1" Graphs = "1" JuliaSyntax = "1" @@ -46,7 +50,7 @@ LinearAlgebra = "1.10" LogDensityProblems = "2" LogDensityProblemsAD = "1" LogExpFunctions = "0.3" -MCMCChains = "6" +MCMCChains = "6, 7" MacroTools = "0.5" MetaGraphsNext = "0.6, 0.7" OrderedCollections = "1" diff --git a/JuliaBUGS/test/runtests.jl b/JuliaBUGS/test/runtests.jl index 98a021394..7bbb1d102 100644 --- a/JuliaBUGS/test/runtests.jl +++ b/JuliaBUGS/test/runtests.jl @@ -40,6 +40,8 @@ using AdvancedHMC using AdvancedMH using MCMCChains using ReverseDiff +using ForwardDiff +using Mooncake JuliaBUGS.@bugs_primitive Beta Bernoulli Categorical Exponential Gamma InverseGamma Normal Uniform LogNormal Poisson JuliaBUGS.@bugs_primitive Diagonal Dirichlet LKJ MvNormal @@ -96,6 +98,7 @@ const TEST_GROUPS = OrderedDict{String,Function}( "inference_mh" => () -> include("independent_mh.jl"), "gibbs" => () -> include("gibbs.jl"), "parallel_sampling" => () -> include("parallel_sampling.jl"), + "ad_compatibility" => () -> include("ad_compatibility.jl"), "experimental" => () -> include("experimental/ProbabilisticGraphicalModels/runtests.jl"), ) From 94fb6c5853ebdffd0a27739c6dd987b2100b470c Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 4 Dec 2025 07:36:16 +0000 Subject: [PATCH 11/17] format --- JuliaBUGS/test/ad_compatibility.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/JuliaBUGS/test/ad_compatibility.jl b/JuliaBUGS/test/ad_compatibility.jl index 822ec1261..c8ffe6d9c 100644 --- a/JuliaBUGS/test/ad_compatibility.jl +++ b/JuliaBUGS/test/ad_compatibility.jl @@ -50,7 +50,9 @@ @testset "UseGeneratedLogDensityFunction mode" begin model = compile(model_def, data) - model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGeneratedLogDensityFunction()) + model = JuliaBUGS.set_evaluation_mode( + model, JuliaBUGS.UseGeneratedLogDensityFunction() + ) @test model.evaluation_mode isa JuliaBUGS.UseGeneratedLogDensityFunction x = JuliaBUGS.getparams(model) @@ -78,8 +80,11 @@ end @testset "AutoMooncake - should work without warning" begin - grad_model = JuliaBUGS.BUGSModelWithGradient(model, AutoMooncake(; config=nothing)) - @test grad_model.base_model.evaluation_mode isa JuliaBUGS.UseGeneratedLogDensityFunction + grad_model = JuliaBUGS.BUGSModelWithGradient( + model, AutoMooncake(; config=nothing) + ) + @test grad_model.base_model.evaluation_mode isa + JuliaBUGS.UseGeneratedLogDensityFunction logp, grad = LogDensityProblems.logdensity_and_gradient(grad_model, x) @test isfinite(logp) From 8aee4d23b7345f9cb3a9ef25748c48e7ee31e16b Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 4 Dec 2025 07:52:11 +0000 Subject: [PATCH 12/17] remove redundant lines in benchmark code --- JuliaBUGS/benchmark/benchmark.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/JuliaBUGS/benchmark/benchmark.jl b/JuliaBUGS/benchmark/benchmark.jl index 0b5936624..25a558a0f 100644 --- a/JuliaBUGS/benchmark/benchmark.jl +++ b/JuliaBUGS/benchmark/benchmark.jl @@ -85,11 +85,6 @@ function _create_results_dataframe(results::OrderedDict{Symbol,BenchmarkResult}) ), ) end - DataFrames.rename!( - df, - :Density_Time => "Density Time (µs)", - :Density_Gradient_Time => "Density+Gradient Time (µs)", - ) return df end From cd1df5ed79c126786b7cf3e15c50beb9e640df7d Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 4 Dec 2025 08:38:07 +0000 Subject: [PATCH 13/17] various cleaning up --- .gitignore | 4 +- JuliaBUGS/History.md | 21 +++---- JuliaBUGS/src/JuliaBUGS.jl | 58 +++---------------- JuliaBUGS/src/model/logdensityproblems.jl | 23 ++------ JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl | 52 ----------------- JuliaBUGS/test/model/bugsmodel.jl | 41 ------------- 6 files changed, 22 insertions(+), 177 deletions(-) diff --git a/.gitignore b/.gitignore index e5c898ac2..779fd4d05 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,4 @@ Manifest.toml *.local.* # gitingest generated files -digest.txt - -tmp/ \ No newline at end of file +digest.txt \ No newline at end of file diff --git a/JuliaBUGS/History.md b/JuliaBUGS/History.md index 24fb9c5f9..0e43ea767 100644 --- a/JuliaBUGS/History.md +++ b/JuliaBUGS/History.md @@ -1,18 +1,13 @@ # JuliaBUGS Changelog -## 0.10.4 - -- **DifferentiationInterface.jl integration**: JuliaBUGS now uses [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) for automatic differentiation, providing a unified interface to multiple AD backends. - - Add `adtype` parameter to `compile()` function for specifying AD backends via [ADTypes.jl](https://github.com/SciML/ADTypes.jl) - - Supports multiple backends: `AutoReverseDiff`, `AutoForwardDiff`, `AutoZygote`, `AutoEnzyme`, `AutoMooncake` - - Gradient computation is prepared during compilation for optimal performance - - Example: `model = compile(model_def, data; adtype=AutoReverseDiff(compile=true))` - - New `BUGSModelWithGradient(model, adtype)` constructor for adding gradients to existing models - - Models without `adtype` work as before (no gradients) - -- **Breaking**: `LogDensityProblemsAD.ADgradient` is no longer supported for gradient computation. - - **Old**: `ad_model = ADgradient(:ReverseDiff, model)` - - **New**: `model = compile(model_def, data; adtype=AutoReverseDiff())` or `BUGSModelWithGradient(model, AutoReverseDiff())` +## 0.12 + +- **DifferentiationInterface.jl integration**: Use `adtype` parameter in `compile()` to enable gradient-based inference via [ADTypes.jl](https://github.com/SciML/ADTypes.jl). + - Example: `model = compile(model_def, data; adtype=AutoReverseDiff())` + - Supports `AutoReverseDiff`, `AutoForwardDiff`, `AutoMooncake` + +- **Breaking**: `LogDensityProblemsAD.ADgradient` is no longer supported. + - Use `compile(...; adtype=...)` or `BUGSModelWithGradient(model, adtype)` instead. ## 0.10.1 diff --git a/JuliaBUGS/src/JuliaBUGS.jl b/JuliaBUGS/src/JuliaBUGS.jl index 704800dc7..5e279e4e4 100644 --- a/JuliaBUGS/src/JuliaBUGS.jl +++ b/JuliaBUGS/src/JuliaBUGS.jl @@ -236,48 +236,20 @@ function validate_bugs_expression(expr, line_num) end """ - compile(model_def, data[, initial_params]; skip_validation=false, adtype=nothing) + compile(model_def, data[, initial_params]; adtype=nothing) -Compile the model with model definition and data. Optionally, initializations can be provided. -If initializations are not provided, values will be sampled from the prior distributions. - -By default, validates that all functions in the model are in the BUGS allowlist (suitable for @bugs macro). -Set `skip_validation=true` to skip validation (for @model macro usage). - -The compiled model uses `UseGraph` evaluation mode by default. To use the optimized generated -log-density function, call `set_evaluation_mode(model, UseGeneratedLogDensityFunction())`. - -If `adtype` is provided, returns a `BUGSModelWithGradient` that supports gradient-based MCMC -samplers like HMC/NUTS. The gradient computation is prepared during compilation for optimal performance. +Compile a BUGS model. Returns `BUGSModel`, or `BUGSModelWithGradient` if `adtype` is provided. # Arguments -- `model_def::Expr`: Model definition from @bugs macro +- `model_def::Expr`: Model definition from `@bugs` macro - `data::NamedTuple`: Observed data -- `initial_params::NamedTuple=NamedTuple()`: Initial parameter values (optional) -- `skip_validation::Bool=false`: Skip function validation (for @model macro) -- `eval_module::Module=@__MODULE__`: Module for evaluation -- `adtype`: AD backend specification using ADTypes. Examples: - - `AutoReverseDiff(compile=true)` - ReverseDiff with tape compilation (fastest) - - `AutoReverseDiff(compile=false)` - ReverseDiff without compilation - - `AutoForwardDiff()` - ForwardDiff backend - - `AutoZygote()` - Zygote backend - - `AutoEnzyme()` - Enzyme backend - - `AutoMooncake()` - Mooncake backend - - Any other `ADTypes.AbstractADType` +- `initial_params::NamedTuple`: Initial parameter values (optional, defaults to prior samples) +- `adtype`: AD backend from ADTypes.jl (e.g., `AutoReverseDiff()`, `AutoForwardDiff()`, `AutoMooncake()`) # Examples ```julia -# Basic compilation model = compile(model_def, data) - -# With gradient support using ReverseDiff (recommended for most models) -model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) - -# Using ForwardDiff for small models -model = compile(model_def, data; adtype=AutoForwardDiff()) - -# Sample with NUTS -chain = AbstractMCMC.sample(model, NUTS(0.8), 1000) +model = compile(model_def, data; adtype=AutoReverseDiff()) ``` """ function compile( @@ -319,28 +291,12 @@ function compile( # If adtype provided, wrap with gradient capabilities if adtype !== nothing - return _wrap_with_gradient(base_model, adtype) + return Base.invokelatest(Model.BUGSModelWithGradient, base_model, adtype) end return base_model end -# Helper function to prepare gradient - separated to handle world age issues -function _wrap_with_gradient(base_model::Model.BUGSModel, adtype::ADTypes.AbstractADType) - # Use invokelatest to handle world age issues with generated functions - return Base.invokelatest(Model.BUGSModelWithGradient, base_model, adtype) -end -# function compile( -# model_str::String, -# data::NamedTuple, -# initial_params::NamedTuple=NamedTuple(); -# replace_period::Bool=true, -# no_enclosure::Bool=false, -# ) -# model_def = _bugs_string_input(model_str, replace_period, no_enclosure) -# return compile(model_def, data, initial_params) -# end - """ register_bugs_function(func_name::Symbol) diff --git a/JuliaBUGS/src/model/logdensityproblems.jl b/JuliaBUGS/src/model/logdensityproblems.jl index da44db04d..a8eb9e45c 100644 --- a/JuliaBUGS/src/model/logdensityproblems.jl +++ b/JuliaBUGS/src/model/logdensityproblems.jl @@ -59,27 +59,16 @@ end """ BUGSModelWithGradient{AD,P,M} -Wraps a BUGSModel with automatic differentiation capabilities using DifferentiationInterface. -Implements both `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. +Wrap a `BUGSModel` with AD capabilities for gradient-based inference. + +Implements `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. # Fields -- `adtype::AD`: ADTypes backend (e.g., AutoReverseDiff()) +- `adtype::AD`: AD backend (e.g., `AutoReverseDiff()`) - `prep::P`: Prepared gradient from DifferentiationInterface -- `base_model::M`: The underlying BUGSModel +- `base_model::M`: The underlying `BUGSModel` -# Example -```julia -model_def = @bugs begin - x ~ dnorm(0, 1) -end -data = NamedTuple() - -# Create model with gradient capabilities -model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) - -# Use with gradient-based MCMC -chain = AbstractMCMC.sample(rng, model, NUTS(0.8), 1000) -``` +See also [`compile`](@ref). """ struct BUGSModelWithGradient{AD<:ADTypes.AbstractADType,P,M<:BUGSModel} adtype::AD diff --git a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl index f140065ba..61d2c39d3 100644 --- a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl @@ -27,58 +27,6 @@ [Symbol("x[3]"), Symbol("x[1:2][1]"), Symbol("x[1:2][2]"), :y] end - @testset "AD backend sampling" begin - model_def = @bugs begin - mu ~ dnorm(0, 1) - for i in 1:N - y[i] ~ dnorm(mu, 1) - end - end - data = (N=5, y=[1.0, 2.0, 1.5, 2.5, 1.8]) - - # Test that ReverseDiff backend works - ad_model_compiled = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) - ad_model_nocompile = compile( - model_def, data; adtype=AutoReverseDiff(; compile=false) - ) - - @test ad_model_compiled isa JuliaBUGS.Model.BUGSModelWithGradient - @test ad_model_nocompile isa JuliaBUGS.Model.BUGSModelWithGradient - - # Test that both produce equivalent results - n_samples, n_adapts = 100, 100 - D = LogDensityProblems.dimension(ad_model_compiled) - initial_θ = rand(StableRNG(123), D) - - samples_compiled = AbstractMCMC.sample( - StableRNG(1234), - ad_model_compiled, - NUTS(0.8), - n_samples; - progress=false, - chain_type=Chains, - n_adapts=n_adapts, - init_params=initial_θ, - discard_initial=n_adapts, - ) - - samples_nocompile = AbstractMCMC.sample( - StableRNG(1234), - ad_model_nocompile, - NUTS(0.8), - n_samples; - progress=false, - chain_type=Chains, - n_adapts=n_adapts, - init_params=initial_θ, - discard_initial=n_adapts, - ) - - # Results should be very similar (same RNG seed) - @test summarize(samples_compiled)[:mu].nt.mean[1] ≈ - summarize(samples_nocompile)[:mu].nt.mean[1] rtol = 0.1 - end - @testset "Inference results on examples: $example" for example in [:seeds, :rats, :stacks] (; model_def, data, inits, reference_results) = Base.getfield( diff --git a/JuliaBUGS/test/model/bugsmodel.jl b/JuliaBUGS/test/model/bugsmodel.jl index 5abacc34f..9173ed822 100644 --- a/JuliaBUGS/test/model/bugsmodel.jl +++ b/JuliaBUGS/test/model/bugsmodel.jl @@ -385,45 +385,4 @@ end @test occursin("Variable sizes and types:", output) end end - - @testset "AD Type Parameter" begin - model_def = @bugs begin - mu ~ dnorm(0, 1) - y ~ dnorm(mu, 1) - end - data = (y=1.5,) - - @testset "ADTypes backends" begin - # Test with compile=true - model_compile = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) - @test model_compile isa JuliaBUGS.Model.BUGSModelWithGradient - - # Test with compile=false - model_nocompile = compile( - model_def, data; adtype=AutoReverseDiff(; compile=false) - ) - @test model_nocompile isa JuliaBUGS.Model.BUGSModelWithGradient - end - - @testset "Default behavior (no adtype)" begin - # Without adtype, should return regular BUGSModel - model_default = compile(model_def, data) - @test model_default isa BUGSModel - @test !(model_default isa JuliaBUGS.Model.BUGSModelWithGradient) - end - - @testset "Gradient computation" begin - model = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) - test_point = [0.0] - - # Test that gradient can be computed - ℓ, grad = LogDensityProblems.logdensity_and_gradient(model, test_point) - - @test ℓ isa Real - @test grad isa Vector - @test length(grad) == 1 - @test isfinite(ℓ) - @test all(isfinite, grad) - end - end end From a501476abbe44a6928b4a6624ae9928f8f6d4289 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 4 Dec 2025 08:59:15 +0000 Subject: [PATCH 14/17] Update examples to use new BUGSModelWithGradient API --- JuliaBUGS/examples/Project.toml | 5 +- JuliaBUGS/examples/bnn.jl | 16 +++--- JuliaBUGS/examples/gp.jl | 95 +++------------------------------ JuliaBUGS/examples/sir.jl | 8 +-- 4 files changed, 23 insertions(+), 101 deletions(-) diff --git a/JuliaBUGS/examples/Project.toml b/JuliaBUGS/examples/Project.toml index 401edf467..8a043e94c 100644 --- a/JuliaBUGS/examples/Project.toml +++ b/JuliaBUGS/examples/Project.toml @@ -4,17 +4,18 @@ AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" -DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + +[sources] +JuliaBUGS = {path = ".."} diff --git a/JuliaBUGS/examples/bnn.jl b/JuliaBUGS/examples/bnn.jl index f7666904f..afb21e0a8 100644 --- a/JuliaBUGS/examples/bnn.jl +++ b/JuliaBUGS/examples/bnn.jl @@ -1,17 +1,16 @@ using JuliaBUGS +using Distributions: Bernoulli, MvNormal using AbstractMCMC using ADTypes using AdvancedHMC -using DifferentiationInterface using FillArrays +using ForwardDiff using Functors using LinearAlgebra using LogDensityProblems -using LogDensityProblemsAD using Lux using MCMCChains -using Mooncake using Random ## data simulation @@ -84,7 +83,7 @@ function make_prediction(parameters, xs; ps=ps, nn=nn) return Lux.apply(nn, f32(xs), f32(vector_to_parameters(parameters, ps))) end -JuliaBUGS.@bugs_primitive parameter_distribution make_prediction +JuliaBUGS.@bugs_primitive parameter_distribution make_prediction Bernoulli @eval JuliaBUGS begin ps = Main.ps @@ -96,16 +95,17 @@ end data = (nparameters=Lux.parameterlength(nn), xs=xs_hcat, ts=ts, N=length(ts), sigma=sigma) +# Use ForwardDiff with UseGraph mode (required for user-defined primitives) model = compile(model_def, data) - -ad_model = ADgradient(AutoMooncake(; config=Mooncake.Config()), model) +model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGraph()) +model = JuliaBUGS.BUGSModelWithGradient(model, AutoForwardDiff()) # sampling is slow, so sample 10 of them to verify that this can work samples_and_stats = AbstractMCMC.sample( - ad_model, + model, NUTS(0.65), 10; chain_type=Chains, - # n_adapts=1000, + # n_adapts=1000, # discard_initial=1000 ) diff --git a/JuliaBUGS/examples/gp.jl b/JuliaBUGS/examples/gp.jl index fd8188863..562937e61 100644 --- a/JuliaBUGS/examples/gp.jl +++ b/JuliaBUGS/examples/gp.jl @@ -7,14 +7,11 @@ using JuliaBUGS using JuliaBUGS: @model # Required packages for GP modeling and MCMC -using AbstractGPs, Distributions, LogExpFunctions -using LogDensityProblems, LogDensityProblemsAD +using AbstractGPs, Distributions, LogExpFunctions, ForwardDiff +using LogDensityProblems +using ADTypes using AbstractMCMC, AdvancedHMC, MCMCChains -# Differentiation backend -using DifferentiationInterface -using Mooncake: Mooncake - # --- Data Definition --- # Golf putting data from Gelman et al. (BDA3, Chapter 5) @@ -120,94 +117,18 @@ model = gp_golf_putting( data.jitter, # Numerical stability term ) -# Generate the log density function for optimal performance -model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGeneratedLogDensityFunction()) - -# --- MCMC Setup with Custom LogDensityProblems Wrapper --- - -# We need a wrapper around the JuliaBUGS model to interface with LogDensityProblems -# and utilize automatic differentiation (AD) via Mooncake.jl for gradient computation, -# which is required by AdvancedHMC. - -struct BUGSMooncakeModel{T,P} - model::T # The JuliaBUGS model - prep::P # Pre-allocated workspace for gradient computation using Mooncake -end - -# Define the function to compute the log density using the JuliaBUGS model's internal function -f(x) = model.log_density_computation_function(model.evaluation_env, x) - -# Prepare the differentiation backend (Mooncake) -backend = AutoMooncake(; config=nothing) -x_init = rand(LogDensityProblems.dimension(model)) # Initial point for testing/preparation -prep = prepare_gradient(f, backend, x_init) - -# Create the wrapped model instance -bugsmooncake = BUGSMooncakeModel(model, prep) - -# --- LogDensityProblems Interface Implementation for the Wrapper --- - -# Define logdensity function for the wrapper -function LogDensityProblems.logdensity(model::BUGSMooncakeModel, x::AbstractVector) - return f(x) # Calls the underlying JuliaBUGS log density function -end - -# Define logdensity_and_gradient function using the prepared DifferentiationInterface setup -function LogDensityProblems.logdensity_and_gradient( - model::BUGSMooncakeModel, x::AbstractVector -) - # Computes both the log density and its gradient using Mooncake AD - return DifferentiationInterface.value_and_gradient( - f, model.prep, AutoMooncake(; config=nothing), x - ) -end - -# Define dimension function -function LogDensityProblems.dimension(model::BUGSMooncakeModel) - return LogDensityProblems.dimension(model.model) # Delegates to the original model -end - -# Define a custom bundle_samples function to convert the AdvancedHMC.Transition to a Chains object -function AbstractMCMC.bundle_samples( - ts::Vector{<:AdvancedHMC.Transition}, - logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSMooncakeModel}, - sampler::AdvancedHMC.AbstractHMCSampler, - state, - chain_type::Type{Chains}; - discard_initial=0, - thinning=1, - kwargs..., -) - stats_names = collect(keys(merge((; lp=ts[1].z.ℓπ.value), AdvancedHMC.stat(ts[1])))) - stats_values = [ - vcat([ts[i].z.ℓπ.value..., collect(values(AdvancedHMC.stat(ts[i])))...]) for - i in eachindex(ts) - ] - - return JuliaBUGS.gen_chains( - logdensitymodel.logdensity.model, - [t.z.θ for t in ts], - stats_names, - stats_values; - discard_initial=discard_initial, - thinning=thinning, - kwargs..., - ) -end - -# Specify capabilities (indicates gradient availability) -function LogDensityProblems.capabilities(::Type{<:BUGSMooncakeModel}) - return LogDensityProblems.LogDensityOrder{1}() # Can compute up to the gradient -end +# Use graph evaluation mode with ForwardDiff AD (required for user-defined primitives) +model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGraph()) +grad_model = JuliaBUGS.BUGSModelWithGradient(model, AutoForwardDiff()) # --- MCMC Sampling --- # Sample from the posterior distribution using AdvancedHMC's NUTS sampler samples_and_stats = AbstractMCMC.sample( - AbstractMCMC.LogDensityModel(bugsmooncake), # Wrap the model for AbstractMCMC + grad_model, AdvancedHMC.NUTS(0.65), # No-U-Turn Sampler 1000; # Total number of samples chain_type=Chains, # Store results as MCMCChains object n_adapts=500, # Number of adaptation steps for NUTS - discard_initial=500, # Number of initial samples (warmup) to discard; + discard_initial=500, # Number of initial samples (warmup) to discard ) diff --git a/JuliaBUGS/examples/sir.jl b/JuliaBUGS/examples/sir.jl index ccd4f3279..5e915e5b3 100644 --- a/JuliaBUGS/examples/sir.jl +++ b/JuliaBUGS/examples/sir.jl @@ -6,7 +6,7 @@ using JuliaBUGS using JuliaBUGS: @model using Distributions using DifferentialEquations -using LogDensityProblems, LogDensityProblemsAD +using LogDensityProblems using ADTypes using AbstractMCMC, AdvancedHMC, MCMCChains using Distributed # For distributed example @@ -113,8 +113,8 @@ model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGraph()) # --- MCMC Sampling: NUTS with ForwardDiff AD --- -# Create an AD-aware wrapper for the model using ForwardDiff for gradients -ad_model_forwarddiff = ADgradient(AutoForwardDiff(), model) +# Create gradient-enabled model using ForwardDiff +grad_model = JuliaBUGS.BUGSModelWithGradient(model, AutoForwardDiff()) # MCMC settings n_samples = 1000 @@ -122,7 +122,7 @@ n_adapts = 500 # Run the NUTS sampler samples_nuts_fwd = AbstractMCMC.sample( - ad_model_forwarddiff, + grad_model, AdvancedHMC.NUTS(0.65), # No-U-Turn Sampler with step size adaptation target n_samples; chain_type=Chains, # Store results as MCMCChains object From e059ad80f12ba0e9656fcd5692208f23fee2f5f5 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 4 Dec 2025 09:30:54 +0000 Subject: [PATCH 15/17] update example doc --- JuliaBUGS/docs/Project.toml | 9 ++ JuliaBUGS/docs/src/example.md | 198 +++++++++------------------------- 2 files changed, 58 insertions(+), 149 deletions(-) diff --git a/JuliaBUGS/docs/Project.toml b/JuliaBUGS/docs/Project.toml index a8e5b92ac..ca5e166dc 100644 --- a/JuliaBUGS/docs/Project.toml +++ b/JuliaBUGS/docs/Project.toml @@ -1,7 +1,16 @@ [deps] +AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + +[sources] +JuliaBUGS = {path = ".."} [compat] Documenter = "1.14" diff --git a/JuliaBUGS/docs/src/example.md b/JuliaBUGS/docs/src/example.md index 57ea54466..771d4da9f 100644 --- a/JuliaBUGS/docs/src/example.md +++ b/JuliaBUGS/docs/src/example.md @@ -2,6 +2,7 @@ ```@setup abc using JuliaBUGS +using AdvancedHMC, AbstractMCMC, LogDensityProblems, MCMCChains, ADTypes, ReverseDiff data = ( r = [10, 23, 23, 26, 17, 5, 53, 55, 32, 46, 10, 8, 10, 8, 23, 0, 3, 22, 15, 32, 3], @@ -207,9 +208,7 @@ model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) Available AD backends include: - `AutoReverseDiff(compile=true)` - ReverseDiff with tape compilation (recommended for most models) - `AutoForwardDiff()` - ForwardDiff (efficient for models with few parameters) -- `AutoZygote()` - Zygote (source-to-source AD) -- `AutoEnzyme()` - Enzyme (experimental, high-performance) -- `AutoMooncake()` - Mooncake (high-performance reverse-mode AD) +- `AutoMooncake()` - Mooncake (requires `UseGeneratedLogDensityFunction()` mode) For fine-grained control, you can configure the AD backend: @@ -224,9 +223,7 @@ The compiled model with gradient support implements the [`LogDensityProblems.jl` For gradient-based inference, we use [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) with models compiled with an `adtype`: -```julia -using AdvancedHMC, AbstractMCMC, LogDensityProblems, MCMCChains, ReverseDiff - +```@example abc # Compile with gradient support model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) @@ -243,86 +240,56 @@ samples_and_stats = AbstractMCMC.sample( init_params = initial_θ, discard_initial = n_adapts ) -describe(samples_and_stats) +samples_and_stats ``` -This will return the MCMC Chain, - -```plaintext -Chains MCMC chain (2000×40×1 Array{Real, 3}): - -Iterations = 1001:1:3000 -Number of chains = 1 -Samples per chain = 2000 -parameters = tau, alpha12, alpha2, alpha1, alpha0, b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19], b[20], b[21], sigma -internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, is_adapt - -Summary Statistics - parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec - Symbol Float64 Float64 Float64 Real Float64 Float64 Missing - - tau 73.1490 193.8441 43.2582 56.3430 20.6688 1.0155 missing - alpha12 -0.8052 0.4392 0.0158 761.2180 1049.1664 1.0020 missing - alpha2 1.3428 0.2813 0.0140 422.8810 1013.2570 1.0061 missing - alpha1 0.0845 0.3126 0.0113 773.2202 981.8487 1.0051 missing - alpha0 -0.5480 0.1944 0.0087 537.6212 1156.2083 1.0014 missing - b[1] -0.1905 0.2540 0.0129 374.3372 971.7526 1.0034 missing - b[2] 0.0161 0.2178 0.0056 1505.6353 1002.8787 1.0001 missing - b[3] -0.1986 0.2375 0.0128 367.6766 1287.8215 1.0015 missing - b[4] 0.2792 0.2498 0.0163 201.1558 1168.7538 1.0068 missing - b[5] 0.1170 0.2397 0.0092 659.5422 1484.8584 1.0016 missing - b[6] 0.0667 0.2821 0.0074 1745.5567 902.1014 1.0067 missing - b[7] 0.0597 0.2218 0.0055 1589.5590 1145.6017 1.0065 missing - b[8] 0.1769 0.2316 0.0102 554.5974 1318.8089 1.0001 missing - b[9] -0.1257 0.2233 0.0073 930.0346 1186.4283 1.0031 missing - b[10] -0.2513 0.2392 0.0159 213.6323 1142.4487 1.0096 missing - b[11] 0.0768 0.2783 0.0081 1376.5999 1218.1537 1.0009 missing - b[12] 0.1171 0.2768 0.0079 1354.9409 1130.8217 1.0052 missing - b[13] -0.0688 0.2433 0.0055 1895.0387 1527.7066 1.0010 missing - b[14] -0.1363 0.2558 0.0075 1276.0992 1208.8587 1.0001 missing - b[15] 0.2334 0.2757 0.0135 439.2241 837.3396 1.0036 missing - b[16] -0.1212 0.3024 0.0106 1093.4416 914.9457 0.9997 missing - b[17] -0.2120 0.3142 0.0166 360.6420 702.4098 1.0009 missing - b[18] 0.0346 0.2282 0.0056 1665.0325 1281.7179 1.0011 missing - b[19] -0.0244 0.2400 0.0052 2186.7638 1179.6971 1.0132 missing - b[20] 0.2108 0.2421 0.0131 349.7657 1263.5781 1.0016 missing - b[21] -0.0509 0.2813 0.0061 2200.5614 916.6256 0.9998 missing - sigma 0.2797 0.1362 0.0168 56.3430 21.4971 1.0123 missing - -Quantiles - parameters 2.5% 25.0% 50.0% 75.0% 97.5% - Symbol Float64 Float64 Float64 Float64 Float64 - - tau 3.1280 7.4608 13.0338 28.2289 929.6520 - alpha12 -1.6645 -1.0887 -0.7952 -0.5635 0.1162 - alpha2 0.8398 1.1494 1.3233 1.5337 1.9177 - alpha1 -0.5796 -0.1059 0.1042 0.2883 0.6702 - alpha0 -0.9340 -0.6751 -0.5463 -0.4086 -0.1752 - b[1] -0.7430 -0.3415 -0.1566 -0.0074 0.2535 - b[2] -0.4261 -0.1083 0.0192 0.1420 0.4810 - b[3] -0.7394 -0.3377 -0.1687 -0.0242 0.2041 - b[4] -0.1108 0.0873 0.2409 0.4375 0.8267 - b[5] -0.3141 -0.0458 0.0900 0.2563 0.6489 - b[6] -0.4679 -0.0896 0.0291 0.2202 0.7060 - b[7] -0.3861 -0.0685 0.0534 0.1847 0.5207 - b[8] -0.2326 0.0221 0.1505 0.3162 0.6861 - b[9] -0.6007 -0.2482 -0.0984 0.0057 0.2771 - b[10] -0.7936 -0.4108 -0.2255 -0.0617 0.1290 - b[11] -0.4381 -0.0796 0.0353 0.2178 0.7232 - b[12] -0.3806 -0.0451 0.0750 0.2671 0.7625 - b[13] -0.5841 -0.2135 -0.0443 0.0652 0.4055 - b[14] -0.6854 -0.2872 -0.1015 0.0147 0.3476 - b[15] -0.2054 0.0257 0.1898 0.4004 0.8660 - b[16] -0.8173 -0.2829 -0.0804 0.0532 0.4094 - b[17] -0.9071 -0.3911 -0.1595 0.0099 0.2864 - b[18] -0.4526 -0.0919 0.0140 0.1686 0.4985 - b[19] -0.5055 -0.1547 -0.0091 0.1134 0.4528 - b[20] -0.2120 0.0318 0.1788 0.3673 0.7416 - b[21] -0.6482 -0.2044 -0.0263 0.1051 0.5246 - sigma 0.0328 0.1882 0.2770 0.3661 0.5654 +This is consistent with the result in the [OpenBUGS seeds example](https://chjackson.github.io/openbugsdoc/Examples/Seeds.html). + +## Evaluation Modes and Automatic Differentiation + +JuliaBUGS supports multiple evaluation modes and AD backends. The evaluation mode determines how the log density is computed, and constrains which AD backends can be used. + +### Evaluation Modes + +| Mode | AD Backends | +|------|-------------| +| `UseGraph()` (default) | ReverseDiff, ForwardDiff | +| `UseGeneratedLogDensityFunction()` | Mooncake | + +- **`UseGraph()`**: Evaluates by traversing the computational graph. Supports user-defined primitives registered via `@bugs_primitive`. +- **`UseGeneratedLogDensityFunction()`**: Generates and compiles a Julia function for the log density. + +### AD Backends with `UseGraph()` Mode + +Use [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) or [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) with the default `UseGraph()` mode: + +```julia +using ADTypes + +# ReverseDiff with tape compilation (recommended for large models) +model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + +# ForwardDiff (efficient for small models with < 20 parameters) +model = compile(model_def, data; adtype=AutoForwardDiff()) + +# ReverseDiff without compilation (supports control flow) +model = compile(model_def, data; adtype=AutoReverseDiff(compile=false)) ``` -This is consistent with the result in the [OpenBUGS seeds example](https://chjackson.github.io/openbugsdoc/Examples/Seeds.html). +!!! warning "Compiled ReverseDiff does not support control flow" + Compiled tapes record a fixed execution path. If your model contains value-dependent control flow (e.g., `if x > 0`, `while`, truncation), the tape will only capture one branch and produce **incorrect gradients** when the control flow takes a different path. Use `AutoReverseDiff(compile=false)` or `AutoForwardDiff()` for models with control flow. + +### AD Backend with `UseGeneratedLogDensityFunction()` Mode + +Use [Mooncake.jl](https://github.com/compintell/Mooncake.jl) with the generated log density function mode: + +```julia +using ADTypes + +model = compile(model_def, data) +model = set_evaluation_mode(model, UseGeneratedLogDensityFunction()) +model = BUGSModelWithGradient(model, AutoMooncake(; config=nothing)) +``` ## Parallel and Distributed Sampling with `AbstractMCMC` @@ -395,73 +362,6 @@ In this case, we pass two additional arguments to `AbstractMCMC.sample`: Note that the `init_params` argument is now a vector of initial parameters for each chain. Sometimes the progress logger can cause problems in distributed setting, so we can disable it by setting `progress = false`. -## Choosing an Automatic Differentiation Backend - -JuliaBUGS integrates with multiple automatic differentiation (AD) backends through [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), providing flexibility to choose the most suitable backend for your model. - -### Available Backends - -The following AD backends are supported via [ADTypes.jl](https://github.com/SciML/ADTypes.jl): - -- **`AutoReverseDiff(compile=true)`** (recommended) — Tape-based reverse-mode AD, highly efficient for models with many parameters. Uses compilation by default for optimal performance. -- **`AutoForwardDiff()`** — Forward-mode AD, efficient for models with few parameters (typically < 20). -- **`AutoZygote()`** — Source-to-source reverse-mode AD, general-purpose but may be slower than ReverseDiff for many models. -- **`AutoEnzyme()`** — Experimental high-performance AD backend with LLVM-level transformations. -- **`AutoMooncake()`** — High-performance reverse-mode AD with advanced optimizations. - -### Usage Examples - -#### Basic Usage - -Specify an AD backend using ADTypes: - -```julia -using ADTypes - -# ReverseDiff with tape compilation (recommended for most models) -model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) - -# ForwardDiff (good for small models with few parameters) -model = compile(model_def, data; adtype=AutoForwardDiff()) - -# Zygote (source-to-source AD) -model = compile(model_def, data; adtype=AutoZygote()) -``` - -#### Advanced Configuration - -For fine-grained control, you can configure the AD backends: - -```julia -using ADTypes - -# ReverseDiff without tape compilation -model = compile(model_def, data; adtype=AutoReverseDiff(compile=false)) - -# ReverseDiff with compilation (default, recommended) -model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) -``` - -### Performance Considerations - -- **ReverseDiff with compilation** (`AutoReverseDiff(compile=true)`) is recommended for most models, especially those with many parameters. Compilation adds a one-time overhead but significantly speeds up subsequent gradient evaluations. - -- **ForwardDiff** (`AutoForwardDiff()`) is often faster for models with few parameters (< 20), as it avoids tape construction overhead. - -- **Tape compilation trade-off**: While `AutoReverseDiff(compile=true)` has higher initial compilation time, it provides faster gradient evaluations during sampling. For quick prototyping or models that will only be sampled a few times, `AutoReverseDiff(compile=false)` may be preferable. - -!!! warning "Compiled tapes and control flow" - Compiled ReverseDiff tapes cannot handle value-dependent control flow (e.g., `if x[1] > 0`). If your model has such control flow, use `AutoReverseDiff(compile=false)` or a different backend like `AutoForwardDiff()` or `AutoMooncake()`. See the [ReverseDiff documentation](https://juliadiff.org/ReverseDiff.jl/stable/api/#The-AbstractTape-API) for details. - -### Compatibility - -All models compiled with an `adtype` implement the full [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/) interface, making them compatible with: - -- [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) — NUTS and HMC samplers -- Any other sampler that works with the LogDensityProblems interface - -The gradient computation is prepared during model compilation for optimal performance during sampling. - ## More Examples We have transcribed all the examples from the first volume of the BUGS Examples ([original](https://www.multibugs.org/examples/latest/VolumeI.html) and [transcribed](https://github.com/TuringLang/JuliaBUGS.jl/tree/main/JuliaBUGS/src/BUGSExamples/Volume_1)). All programs and data are included, and can be compiled using the steps described in the tutorial above. From 4509463a8432990cadd677e69b5adbe3db146c32 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 4 Dec 2025 09:46:37 +0000 Subject: [PATCH 16/17] fix hmc test error --- JuliaBUGS/docs/src/example.md | 9 ++++++++- JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/JuliaBUGS/docs/src/example.md b/JuliaBUGS/docs/src/example.md index 771d4da9f..da6a3147a 100644 --- a/JuliaBUGS/docs/src/example.md +++ b/JuliaBUGS/docs/src/example.md @@ -205,6 +205,13 @@ using ADTypes model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) ``` +Alternatively, if you already have a compiled `BUGSModel`, you can wrap it with `BUGSModelWithGradient` without recompiling: + +```julia +base_model = compile(model_def, data) +model = BUGSModelWithGradient(base_model, AutoReverseDiff(compile=true)) +``` + Available AD backends include: - `AutoReverseDiff(compile=true)` - ReverseDiff with tape compilation (recommended for most models) - `AutoForwardDiff()` - ForwardDiff (efficient for models with few parameters) @@ -240,7 +247,7 @@ samples_and_stats = AbstractMCMC.sample( init_params = initial_θ, discard_initial = n_adapts ) -samples_and_stats +describe(samples_and_stats) ``` This is consistent with the result in the [OpenBUGS seeds example](https://chjackson.github.io/openbugsdoc/Examples/Seeds.html). diff --git a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl index 61d2c39d3..31cedc9de 100644 --- a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl @@ -39,7 +39,7 @@ n_samples, n_adapts = 1000, 1000 D = LogDensityProblems.dimension(ad_model) - initial_θ = JuliaBUGS.getparams(ad_model.base_model) + initial_θ = Base.invokelatest(JuliaBUGS.getparams, ad_model.base_model) samples_and_stats = Base.invokelatest( AbstractMCMC.sample, From d6fc462b0f38b8e39f5aba31e14befedb6d0e309 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 4 Dec 2025 09:55:41 +0000 Subject: [PATCH 17/17] disable progress bar example doc --- JuliaBUGS/docs/src/example.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/JuliaBUGS/docs/src/example.md b/JuliaBUGS/docs/src/example.md index da6a3147a..e8a162f23 100644 --- a/JuliaBUGS/docs/src/example.md +++ b/JuliaBUGS/docs/src/example.md @@ -245,7 +245,8 @@ samples_and_stats = AbstractMCMC.sample( chain_type = Chains, n_adapts = n_adapts, init_params = initial_θ, - discard_initial = n_adapts + discard_initial = n_adapts, + progress = false ) describe(samples_and_stats) ```