Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

permit NNlibCUDA to use Float16 #363

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ _batched_mul!(::Type, C, A, B, α::Number, β::Number) = batched_mul_generic!(C,
_batched_mul!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat} =
_batched_try_gemm!(DT, C, A, B, α, β)

function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat}
Copy link
Member

@mcabbott mcabbott Nov 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern with this change (removing {T<:BlasFloat} restriction, not highlighed well) is that it may send weird numbers (like Dual, or BigFloat) down the path towards batched_gemm! which won't accept them.

Perhaps, to safely widen here, the method _batched_gemm!(::Type{<:Array} below needs to be restricted to Array{<:BlasFloat}? With a new method offering another path to batched_mul_generic! at that stage?

The dispatch in this file is pretty convoluted! Maybe there's another tidier solution.

Float16 would be good to have, though. Thanks for digging.

Copy link
Contributor Author

@bjarthur bjarthur Nov 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the only place this method (ie _batched_try_gemm!) is currently called is from the method immediately above (ie _batched_mul!() where {T<:BlasFloat}). widening _batched_try_gemm! to types other than BlasFloat permits the proposed new _batched_mul!() where {T<:Float16} in FluxML/NNlibCUDA.jl#32 to call it too. i don't think there's any danger of weird number types getting where they shouldn't.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, now I see better what you're proposing. There are two jumps to the CUDA package, in order to allow Float16 only for CuArrays, not for Arrays. Which is the desired behaviour. The first jump comes back to this package's chain of functions.

It does seem slightly weird to jump twice. Let me think a bit more, I'd be happier if there was exactly one point in the chain where dispatch cared about CuArrays.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I dropped the ball here. I think we should do this, or at least I certainly didn't get around to thinking up a better way.

Could you perhaps add some comments explaining a bit what's going on? Having dispatch at two points, instead of just reading down the page & at some point jumping to CUDA, is one step trickier to read. Maybe the where {DT<:DenseArray{T}} where {T<:BlasFloat} = ... method can explain that there's another path through here for CuArray{Float16}?

function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T}

alpha, beta = promote(α, β, zero(T))
alpha isa T && beta isa T || return batched_mul_generic!(C, A, B, α, β)
Expand Down