-
-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use/export LogExpFunctions.jl
?
#252
Comments
I guess we should. I'm just wary of adding another dependency, there have been already some complaints about latency (see #224).
Are all these gpu and AD friendly? |
I don't know the exact plans of the maintainers, I think the plan is to remove the dependency eventually at some point. There are some issues regarding Rmath (e.g. JuliaStats/Distributions.jl#1509) and there was a discussion about moving the log/exp functions to a separate package (JuliaStats/StatsFuns.jl#46).
IIRC not, therefore I used the term ┌ Warning: calls to Base intrinsics might be GPU incompatible
│ exception =
│ You called log(x::Float32) in Base.Math at special/log.jl:289, maybe you intended to call log(x::Float32) in CUDA at /home/davwi492/.julia/packages/CUDA/YeS8q/src/de
vice/intrinsics/math.jl:73 instead?
│ Stacktrace:
│ [1] log at special/log.jl:289
│ [2] broadcast_kernel at /home/davwi492/.julia/packages/GPUArrays/jhRU7/src/host/broadcast.jl:59
└ @ GPUCompiler /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/irgen.jl:68
julia> map(log ∘ expm1, CUDA.rand(Float32, 5))
┌ Warning: calls to Base intrinsics might be GPU incompatible
│ exception =
│ You called log(x::Float32) in Base.Math at special/log.jl:289, maybe you intended to call log(x::Float32) in CUDA at /home/davwi492/.julia/packages/CUDA/YeS8q/src/de
vice/intrinsics/math.jl:73 instead?
│ Stacktrace:
│ [1] log at special/log.jl:289
│ [2] broadcast_kernel at /home/davwi492/.julia/packages/GPUArrays/jhRU7/src/host/broadcast.jl:59
└ @ GPUCompiler /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/irgen.jl:68
ERROR: InvalidIRError: compiling kernel broadcast_kernel(CUDA.CuKernelContext, CuDeviceArray{Float32,1,1}, Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},Bas
e.var"#62#63"{typeof(log),typeof(expm1)},Tuple{Base.Broadcast.Extruded{CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}}, Int64) resulted in invalid LLVM IR
Reason: unsupported call to the Julia runtime (call to jl_f_tuple)
Stacktrace:
[1] expm1 at math.jl:367
[2] JuliaStats/StatsFuns.jl#62 at operators.jl:875
[3] _broadcast_getindex_evalf at broadcast.jl:648
[4] _broadcast_getindex at broadcast.jl:621
[5] getindex at broadcast.jl:575
[6] broadcast_kernel at /home/davwi492/.julia/packages/GPUArrays/jhRU7/src/host/broadcast.jl:62
Stacktrace:
[1] check_ir(::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget,CUDA.CUDACompilerParams}, ::LLVM.Module) at /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/val
idation.jl:123
[2] macro expansion at /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:239 [inlined]
[3] macro expansion at /home/davwi492/.julia/packages/TimerOutputs/ZmKD7/src/TimerOutput.jl:206 [inlined]
[4] codegen(::Symbol, ::GPUCompiler.CompilerJob; libraries::Bool, deferred_codegen::Bool, optimize::Bool, strip::Bool, validate::Bool, only_entry::Bool) at /home/davwi49
2/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:237
[5] compile(::Symbol, ::GPUCompiler.CompilerJob; libraries::Bool, deferred_codegen::Bool, optimize::Bool, strip::Bool, validate::Bool, only_entry::Bool) at /home/davwi49
2/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:39
[6] compile at /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:35 [inlined]
[7] cufunction_compile(::GPUCompiler.FunctionSpec; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/davwi492/.julia/packages/CUDA/Y
eS8q/src/compiler/execution.jl:310
[8] cufunction_compile(::GPUCompiler.FunctionSpec) at /home/davwi492/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:305
[9] check_cache(::Dict{UInt64,Any}, ::Any, ::Any, ::GPUCompiler.FunctionSpec{GPUArrays.var"#broadcast_kernel#12",Tuple{CUDA.CuKernelContext,CuDeviceArray{Float32,1,1},Ba
se.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},Base.var"#62#63"{typeof(log),typeof(expm1)},Tuple{Base.Broadcast.Extruded{CuDeviceArray{Float32,1,1},Tuple{Bool}
,Tuple{Int64}}}},Int64}}, ::UInt64; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/
cache.jl:40
[10] broadcast_kernel at /home/davwi492/.julia/packages/GPUArrays/jhRU7/src/host/broadcast.jl:60 [inlined]
[11] cached_compilation at /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/cache.jl:65 [inlined]
[12] cufunction(::GPUArrays.var"#broadcast_kernel#12", ::Type{Tuple{CUDA.CuKernelContext,CuDeviceArray{Float32,1,1},Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},Base.var"#62#63"{typeof(log),typeof(expm1)},Tuple{Base.Broadcast.Extruded{CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}},Int64}}; name::Nothing, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/davwi492/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:297
[13] cufunction at /home/davwi492/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:294 [inlined]
[14] #launch_heuristic#853 at /home/davwi492/.julia/packages/CUDA/YeS8q/src/gpuarrays.jl:19 [inlined]
[15] launch_heuristic at /home/davwi492/.julia/packages/CUDA/YeS8q/src/gpuarrays.jl:17 [inlined]
[16] copyto! at /home/davwi492/.julia/packages/GPUArrays/jhRU7/src/host/broadcast.jl:66 [inlined]
[17] copyto! at ./broadcast.jl:886 [inlined]
[18] copy at ./broadcast.jl:862 [inlined]
[19] materialize(::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,Base.var"#62#63"{typeof(log),typeof(expm1)},Tuple{CuArray{Float32,1}}}) at ./broadcast.jl:837
[20] map(::Function, ::CuArray{Float32,1}) at /home/davwi492/.julia/packages/GPUArrays/jhRU7/src/host/broadcast.jl:89
[21] top-level scope at REPL[52]:1 |
BTW StatsFuns already depends on ChainRulesCore implicitly via SpecialFunctions, so it seems custom ChainRules-based adjoints could be added to StatsFuns without introducing any additional dependencies. |
We should use https://github.com/JuliaStats/LogExpFunctions.jl, which doesn't depend on Rmath. |
Yes, this issue was one motivation for moving the functions to LogExpFunctions 🙂 |
StatsFuns.logsumexp
?LogSumExp.jl
?
LogSumExp.jl
?LogExpFunctions.jl
?
LogExpFunctions.jl should define the rrules. We could do it here, but the original repo is the natural place. Also, if we need this, we'll need to define sepate implementations for CuArrays in NNlibCUDA |
FYI recently I added the ChainRules definitions to LogExpFunctions. |
Great, we can move some of the definitions there |
Which definitions? ChainRules? LogExpFunctions contains already derivatives for all functions defined in LogExpFunctions. |
Not something that we typically pay much attention to (although we should!), but the rules themselves are differentiable? |
Nobody has tested it but they should be as they only involve basic functions or functions from LogExpFunctions for which rules are defined: https://github.com/JuliaStats/LogExpFunctions.jl/blob/master/src/chainrules.jl It might be more efficient though for in particular |
There are a few rules which have their own rules, as for Those ones look likely to work, to me. Although perhaps you could find ways to make the second more efficient. Why do they have |
Mutation of the primal result of |
Sure, I guess I mean, did this come up somewhere? Every rule in https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/arraymath.jl (except for Looking in the docs quickly, I don't actually see mention of such questions. Maybe @oxinabox has thoughts? |
The implementation of
logsumexp
in StatsFuns is quite optimized (see, e.g., JuliaStats/StatsFuns.jl#97), it works with GPUs, is numerically more stable than the implementation in NNlib, and uses a one-pass algorithm.I am wondering if NNlib should remove its own implementation and just reexport
StatsFuns.logsumexp
?More generally, maybe it would make sense to unify some of the duplicate implementations in both packages of, e.g.,
softmax
,softmax!
,sigmoid
, andsoftplus
?The text was updated successfully, but these errors were encountered: