Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TeLU activation functions telu and telu_fast #622

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/reference.md
Original file line number Diff line number Diff line change
@@ -31,6 +31,8 @@ swish
hardswish
tanhshrink
trelu
telu
telu_fast
```

## Attention
74 changes: 72 additions & 2 deletions src/activations.jl
Original file line number Diff line number Diff line change
@@ -7,8 +7,8 @@ ACTIVATIONS = [
, :hardσ, :hardtanh, :relu,
:leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :hardswish, :selu,
:celu, :softplus, :softsign, :logσ, :logcosh,
:mish, :tanhshrink, :softshrink, :trelu, :lisht,
:tanh_fast, :sigmoid_fast,
:mish, :tanhshrink, :softshrink, :trelu, :lisht, :telu,
:tanh_fast, :sigmoid_fast, :telu_fast
]

# of type float (to allow for integer inputs)
@@ -749,6 +749,74 @@ function softshrink(x, λ = 0.5)
ifelse(hi > 0, ifelse(lo < 0, zero(hi), lo), hi)
end

"""
telu(x) = x * tanh(exp(x))
See e.g. ["TeLU Activation Function for Fast and Stable Deep Learning"](https://arxiv.org/abs/2412.20269).
```julia-repl
julia> lineplot(telu, -2, 2, height=7)
┌────────────────────────────────────────┐
2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉│ telu(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⡤⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⣤⡤⡧⠶⠯⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│
│⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠊⠉⠉⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
-1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
julia> lineplot(telu, -5, 0, height=7)
┌────────────────────────────────────────┐
0 │⠤⠤⢄⣀⣀⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡸│ telu(x)
│⠀⠀⠀⠀⠀⠈⠉⠉⠒⠒⠤⢄⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠇│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⠢⢄⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡎⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡜⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠑⠦⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠎⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⠤⣀⡀⠀⠀⠀⣀⡤⠃⠀⠀⠀⠀│
-0.4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠉⠉⠉⠁⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀0⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
```
"""
telu(x) = x * tanh(exp(x))

"""
telu_fast(x)
This is faster but less accruate version of `telu`. This function is associated with a hard-coded derivative,
`deriv_telu_fast`, which is faster but less accurate that `deriv_telu`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this less meaningfully less accurate? In the tests there should be some functions for measuring error, countepsfrom and friends.

My guess is that for NN purposes, we will only want the fast version, and probably @fastmath x * tanh_fast(exp(x)) to speed up exp too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my gist, but translated to this notation -- there is hardly any accuracy change:

julia> worst_eps(telu_fast, telu, -5:0.01f0:5)  # comparing to bigfloat
3

julia> worst_eps(telu, telu, -5:0.01f0:5)
2

"""
telu_fast(x) = @fastmath x * tanh_fast(exp(x))

# Adapted from the Discourse post: <https://discourse.julialang.org/t/how-to-compute-tanhexp-telu-function-accurately/124464/7>
function deriv_telu(x)
exp_x = exp(x)
tanh(exp_x) + 4x / (exp(exp_x - x/2) + exp(-exp_x - x/2))^2
end

@inline function _deriv_telu_taylor_expansion(x::T) where T
tanh(one(T)) + x * 8*exp(one(T))^2 / (one(T)+exp(one(T))^2)^2
end

function deriv_telu_fast(x::T, Ω) where T
ifelse(abs(x) < sqrt(eps(T)), _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
ifelse(x >= -log(sqrt(eps(T))), one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow.
end

@inline function _deriv_telu_fast(x, Ω)
tanh_exp_x = Ω / x
sech_exp_x_squared = 1 - tanh_exp_x^2
tanh_exp_x + x * exp(x) * sech_exp_x_squared
end

# for testing accuracy
_deriv_telu_fast(x) = deriv_telu_fast(x, telu_fast(x))

# Define broadcasts for activation functions on arrays
for f in ACTIVATIONS
@eval $(f)(x::AbstractArray, args...) = $(f).(x, args...)
@@ -888,9 +956,11 @@ UNARY_ACTS = [ # f, dfdx
# mish
(:tanhshrink, :((x - Ω)^2)),
(:softshrink, :(Ω != 0)),
(:telu, :(deriv_telu(x))),
## Fast variants are the same!
(:tanh_fast, :(conj(1 - Ω^2))),
(:sigmoid_fast, :(conj* (1 - Ω)))),
(:telu_fast, :(deriv_telu_fast(x, Ω)))
]

for (f, dfdx) in UNARY_ACTS
53 changes: 48 additions & 5 deletions test/activations.jl
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI
@test mish(0.0) == 0.0
@test tanhshrink(0.0) == 0.0
@test softshrink(0.0) == 0.0
@test telu(0.0) == 0.0

@test sigmoid(1.0) == 1.0 / (1.0 + exp(-1.0))
@test hardsigmoid(1.0) == max(0,min(1, (1 + 3)/6))
@@ -48,6 +49,7 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI
@test mish(1.0) tanh(log(1.0 + exp(1.0)))
@test tanhshrink(1.0) 0.23840584404423515
@test softshrink(1.0) == 0.5
@test telu(1.0) 0.99132891580059984

@test sigmoid(-1.0) == exp(-1.0) / (1.0 + exp(-1.0))
@test hardsigmoid(-1.0) == max(0,min(1,(-1+3)/6 ))
@@ -70,6 +72,7 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI
@test mish(-1.0) -tanh(log(1.0 + exp(-1.0)))
@test tanhshrink(-1.0) -0.23840584404423515
@test softshrink(-1.0) == -0.5
@test telu(-1.0) -0.35213549054658698

@testset "Float inference" begin
@testset "$(a): " for a in ACTIVATION_FUNCTIONS
@@ -111,10 +114,10 @@ end

# Ideally +-Inf would not lead to NaN, but perhaps
# these aren't worth the complication of fixing:
a == softsign && continue
a in [softsign, telu] && continue
@test !isnan(a(Inf32))

a in [gelu, swish, hardswish, logcosh, mish] && continue
a in [gelu, swish, hardswish, logcosh, mish, telu, telu_fast] && continue
@test !isnan(a(-Inf32))
end
end
@@ -211,7 +214,7 @@ end

## Faster variants

using NNlib: tanh_fast, sigmoid_fast
using NNlib: tanh_fast, sigmoid_fast, telu_fast, deriv_telu, _deriv_telu_fast

function countepsfrom(x::T, xtrue) where {T<:AbstractFloat}
target = T(xtrue)
@@ -228,7 +231,7 @@ function find_worst(f, g, xs)
c, xs[i]
end

@testset "tanh_fast & sigmoid_fast: Float64" begin
@testset "tanh_fast, sigmoid_fast, telu_fast & deriv_telu_fast: Float64" begin

x64 = 1e-6:1e-4:5
xbig = vcat(6:3:200.0, 1000, 10^6, typemax(Float64))
@@ -262,9 +265,29 @@ end
@test sigmoid_fast.(xbig) sigmoid.(xbig)
@test sigmoid_fast.(-xbig) sigmoid.(-xbig)
end
@testset "telu" begin
mean_eps(telu, telu, x64) # 0.1146
worst_eps(telu, telu, x64) # 2

@test mean_eps(telu_fast, telu, x64) < 0.14 # 0.1338
@test worst_eps(telu_fast, telu, x64) <= 4 # 3

@test telu_fast.(xbig[1:end-1]) telu.(xbig[1:end-1])
@test telu_fast.(-xbig[1:end-1]) telu.(-xbig[1:end-1])
end
@testset "deriv_telu" begin
mean_eps(deriv_telu, deriv_telu, x64) # 0.09304
worst_eps(deriv_telu, deriv_telu, x64) # 2

@test mean_eps(_deriv_telu_fast, deriv_telu, x64) < 4.1 # 4.06396
@test worst_eps(_deriv_telu_fast, deriv_telu, x64) <= 125 # 120

@test _deriv_telu_fast.(xbig[1:end-1]) deriv_telu.(xbig[1:end-1])
@test _deriv_telu_fast.(-xbig[1:end-1]) deriv_telu.(-xbig[1:end-1])
end
end

@testset "tanh_fast & sigmoid_fast: Float32" begin
@testset "tanh_fast, sigmoid_fast, telu_fast & deriv_telu_fast: Float32" begin

x32 = 1f-6:1f-4:5
xbig32 = vcat(6:3:200f0, 1000, typemax(Float32))
@@ -298,6 +321,26 @@ end
@test sigmoid_fast.(xbig32) sigmoid.(xbig32)
@test sigmoid_fast.(-xbig32) sigmoid.(-xbig32)
end
@testset "telu" begin
mean_eps(telu, telu, x32) # 0.09418
worst_eps(telu, telu, x32) # 2

@test mean_eps(telu_fast, telu, x32) < 0.26 # 0.2555
@test worst_eps(telu_fast, telu, x32) <= 5 # 4

@test telu_fast.(xbig32[1:end-1]) telu.(xbig32[1:end-1])
@test telu_fast.(-xbig32[1:end-1]) telu.(-xbig32[1:end-1])
end
@testset "deriv_telu" begin
mean_eps(deriv_telu, deriv_telu, x32) # 0.07228
worst_eps(deriv_telu, deriv_telu, x32) # 1

@test mean_eps(_deriv_telu_fast, deriv_telu, x32) < 2.4 # 2.31772
@test worst_eps(_deriv_telu_fast, deriv_telu, x32) <= 70 # 66

@test _deriv_telu_fast.(xbig32[1:end-1]) deriv_telu.(xbig32[1:end-1])
@test _deriv_telu_fast.(-xbig32[1:end-1]) deriv_telu.(-xbig32[1:end-1])
end
end

## Autodiff tests