Skip to content

Commit 1af2535

Browse files
authored
Update softmax.jl (#569)
1 parent 050b835 commit 1af2535

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/softmax.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ function softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
6262
if all(isfinite, max_)
6363
@fastmath out .= exp.(x .- max_)
6464
else
65-
@fastmath @. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 1, 0), exp(x - max_))
65+
_zero, _one, _inf = T(0), T(1), T(Inf)
66+
@fastmath @. out = ifelse(isequal(max_,_inf), ifelse(isequal(x,_inf), _one, _zero), exp(x - max_))
6667
end
6768
tmp = dims isa Colon ? sum(out) : sum!(max_, out)
6869
out ./= tmp

0 commit comments

Comments
 (0)