Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for half precision gemm #32

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

bjarthur
Copy link

@bjarthur bjarthur commented Nov 16, 2021

in conjunction with FluxML/NNlib.jl#363, add support for half-precision gemm, for which a special kernel is provided by Nvidia. see JuliaGPU/CUDA.jl#1080

@mcabbott
Copy link
Member

Why do you say this is needed in addition? It looks like an alternative path. But the existing method NNlib._batched_gemm!(::Type{<:CuArray}, ought to match Float16 (if NNlib.jl would let it be called).

What would be good to add here is tests using this precision. Which I think should test the user-facing batched_mul not the internal functions.

@DhairyaLGandhi
Copy link
Member

Why would nnlib prevent it from getting called?

@bjarthur
Copy link
Author

the current code actually works with Float16, but falls back to batched_mul_generic! where a loop is performed over the last dimension. so painfully slow. i thought about tests, but couldn't come up with a way to test that the batched nvidia kernel is called instead.

@ToucheSir
Copy link
Member

Yup, the overriden method in NNlib uses BlasFloat, which does not include Float16. Now, one hang-up I see with this PR is that _batched_try_gemm! also only accepts BlasFloat. @bjarthur can you confirm this works locally without any errors?

@bjarthur
Copy link
Author

indeed, it does work locally without any errors, otherwise i would not have submitted it! ;)

@ToucheSir
Copy link
Member

Great, I think per @mcabbott's comment a test for this would be good :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants