Skip to content

Commit

Permalink
Port oss f16_fast_gemv into fbcode (#3610)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#688


This diff content includes:
1. Port OSS FastGEMV `fp16` kernel into fbcode and expose to python as a step 1 - `torch.ops.fbgemm.f16_fast_gemv`
https://github.com/wangsiping97/FastGEMV/blob/1fdff6f74aade033c02727a419afd6a4b4bfbc3f/fast_gemv.cu#L14
2. Add `fp16_oss_fast_gemv` to quantize ops benchmark script
3. Add two simple tests for custom op`torch.ops.fbgemm.f16_fast_gemv` to test
     - `torch.compile()` able
     -  correctness

Perf numbers compared with `f16_baseline,fp16_oss_fast_gemv,cuda_lite,marlin_bf16i4,machete_bf16i4`

======================

### Benchmark Results

| **M** | **N** | **K** | **Method** | **Elapsed Time (ms)** | **TFLOPS** | **GB/s** |
| --- | --- | --- | --- | --- | --- | --- |
| 1 | 1280 | 8192 | bf16_baseline | 0.024 | 0.860 | 861.042 |
| 1 | 1280 | 8192 | fp16_oss_fast_gemv | 0.019 | 1.126 | 1127.391 |
| 1 | 1280 | 8192 | cuda_lite_fp8 | 0.015 | 1.357 | 679.032 |
| 1 | 1280 | 8192 | marlin_bf16i4 | 0.027 | 0.768 | 192.612 |
| 1 | 1280 | 8192 | machete_bf16i4 | 0.026 | 0.810 | 203.219 |
| 1 | 8192 | 1024 | bf16_baseline | 0.018 | 0.952 | 953.176 |
| 1 | 8192 | 1024 | fp16_oss_fast_gemv | 0.010 | 1.763 | 1765.033 |
| 1 | 8192 | 1024 | cuda_lite_fp8 | 0.014 | 1.198 | 600.054 |
| 1 | 8192 | 1024 | marlin_bf16i4 | 0.015 | 1.144 | 287.150 |
| 1 | 8192 | 1024 | machete_bf16i4 | 0.014 | 1.187 | 298.096 |
| 1 | 7168 | 8192 | bf16_baseline | 0.073 | 1.609 | 1608.983 |
| 1 | 7168 | 8192 | fp16_oss_fast_gemv | 0.069 | 1.697 | 1697.308 |
| 1 | 7168 | 8192 | cuda_lite_fp8 | 0.044 | 2.679 | 1340.093 |
| 1 | 7168 | 8192 | marlin_bf16i4 | 0.033 | 3.590 | 898.436 |
| 1 | 7168 | 8192 | machete_bf16i4 | 0.039 | 3.017 | 755.147 |
| 1 | 8192 | 3584 | bf16_baseline | 0.045 | 1.312 | 1312.239 |
| 1 | 8192 | 3584 | fp16_oss_fast_gemv | 0.026 | 2.268 | 1134.843 |
| 1 | 8192 | 3584 | cuda_lite_fp8 | 0.026 | 2.271 | 1136.151 |
| 1 | 8192 | 3584 | marlin_bf16i4 | 0.021 | 2.808 | 703.164 |
| 1 | 8192 | 3584 | machete_bf16i4 | 0.024 | 2.460 | 615.990 |

Note that currently the precision with fast_gemv kernel and cuda_lite not matched yet. so need fp8 to coming in for a fairer result.


heuristic sweep results from the 4 problem sizes we care about:

P1722806148


**Next step:**
Need fp8 mixed precision support for fast gemv kernel which is what we want

Differential Revision: D68470488
  • Loading branch information
YUNQIUGUO authored and facebook-github-bot committed Jan 31, 2025
1 parent 3266957 commit 949bdfc
Show file tree
Hide file tree
Showing 8 changed files with 615 additions and 2 deletions.
5 changes: 4 additions & 1 deletion fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ endif()
# CUDA-specific sources
file(GLOB_RECURSE experimental_gen_ai_cpp_source_files_cuda
src/quantize/cutlass_extensions/*.cu
src/quantize/cutlass_extensions/**/*.cu)
src/quantize/cutlass_extensions/**/*.cu
src/quantize/fast_gemv/*.cu
src/quantize/fast_gemv/**/*.cu
src/quantize/fast_gemv/**/*.cuh)

