-
Notifications
You must be signed in to change notification settings - Fork 162
Closed
Description
Found during work on #334.
# ------------ Static DSL ------------ #
@gen (static) function bang((grad)(x::Float64), (grad)(y::Float64))
std::Float64 = 3.0
z = @trace(normal(x + y, std), :z)
return z
end
@gen (static) function fuzz((grad)(x::Float64), (grad)(y::Float64))
std::Float64 = 3.0
z = @trace(normal(x + 2 * y, std), :z)
return z
end
sc = Switch(bang, fuzz)
@gen (static) function bam(s::Int)
x ~ sc(s, 5.0, 3.0)
return x
end
Gen.@load_generated_functions()
tr = simulate(bam, (1, ))
chm = choicemap((:x => :z, 5.0))
new_tr, w, rd, discard = update(tr, (2, ), (UnknownChange(), ), chm)
display(discard)
display(get_choices(new_tr))
new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select())
sel = AllSelection()
arg_grads, cvs, cgs = choice_gradients(tr, sel, 1.0)
throws
ERROR: LoadError: UndefVarError: RandomChoiceNodes not defined
Stacktrace:
[1] get_selected_choices(::Main.SwitchComb.Gen.AllAddressSchema, ::Main.SwitchComb.Gen.StaticIR) at /home/mccoy/code/julia/Gen.jl/src/static_ir/backprop.jl:356
[2] codegen_choice_gradients(::Type{Main.SwitchComb.var"##StaticIRTrace_bam#1717"}, ::Type{T} where T, ::Type{T} where T) at /home/mccoy/code/julia/Gen.jl/src/static_ir/backprop.jl:400
[3] #s412#18 at /home/mccoy/code/julia/Gen.jl/src/static_ir/backprop.jl:504 [inlined]
[4] #s412#18(::Any, ::Any, ::Any, ::Any, ::Any) at ./none:0
[5] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at ./boot.jl:527
[6] top-level scope at /home/mccoy/code/julia/Gen.jl/test/modeling_library/switch.jl:72
[7] include(::String) at ./client.jl:457
[8] top-level scope at REPL[2]:1
Seems like something specific to static DSL, not sure though.
Metadata
Metadata
Assignees
Labels
No labels