Skip to content

Commit

Permalink
Zero beta (#122)
Browse files Browse the repository at this point in the history
* fix alpha != 1

* Bump version

* Ensure zero-betas ignore NaNs in dest, fixes #121.

* Bump version
  • Loading branch information
chriselrod authored Nov 1, 2021
1 parent cd0be03 commit b694678
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 52 deletions.
108 changes: 56 additions & 52 deletions src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ end
return C
end
@inline function matmul_serial!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β, MKN, ::StaticInt)
_matmul_serial!(C, A, B, α, β, MKN)
return C
_matmul_serial!(C, A, B, α, β, MKN)
return C
end

"""
Expand All @@ -164,30 +164,32 @@ If the arrays are small and statically sized, it will dispatch to an inlined mul
Otherwise, based on the array's size, whether they are transposed, and whether the columns are already aligned, it decides to not pack at all, to pack only `A`, or to pack both arrays `A` and `B`.
"""
@inline function _matmul_serial!(
C::AbstractMatrix{T}, A::AbstractMatrix, B::AbstractMatrix, α, β, MKN
C::AbstractMatrix{T}, A::AbstractMatrix, B::AbstractMatrix, α, β, MKN
) where {T}
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
if M * N == 0
return
elseif K == 0
matmul_only_β!(C, β)
return
end
pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C);
Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B);
Mc, Kc, Nc = block_sizes(Val(T)); mᵣ, nᵣ = matmul_params(Val(T));
GC.@preserve Cb Ab Bb begin
if maybeinline(M, N, T, ArrayInterface.is_column_major(A)) # check MUST be compile-time resolvable
inlineloopmul!(pC, pA, pB, One(), Zero(), M, K, N)
return
elseif (nᵣ N) || dontpack(pA, M, K, Mc, Kc, T)
loopmul!(pC, pA, pB, α, β, M, K, N)
return
else
matmul_st_pack_dispatcher!(pC, pA, pB, α, β, M, K, N)
return
end
((β Zero()) && iszero(β)) && return _matmul_serial!(C, A, B, α, Zero(), MKN)
isa Bool) && return _matmul_serial!(C, A, B, α, One(), MKN)
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
if M * N == 0
return
elseif K == 0
matmul_only_β!(C, β)
return
end
pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C);
Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B);
Mc, Kc, Nc = block_sizes(Val(T)); mᵣ, nᵣ = matmul_params(Val(T));
GC.@preserve Cb Ab Bb begin
if maybeinline(M, N, T, ArrayInterface.is_column_major(A)) # check MUST be compile-time resolvable
inlineloopmul!(pC, pA, pB, One(), Zero(), M, K, N)
return
elseif (nᵣ N) || dontpack(pA, M, K, Mc, Kc, T)
loopmul!(pC, pA, pB, α, β, M, K, N)
return
else
matmul_st_pack_dispatcher!(pC, pA, pB, α, β, M, K, N)
return
end
end
end # function

function matmul_only_β!(C::AbstractMatrix{T}, β::StaticInt{0}) where T
Expand Down Expand Up @@ -266,35 +268,37 @@ end

# passing MKN directly would let osmeone skip the size check.
@inline function _matmul!(C::AbstractMatrix{T}, A, B, α, β, nthread, MKN) where {T}
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
if M * N == 0
return
elseif K == 0
matmul_only_β!(C, β)
return
end
W = pick_vector_width(T)
pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C);
Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B);
mᵣ, nᵣ = matmul_params(Val(T))
GC.@preserve Cb Ab Bb begin
if maybeinline(M, N, T, ArrayInterface.is_column_major(A)) # check MUST be compile-time resolvable
inlineloopmul!(pC, pA, pB, One(), Zero(), M, K, N)
return
else
(nᵣ N) && @goto LOOPMUL
if (Sys.ARCH === :x86_64) || (Sys.ARCH === :i686)
(M*K*N < (StaticInt{4_096}() * W)) && @goto LOOPMUL
else
(M*K*N < (StaticInt{32_000}() * W)) && @goto LOOPMUL
end
__matmul!(pC, pA, pB, α, β, M, K, N, nthread)
return
@label LOOPMUL
loopmul!(pC, pA, pB, α, β, M, K, N)
return
end
((β Zero()) && iszero(β)) && return _matmul!(C, A, B, α, Zero(), nthread, MKN)
isa Bool) && return _matmul!(C, A, B, α, One(), nthread, MKN)
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
if M * N == 0
return
elseif K == 0
matmul_only_β!(C, β)
return
end
W = pick_vector_width(T)
pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C);
Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B);
mᵣ, nᵣ = matmul_params(Val(T))
GC.@preserve Cb Ab Bb begin
if maybeinline(M, N, T, ArrayInterface.is_column_major(A)) # check MUST be compile-time resolvable
inlineloopmul!(pC, pA, pB, One(), Zero(), M, K, N)
return
else
(nᵣ N) && @goto LOOPMUL
if (Sys.ARCH === :x86_64) || (Sys.ARCH === :i686)
(M*K*N < (StaticInt{4_096}() * W)) && @goto LOOPMUL
else
(M*K*N < (StaticInt{32_000}() * W)) && @goto LOOPMUL
end
__matmul!(pC, pA, pB, α, β, M, K, N, nthread)
return
@label LOOPMUL
loopmul!(pC, pA, pB, α, β, M, K, N)
return
end
end
end

# This funciton is sort of a `pun`. It splits aggressively (it does a lot of "splitin'"), which often means it will split-N.
Expand Down
4 changes: 4 additions & 0 deletions test/matmul_main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ for T ∈ (Float64,Float32,Int64,Int32)
@time test_real(T, m_values, k_values, n_values, testset_name_suffix)
end

A = rand(2,2); B = rand(2,2); AB = A*B; C = fill(NaN, 2, 2);
@test Octavian.matmul!(C, A, B, true, false) AB
@test Octavian.matmul!(C, A, B, true, true) 2AB

2 comments on commit b694678

@chriselrod
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/47861

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.8 -m "<description of version>" b694678a2c32bf30b61996820e14604eaa296501
git push origin v0.3.8

Please sign in to comment.