# HIP-specific sources
file(GLOB_RECURSE experimental_gen_ai_cpp_source_files_hip
Expand Down
32 changes: 32 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,38 @@ def cuda(self) -> bool:
return True


@register_quantize_op
class Fp16OSSFastGemv(QuantizeOpBase):
"""
FP16 OSS fast gemv kernel.
"""

def quantize(self, x, w):
# dummy quantize
return x, w

def compute(self, x, w):
out = torch.ops.fbgemm.fp16_fast_gemv(x, w)
return out

def quantize_and_compute(self, x, w):
x, w = self.quantize(x, w)
return self.compute(x, w)

@property
def name(self) -> str:
return "fp16_oss_fast_gemv"

@property
def hip(self) -> bool:
# This implementation is specific to cublas.
return False

@property
def cuda(self) -> bool:
return True


@register_quantize_op
class FP8CublasRowwiseGemm(QuantizeOpBase):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/* This source code is copied from https://github.com/wangsiping97/FastGEMV
MIT License
Copyright (c) 2023 Siping Wang
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/core/ScalarType.h>
#include <c10/cuda/CUDAGuard.h>

#include "include/fast_gemv.cuh"

namespace fbgemm_gpu {

dim3 get_best_block_dim(int m, int n, int k) {
if (m == 1 && n == 1280 && k == 8192) {
return dim3(128, 4);
} else if (m == 1 && n == 8192 && k == 1024) {
return dim3(32, 8);
} else if (m == 1 && n == 7168 && k == 8192) {
return dim3(256, 1);
} else if (m == 1 && n == 8192 && k == 3584) {
return dim3(64, 2);
} else {
// Default block dimensions
return dim3(32, 4);
}
}

at::Tensor fp16_fast_gemv(at::Tensor X, at::Tensor W) {
// X: M x K
// W: N x K
auto m = X.size(0);
auto n = W.size(0);
auto k = W.size(1);

TORCH_CHECK(X.is_cuda() && X.is_contiguous());
TORCH_CHECK(W.is_cuda() && W.is_contiguous());

// the block_dim values are sweeped results from the heuristics tuning script
dim3 block_dim = get_best_block_dim(m, n, k);

if (n % block_dim.y != 0) {
throw std::runtime_error(
"fp16_fast_gemv cannot be run when n % block_dim.x != 0");
} else if (k % block_dim.x != 0) {
throw std::runtime_error(
"fp16_fast_gemv cannot run when k % block_dim.x != 0");
} else if ((k / block_dim.x) % 8 != 0) {
throw std::runtime_error(
"fp16_fast_gemv cannot run when num_per_thread % 8 != 0");
}

dim3 grid_dim(1, n / block_dim.y);
unsigned int num_per_thread = k / block_dim.x;

auto stream = at::cuda::getCurrentCUDAStream();

auto out_sizes = X.sizes().vec();
out_sizes.front() = m;
out_sizes.back() = n;
auto Y = at::empty(out_sizes, X.options().dtype(at::kHalf));

gemv_fp16<<<grid_dim, block_dim, 0, stream>>>(
(half*)W.data_ptr(), // mat
(half*)X.data_ptr(), // vec
(half*)Y.data_ptr(), // res
k,
num_per_thread);

C10_CUDA_KERNEL_LAUNCH_CHECK();

return Y;
}

} // namespace fbgemm_gpu
Loading

0 comments on commit 949bdfc

Please sign in to comment.