diff --git a/docs/src/reference.md b/docs/src/reference.md index 5edde719..835bd113 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -31,6 +31,8 @@ swish hardswish tanhshrink trelu +telu +telu_fast ``` ## Attention diff --git a/src/activations.jl b/src/activations.jl index 4ed58622..db5fd844 100644 --- a/src/activations.jl +++ b/src/activations.jl @@ -7,8 +7,8 @@ ACTIVATIONS = [ :σ, :hardσ, :hardtanh, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :hardswish, :selu, :celu, :softplus, :softsign, :logσ, :logcosh, - :mish, :tanhshrink, :softshrink, :trelu, :lisht, - :tanh_fast, :sigmoid_fast, + :mish, :tanhshrink, :softshrink, :trelu, :lisht, :telu, + :tanh_fast, :sigmoid_fast, :telu_fast ] # of type float (to allow for integer inputs) @@ -749,6 +749,74 @@ function softshrink(x, λ = 0.5) ifelse(hi > 0, ifelse(lo < 0, zero(hi), lo), hi) end +""" + telu(x) = x * tanh(exp(x)) + +See e.g. ["TeLU Activation Function for Fast and Stable Deep Learning"](https://arxiv.org/abs/2412.20269). + +```julia-repl +julia> lineplot(telu, -2, 2, height=7) + ┌────────────────────────────────────────┐ + 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉│ telu(x) + │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│ + f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⡤⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⣤⡤⡧⠶⠯⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│ + │⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠊⠉⠉⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + -1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + └────────────────────────────────────────┘ + ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀ + ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ + +julia> lineplot(telu, -5, 0, height=7) + ┌────────────────────────────────────────┐ + 0 │⠤⠤⢄⣀⣀⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡸│ telu(x) + │⠀⠀⠀⠀⠀⠈⠉⠉⠒⠒⠤⢄⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠇│ + │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⠢⢄⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡎⠀│ + f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡜⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠑⠦⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠎⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⠤⣀⡀⠀⠀⠀⣀⡤⠃⠀⠀⠀⠀│ + -0.4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠉⠉⠉⠁⠀⠀⠀⠀⠀⠀│ + └────────────────────────────────────────┘ + ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀0⠀ + ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ + +``` +""" +telu(x) = x * tanh(exp(x)) + +""" + telu_fast(x) + +This is faster but less accruate version of `telu`. This function is associated with a hard-coded derivative, +`deriv_telu_fast`, which is faster but less accurate that `deriv_telu`. + """ +telu_fast(x) = @fastmath x * tanh_fast(exp(x)) + +# Adapted from the Discourse post: +function deriv_telu(x) + exp_x = exp(x) + tanh(exp_x) + 4x / (exp(exp_x - x/2) + exp(-exp_x - x/2))^2 +end + +@inline function _deriv_telu_taylor_expansion(x::T) where T + tanh(one(T)) + x * 8*exp(one(T))^2 / (one(T)+exp(one(T))^2)^2 +end + +function deriv_telu_fast(x::T, Ω) where T + ifelse(abs(x) < sqrt(eps(T)), _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion + ifelse(x >= -log(sqrt(eps(T))), one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow. +end + +@inline function _deriv_telu_fast(x, Ω) + tanh_exp_x = Ω / x + sech_exp_x_squared = 1 - tanh_exp_x^2 + tanh_exp_x + x * exp(x) * sech_exp_x_squared +end + +# for testing accuracy +_deriv_telu_fast(x) = deriv_telu_fast(x, telu_fast(x)) + # Define broadcasts for activation functions on arrays for f in ACTIVATIONS @eval $(f)(x::AbstractArray, args...) = $(f).(x, args...) @@ -888,9 +956,11 @@ UNARY_ACTS = [ # f, dfdx # mish (:tanhshrink, :((x - Ω)^2)), (:softshrink, :(Ω != 0)), + (:telu, :(deriv_telu(x))), ## Fast variants are the same! (:tanh_fast, :(conj(1 - Ω^2))), (:sigmoid_fast, :(conj(Ω * (1 - Ω)))), + (:telu_fast, :(deriv_telu_fast(x, Ω))) ] for (f, dfdx) in UNARY_ACTS diff --git a/test/activations.jl b/test/activations.jl index 3a14bfde..3d303fbc 100644 --- a/test/activations.jl +++ b/test/activations.jl @@ -26,6 +26,7 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI @test mish(0.0) == 0.0 @test tanhshrink(0.0) == 0.0 @test softshrink(0.0) == 0.0 +@test telu(0.0) == 0.0 @test sigmoid(1.0) == 1.0 / (1.0 + exp(-1.0)) @test hardsigmoid(1.0) == max(0,min(1, (1 + 3)/6)) @@ -48,6 +49,7 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI @test mish(1.0) ≈ tanh(log(1.0 + exp(1.0))) @test tanhshrink(1.0) ≈ 0.23840584404423515 @test softshrink(1.0) == 0.5 +@test telu(1.0) ≈ 0.99132891580059984 @test sigmoid(-1.0) == exp(-1.0) / (1.0 + exp(-1.0)) @test hardsigmoid(-1.0) == max(0,min(1,(-1+3)/6 )) @@ -70,6 +72,7 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI @test mish(-1.0) ≈ -tanh(log(1.0 + exp(-1.0))) @test tanhshrink(-1.0) ≈ -0.23840584404423515 @test softshrink(-1.0) == -0.5 +@test telu(-1.0) ≈ -0.35213549054658698 @testset "Float inference" begin @testset "$(a): " for a in ACTIVATION_FUNCTIONS @@ -111,10 +114,10 @@ end # Ideally +-Inf would not lead to NaN, but perhaps # these aren't worth the complication of fixing: - a == softsign && continue + a in [softsign, telu] && continue @test !isnan(a(Inf32)) - a in [gelu, swish, hardswish, logcosh, mish] && continue + a in [gelu, swish, hardswish, logcosh, mish, telu, telu_fast] && continue @test !isnan(a(-Inf32)) end end @@ -211,7 +214,7 @@ end ## Faster variants -using NNlib: tanh_fast, sigmoid_fast +using NNlib: tanh_fast, sigmoid_fast, telu_fast, deriv_telu, _deriv_telu_fast function countepsfrom(x::T, xtrue) where {T<:AbstractFloat} target = T(xtrue) @@ -228,7 +231,7 @@ function find_worst(f, g, xs) c, xs[i] end -@testset "tanh_fast & sigmoid_fast: Float64" begin +@testset "tanh_fast, sigmoid_fast, telu_fast & deriv_telu_fast: Float64" begin x64 = 1e-6:1e-4:5 xbig = vcat(6:3:200.0, 1000, 10^6, typemax(Float64)) @@ -262,9 +265,29 @@ end @test sigmoid_fast.(xbig) ≈ sigmoid.(xbig) @test sigmoid_fast.(-xbig) ≈ sigmoid.(-xbig) end + @testset "telu" begin + mean_eps(telu, telu, x64) # 0.1146 + worst_eps(telu, telu, x64) # 2 + + @test mean_eps(telu_fast, telu, x64) < 0.14 # 0.1338 + @test worst_eps(telu_fast, telu, x64) <= 4 # 3 + + @test telu_fast.(xbig[1:end-1]) ≈ telu.(xbig[1:end-1]) + @test telu_fast.(-xbig[1:end-1]) ≈ telu.(-xbig[1:end-1]) + end + @testset "deriv_telu" begin + mean_eps(deriv_telu, deriv_telu, x64) # 0.09304 + worst_eps(deriv_telu, deriv_telu, x64) # 2 + + @test mean_eps(_deriv_telu_fast, deriv_telu, x64) < 4.1 # 4.06396 + @test worst_eps(_deriv_telu_fast, deriv_telu, x64) <= 125 # 120 + + @test _deriv_telu_fast.(xbig[1:end-1]) ≈ deriv_telu.(xbig[1:end-1]) + @test _deriv_telu_fast.(-xbig[1:end-1]) ≈ deriv_telu.(-xbig[1:end-1]) + end end -@testset "tanh_fast & sigmoid_fast: Float32" begin +@testset "tanh_fast, sigmoid_fast, telu_fast & deriv_telu_fast: Float32" begin x32 = 1f-6:1f-4:5 xbig32 = vcat(6:3:200f0, 1000, typemax(Float32)) @@ -298,6 +321,26 @@ end @test sigmoid_fast.(xbig32) ≈ sigmoid.(xbig32) @test sigmoid_fast.(-xbig32) ≈ sigmoid.(-xbig32) end + @testset "telu" begin + mean_eps(telu, telu, x32) # 0.09418 + worst_eps(telu, telu, x32) # 2 + + @test mean_eps(telu_fast, telu, x32) < 0.26 # 0.2555 + @test worst_eps(telu_fast, telu, x32) <= 5 # 4 + + @test telu_fast.(xbig32[1:end-1]) ≈ telu.(xbig32[1:end-1]) + @test telu_fast.(-xbig32[1:end-1]) ≈ telu.(-xbig32[1:end-1]) + end + @testset "deriv_telu" begin + mean_eps(deriv_telu, deriv_telu, x32) # 0.07228 + worst_eps(deriv_telu, deriv_telu, x32) # 1 + + @test mean_eps(_deriv_telu_fast, deriv_telu, x32) < 2.4 # 2.31772 + @test worst_eps(_deriv_telu_fast, deriv_telu, x32) <= 70 # 66 + + @test _deriv_telu_fast.(xbig32[1:end-1]) ≈ deriv_telu.(xbig32[1:end-1]) + @test _deriv_telu_fast.(-xbig32[1:end-1]) ≈ deriv_telu.(-xbig32[1:end-1]) + end end ## Autodiff tests