Skip to content

Commit b99596a

Browse files
committed
Use logistic function in StatsFuns
1 parent 1584f86 commit b99596a

File tree

4 files changed

+59
-15
lines changed

4 files changed

+59
-15
lines changed

Manifest.toml

+47
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,24 @@
33
[[Base64]]
44
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
55

6+
[[BinDeps]]
7+
deps = ["Compat", "Libdl", "SHA", "URIParser"]
8+
git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9"
9+
uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
10+
version = "0.8.10"
11+
612
[[BinaryProvider]]
713
deps = ["Libdl", "Pkg", "SHA", "Test"]
814
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
915
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
1016
version = "0.5.3"
1117

18+
[[Compat]]
19+
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
20+
git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
21+
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
22+
version = "2.1.0"
23+
1224
[[Crayons]]
1325
deps = ["Test"]
1426
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
@@ -19,6 +31,10 @@ version = "4.0.0"
1931
deps = ["Printf"]
2032
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
2133

34+
[[DelimitedFiles]]
35+
deps = ["Mmap"]
36+
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
37+
2238
[[Distributed]]
2339
deps = ["Random", "Serialization", "Sockets"]
2440
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -44,6 +60,9 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
4460
deps = ["Base64"]
4561
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
4662

63+
[[Mmap]]
64+
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
65+
4766
[[Pkg]]
4867
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
4968
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
@@ -66,23 +85,45 @@ git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
6685
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
6786
version = "0.5.2"
6887

88+
[[Rmath]]
89+
deps = ["BinaryProvider", "Libdl", "Random", "Statistics", "Test"]
90+
git-tree-sha1 = "9a6c758cdf73036c3239b0afbea790def1dabff9"
91+
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
92+
version = "0.5.0"
93+
6994
[[SHA]]
7095
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
7196

7297
[[Serialization]]
7398
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
7499

100+
[[SharedArrays]]
101+
deps = ["Distributed", "Mmap", "Random", "Serialization"]
102+
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
103+
75104
[[Sockets]]
76105
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
77106

78107
[[SparseArrays]]
79108
deps = ["LinearAlgebra", "Random"]
80109
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
81110

111+
[[SpecialFunctions]]
112+
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
113+
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
114+
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
115+
version = "0.7.2"
116+
82117
[[Statistics]]
83118
deps = ["LinearAlgebra", "SparseArrays"]
84119
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
85120

121+
[[StatsFuns]]
122+
deps = ["Rmath", "SpecialFunctions", "Test"]
123+
git-tree-sha1 = "b3a4e86aa13c732b8a8c0ba0c3d3264f55e6bb3e"
124+
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
125+
version = "0.8.0"
126+
86127
[[Test]]
87128
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
88129
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
@@ -93,6 +134,12 @@ git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
93134
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
94135
version = "0.5.0"
95136

137+
[[URIParser]]
138+
deps = ["Test", "Unicode"]
139+
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
140+
uuid = "30578b45-9adc-5946-b283-645ec420af67"
141+
version = "0.4.0"
142+
96143
[[UUIDs]]
97144
deps = ["Random", "SHA"]
98145
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1010
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
11+
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1112
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
1213

1314
[extras]

src/NNlib.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
module NNlib
2-
using Requires, TimerOutputs
2+
using Requires, TimerOutputs, StatsFuns
33

44
const to = TimerOutput()
55

src/activation.jl

+10-14
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,19 @@ export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign,
77
Classic [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) activation
88
function.
99
"""
10-
σ(x::Real) = one(x) / (one(x) + exp(-x))
10+
const σ = logistic
1111
const sigmoid = σ
1212

1313
# ForwardDiff numerical stability hack
14-
σ_stable(x::Real) = ifelse(x < -80, zero(x), one(x) / (one(x) + exp(-x)))
15-
σ(x::Float32) = σ_stable(x)
1614
@init @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
17-
σ(x::ForwardDiff.Dual{T,Float32}) where T = σ_stable(x)
15+
function σ(x::ForwardDiff.Dual{T,<:Real}) where T
16+
if x < zero(x)
17+
r = exp(x)
18+
r / (r + one(x))
19+
else
20+
inv(exp(-x) + one(x))
21+
end
22+
end
1823
end
1924

2025

@@ -110,15 +115,6 @@ See [Quadratic Polynomials Learn Better Image Features](http://www.iro.umontreal
110115
"""
111116
softsign(x::Real) = x / (one(x) + abs(x))
112117

113-
114-
"""
115-
softplus(x) = log(exp(x) + 1)
116-
117-
See [Deep Sparse Rectifier Neural Networks](http://proceedings.mlr.press/v15/glorot11a/glorot11a.pdf).
118-
"""
119-
softplus(x::Real) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x)))
120-
121-
122118
"""
123119
logcosh(x)
124120
@@ -127,7 +123,7 @@ Return `log(cosh(x))` which is computed in a numerically stable way.
127123
logcosh(x::T) where T = x + softplus(-2x) - log(convert(T, 2))
128124

129125
# Provide an informative error message if activation functions are called with an array
130-
for f in (, :σ_stable, :logσ, :relu, :leakyrelu, :elu, :gelu, :swish, :selu, :softsign, :softplus, :logcosh)
126+
for f in (, :σ_stable, :logσ, :relu, :leakyrelu, :elu, :gelu, :swish, :selu, :softsign, :(StatsFuns.softplus), :logcosh)
131127
@eval $(f)(x::AbstractArray, args...) =
132128
error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.")
133129
end

0 commit comments

Comments
 (0)