Skip to content

Commit

Permalink
Port oss f16_fast_gemv into fbcode (pytorch#3610)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#688


This diff content includes:
1. Port OSS FastGEMV `fp16` kernel into fbcode and expose to python as a step 1 - `torch.ops.fbgemm.f16_fast_gemv`
https://github.com/wangsiping97/FastGEMV/blob/1fdff6f74aade033c02727a419afd6a4b4bfbc3f/fast_gemv.cu#L14
2. Add `fp16_oss_fast_gemv` to quantize ops benchmark script
3. Add two simple tests for custom op`torch.ops.fbgemm.f16_fast_gemv` to test
     - `torch.compile()` able
     -  correctness

**Next step:**
Need fp8 mixed precision support for fast gemv kernel which is what we want

Differential Revision: D68470488
  • Loading branch information
YUNQIUGUO authored and facebook-github-bot committed Jan 27, 2025
1 parent c8da60b commit 0c02e5a
Show file tree
Hide file tree
Showing 9 changed files with 591 additions and 1 deletion.
5 changes: 4 additions & 1 deletion fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ endif()
# CUDA-specific sources
file(GLOB_RECURSE experimental_gen_ai_cpp_source_files_cuda
src/quantize/cutlass_extensions/*.cu
src/quantize/cutlass_extensions/**/*.cu)
src/quantize/cutlass_extensions/**/*.cu
src/quantize/fast_gemv/*.cu
src/quantize/fast_gemv/**/*.cu
src/quantize/fast_gemv/**/*.cuh)

# HIP-specific sources
file(GLOB_RECURSE experimental_gen_ai_cpp_source_files_hip
Expand Down
21 changes: 21 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2023 Siping Wang

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
32 changes: 32 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,38 @@ def cuda(self) -> bool:
return True


@register_quantize_op
class FP16OSSFastGemv(QuantizeOpBase):
"""
FP16 oss fast gemv.
"""

def quantize(self, x, w):
return x, w

def compute(self, x, w):
out = torch.ops.fbgemm.f16_fast_gemv(x, w)
return out

def quantize_and_compute(self, x, w):
# dummy quantize
x, w = self.quantize(x, w)
return self.compute(x, w)

@property
def name(self) -> str:
return "fp16_oss_fast_gemv"

@property
def hip(self) -> bool:
# This implementation is specific to cublas.
return False

@property
def cuda(self) -> bool:
return True


@register_quantize_op
class FP8CublasRowwiseGemm(QuantizeOpBase):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/core/ScalarType.h>
#include <c10/cuda/CUDAGuard.h>

#include "include/fast_gemv.cuh"

namespace fbgemm_gpu {

#if CUDART_VERSION >= 12000

at::Tensor f16_fast_gemv(at::Tensor X, at::Tensor W) {
// note: oss fast gemv implementation accepts vector shape as (size, 1) i.e.
// (K, M)
// X: K x M
// W: N x K
auto m = X.size(1);
auto n = W.size(0);
auto k = W.size(1);

TORCH_CHECK(X.is_cuda() && X.is_contiguous());
TORCH_CHECK(W.is_cuda() && W.is_contiguous());

auto block_dim_x = k / 8;
auto block_dim_y = MAX_THREADS_PER_BLOCK / block_dim_x;
dim3 block_dim(block_dim_x, block_dim_y);
dim3 grid_dim(1, n / block_dim_y);
unsigned int num_per_thread = k / block_dim_x;

auto stream = at::cuda::getCurrentCUDAStream();

auto Y = at::empty({n, m}, X.options().dtype(at::kHalf));

gemv_fp16<<<grid_dim, block_dim, 0, stream>>>(
(half*)W.data_ptr(), // mat
(half*)X.data_ptr(), // vec
(half*)Y.data_ptr(), // res
k,
num_per_thread);
C10_CUDA_KERNEL_LAUNCH_CHECK();

return Y;
}

#else

at::Tensor f16_fast_gemv(at::Tensor X, at::Tensor W) {
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}
#endif

} // namespace fbgemm_gpu
Loading

0 comments on commit 0c02e5a

Please sign in to comment.