Skip to content

Commit 8f89322

Browse files
authored
Type stabilize logsoftmax (#584)
1 parent 4206e26 commit 8f89322

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/softmax.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ function logsoftmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T
113113
if all(isfinite, max_)
114114
out .= x .- max_
115115
else
116-
@. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 0, -Inf), x - max_)
116+
_zero, _minf, _inf = T(0), T(-Inf), T(Inf)
117+
@. out = ifelse(isequal(max_,_inf), ifelse(isequal(x,_inf), _zero, _minf), x - max_)
117118
end
118119
@fastmath log_ = log.(sum(exp, out; dims))
119120
out .-= log_

0 commit comments

Comments
 (0)