-
-
Notifications
You must be signed in to change notification settings - Fork 216
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Package Version
0.6.49
Julia Version
1.8.1
OS / Environment
ubuntu linux (pop!_os)
Describe the bug
get dot product dimension mismatch error in gradient call when using exp function on Diagonal matrices.
Steps to Reproduce
here is code to reproduce the error:
using Zygote
using LinearAlgebra
d = rand(3)
D = Diagonal(d)
D.diag == diag(D)
# output: true
function exp_diag_function(D::Diagonal)
return Diagonal(exp.(diag(D)))
end
function exp_diag_field_name(D::Diagonal)
return Diagonal(exp.(D.diag))
end
∇_diag_function(x) = Zygote.gradient(x -> tr(exp_diag_function(x * D)), x)
∇_diag_function(1.0)
# output: (1.6395871633504406,)
∇_diag_field_name(x) = Zygote.gradient(x -> tr(exp_diag_field_name(x * D)), x)
∇_diag_field_name(1.0)
# output: ERROR: DimensionMismatch: x and y are of different lengths!
∇_built_in(x) = Zygote.gradient(x -> tr(exp(x * D)), x)
∇_built_in(1.0)
# output: ERROR: DimensionMismatch: x and y are of different lengths!Expected Results
there should be no error, somehow using the diag field name breaks things vs using the diag method.
Observed Results
dimension mismatch error with native exp function. here is the stack trace:
ERROR: DimensionMismatch: x and y are of different lengths!
Stacktrace:
[1] dot
@ ~/packages/julias/julia-1.8/share/julia/stdlib/v1.8/LinearAlgebra/src/generic.jl:866 [inlined]
[2] dot
@ ~/packages/julias/julia-1.8/share/julia/stdlib/v1.8/LinearAlgebra/src/generic.jl:856 [inlined]
[3] #1483
@ ~/.julia/packages/ChainRules/2ql0h/src/rulesets/Base/arraymath.jl:108 [inlined]
[4] unthunk
@ ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:204 [inlined]
[5] wrap_chainrules_output
@ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:105 [inlined]
[6] map
@ ./tuple.jl:223 [inlined]
[7] wrap_chainrules_output
@ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:106 [inlined]
[8] ZBack
@ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:206 [inlined]
[9] Pullback
@ ~/projects/Pico.jl/diagonal.jl:36 [inlined]
[10] (::typeof(∂(#59)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[11] (::Zygote.var"#60#61"{typeof(∂(#59))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:45
[12] gradient(f::Function, args::Float64)
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:97
[13] ∇_built_in(x::Float64)
@ Main ~/projects/Pico.jl/diagonal.jl:36
[14] top-level scope
@ ~/projects/Pico.jl/diagonal.jl:38Relevant log output
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working