-
Notifications
You must be signed in to change notification settings - Fork 243
Open
Labels
cuda arrayStuff about CuArray.Stuff about CuArray.
Description
See the following MWE:
julia> using CUDA, LinearAlgebra
julia> rmul!([NaN, NaN, 1.0, -1.0], false)
4-element Vector{Float64}:
0.0
0.0
0.0
-0.0
julia> rmul!([NaN, NaN, 1.0, -1.0] |> cu, false)
4-element CuArray{Float32, 1, CUDA.DeviceMemory}:
NaN
NaN
0.0
-0.0
In julia base, false
is defined as a "strong zero", see https://github.com/JuliaLang/julia/blob/5e9a32e7af2837e677e60543d4a15faa8d3a7297/base/bool.jl#L178. Hence, NaN*false = 0
and NaN*true = NaN
. For consistency, the following dispatch could be defined for Bool
.
function LinearAlgebra.rmul!(x::CUDA.DenseCuArray{<:CUDA.CUBLAS.CublasFloat}, k::Bool)
k && return x
return x .= copysign.(zero(eltype(x)), x)
end
This would bypass the fallback from rmul!
to scal!
defined here:
Lines 9 to 14 in 792aec5
LinearAlgebra.rmul!(x::StridedCuArray{<:CublasFloat}, k::Number) = | |
scal!(length(x), k, x) | |
# Work around ambiguity with GPUArrays wrapper | |
LinearAlgebra.rmul!(x::DenseCuArray{<:CublasFloat}, k::Real) = | |
invoke(rmul!, Tuple{typeof(x), Number}, x, k) |
I'd be happy to create a PR + tests.
lanceXwq
Metadata
Metadata
Assignees
Labels
cuda arrayStuff about CuArray.Stuff about CuArray.