diff --git a/src/batchedmul.jl b/src/batchedmul.jl index 471cac0..e9ad135 100644 --- a/src/batchedmul.jl +++ b/src/batchedmul.jl @@ -1,3 +1,6 @@ +NNlib._batched_mul!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:CuArray{T}} where {T<:Float16} = + NNlib._batched_try_gemm!(DT, C, A, B, α, β) + # Batched matrix multiplication # 1st argument is produced by NNlib.storage_type(A) NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) =