Skip to content

Commit 3e46234

Browse files
committed
Use logistic function in StatsFuns
1 parent d0cee46 commit 3e46234

File tree

5 files changed

+87
-18
lines changed

5 files changed

+87
-18
lines changed

Manifest.toml

+79-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,31 @@
11
# This file is machine-generated - editing it directly is not advised
22

3+
[[Base64]]
4+
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
5+
36
[[BinaryProvider]]
4-
deps = ["Libdl", "SHA"]
5-
git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c"
7+
deps = ["Libdl", "Logging", "SHA"]
8+
git-tree-sha1 = "ecdec412a9abc8db54c0efc5548c64dfce072058"
69
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
7-
version = "0.5.8"
10+
version = "0.5.10"
11+
12+
[[CompilerSupportLibraries_jll]]
13+
deps = ["Libdl", "Pkg"]
14+
git-tree-sha1 = "7c4f882c41faa72118841185afc58a2eb00ef612"
15+
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
16+
version = "0.3.3+0"
17+
18+
[[Dates]]
19+
deps = ["Printf"]
20+
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
21+
22+
[[InteractiveUtils]]
23+
deps = ["Markdown"]
24+
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
25+
26+
[[LibGit2]]
27+
deps = ["Printf"]
28+
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
829

930
[[Libdl]]
1031
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
@@ -13,6 +34,31 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1334
deps = ["Libdl"]
1435
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1536

37+
[[Logging]]
38+
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
39+
40+
[[Markdown]]
41+
deps = ["Base64"]
42+
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
43+
44+
[[OpenSpecFun_jll]]
45+
deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"]
46+
git-tree-sha1 = "d51c416559217d974a1113522d5919235ae67a87"
47+
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
48+
version = "0.5.3+3"
49+
50+
[[Pkg]]
51+
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
52+
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
53+
54+
[[Printf]]
55+
deps = ["Unicode"]
56+
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
57+
58+
[[REPL]]
59+
deps = ["InteractiveUtils", "Markdown", "Sockets"]
60+
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
61+
1662
[[Random]]
1763
deps = ["Serialization"]
1864
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -23,20 +69,50 @@ git-tree-sha1 = "d37400976e98018ee840e0ca4f9d20baa231dc6b"
2369
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
2470
version = "1.0.1"
2571

72+
[[Rmath]]
73+
deps = ["Random", "Rmath_jll"]
74+
git-tree-sha1 = "86c5647b565873641538d8f812c04e4c9dbeb370"
75+
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
76+
version = "0.6.1"
77+
78+
[[Rmath_jll]]
79+
deps = ["Libdl", "Pkg"]
80+
git-tree-sha1 = "1660f8fefbf5ab9c67560513131d4e933012fc4b"
81+
uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f"
82+
version = "0.2.2+0"
83+
2684
[[SHA]]
2785
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
2886

2987
[[Serialization]]
3088
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
3189

90+
[[Sockets]]
91+
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
92+
3293
[[SparseArrays]]
3394
deps = ["LinearAlgebra", "Random"]
3495
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3596

97+
[[SpecialFunctions]]
98+
deps = ["OpenSpecFun_jll"]
99+
git-tree-sha1 = "d8d8b8a9f4119829410ecd706da4cc8594a1e020"
100+
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
101+
version = "0.10.3"
102+
36103
[[Statistics]]
37104
deps = ["LinearAlgebra", "SparseArrays"]
38105
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
39106

107+
[[StatsFuns]]
108+
deps = ["Rmath", "SpecialFunctions"]
109+
git-tree-sha1 = "04a5a8e6ab87966b43f247920eab053fd5fdc925"
110+
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
111+
version = "0.9.5"
112+
40113
[[UUIDs]]
41114
deps = ["Random", "SHA"]
42115
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
116+
117+
[[Unicode]]
118+
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1010
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
11+
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1112

1213
[compat]
1314
BinaryProvider = "0.5"
1415
Requires = "0.5, 1.0"
16+
StatsFuns = "0.9"
1517
julia = "1"
1618

1719
[extras]

src/NNlib.jl

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module NNlib
22
using Requires
3+
using StatsFuns: logistic, softplus
34

45
# Include APIs
56
include("dim_helpers.jl")

src/activation.jl

+2-14
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@ export σ, sigmoid, hardσ, hardsigmoid, hardtanh, relu, leakyrelu, relu6, rrelu
1212
Classic [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) activation
1313
function.
1414
"""
15-
function σ(x::Real)
16-
t = exp(-abs(x))
17-
ifelse(x 0, inv(one(t) + t), t / (one(t) + t))
18-
end
15+
const σ = logistic
1916
const sigmoid = σ
2017

2118
"""
@@ -181,15 +178,6 @@ See [Quadratic Polynomials Learn Better Image Features](http://www.iro.umontreal
181178
"""
182179
softsign(x::Real) = x / (one(x) + abs(x))
183180

184-
185-
"""
186-
softplus(x) = log(exp(x) + 1)
187-
188-
See [Deep Sparse Rectifier Neural Networks](http://proceedings.mlr.press/v15/glorot11a/glorot11a.pdf).
189-
"""
190-
softplus(x::Real) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x)))
191-
192-
193181
"""
194182
logcosh(x)
195183
@@ -222,7 +210,7 @@ See [Softshrink Activation Function](https://www.gabormelli.com/RKB/Softshrink_A
222210
softshrink(x::Real, λ = oftype(x/1, 0.5)) = min(max(zero(x), x - λ), x + λ)
223211

224212
# Provide an informative error message if activation functions are called with an array
225-
for f in (:σ, :hardσ, :logσ, :hardtanh, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :lisht, :selu, :celu, :trelu, :softsign, :softplus, :logcosh, :mish, :tanhshrink, :softshrink)
213+
for f in (:hardσ, :logσ, :hardtanh, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :lisht, :selu, :celu, :trelu, :softsign, :logcosh, :mish, :tanhshrink, :softshrink)
226214
@eval $(f)(x::AbstractArray, args...) =
227215
error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.")
228216
end

test/activation.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ end
108108
@testset "Array input" begin
109109
x = rand(5)
110110
for a in ACTIVATION_FUNCTIONS
111-
@test_throws ErrorException a(x)
111+
if a != σ && a != softplus
112+
@test_throws ErrorException a(x)
113+
end
112114
end
113115
end
114116

0 commit comments

Comments
 (0)