Skip to content

Commit

Permalink
Port oss bf16_fast_gemv kernel into fbcode (#3610)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#688


This diff content includes:
1. Port OSS FastGEMV `bf16` kernel into fbcode and expose to python as a step 1 - `torch.ops.fbgemm.bf16_fast_gemv`
https://github.com/wangsiping97/FastGEMV/blob/1fdff6f74aade033c02727a419afd6a4b4bfbc3f/fast_gemv.cu#L14
2. Add `bf16_oss_fast_gemv` to quantize ops benchmark script
3. Add two simple tests for custom op`torch.ops.fbgemm.f16_fast_gemv` to test
     - `torch.compile()` able
     -  correctness

Perf numbers compared with `bf16_baseline,bf16_oss_fast_gemv,cuda_lite,marlin_bf16i4,machete_bf16i4`

======================

### Benchmark Results on H100


| **M** | **N** | **K** | **Method** | **Elapsed Time (ms)** | **TFLOPS** | **GB/s** |
| --- | --- | --- | --- | --- | --- | --- |
| 1 | 1280 | 8192 | bf16_baseline | 0.024 | 0.860 | 861.042 |
| 1 | 1280 | 8192 | bf16_oss_fast_gemv | 0.019 | 1.126 | 1127.391 |
| 1 | 1280 | 8192 | cuda_lite_fp8 | 0.015 | 1.357 | 679.032 |
| 1 | 1280 | 8192 | marlin_bf16i4 | 0.027 | 0.768 | 192.612 |
| 1 | 1280 | 8192 | machete_bf16i4 | 0.026 | 0.810 | 203.219 |
| 1 | 8192 | 1024 | bf16_baseline | 0.018 | 0.952 | 953.176 |
| 1 | 8192 | 1024 | bf16_oss_fast_gemv | 0.015 | 1.100 | 1100.900 |
| 1 | 8192 | 1024 | cuda_lite_fp8 | 0.014 | 1.198 | 600.054 |
| 1 | 8192 | 1024 | marlin_bf16i4 | 0.015 | 1.144 | 287.150 |
| 1 | 8192 | 1024 | machete_bf16i4 | 0.014 | 1.187 | 298.096 |
| 1 | 7168 | 8192 | bf16_baseline | 0.073 | 1.609 | 1608.983 |
| 1 | 7168 | 8192 | bf16_oss_fast_gemv | 0.069 | 1.697 | 1697.308 |
| 1 | 7168 | 8192 | cuda_lite_fp8 | 0.044 | 2.679 | 1340.093 |
| 1 | 7168 | 8192 | marlin_bf16i4 | 0.033 | 3.590 | 898.436 |
| 1 | 7168 | 8192 | machete_bf16i4 | 0.039 | 3.017 | 755.147 |
| 1 | 8192 | 3584 | bf16_baseline | 0.045 | 1.312 | 1312.239 |
| 1 | 8192 | 3584 | bf16_oss_fast_gemv | 0.041 | 1.427 | 1427.166 |
| 1 | 8192 | 3584 | cuda_lite_fp8 | 0.026 | 2.271 | 1136.151 |
| 1 | 8192 | 3584 | marlin_bf16i4 | 0.021 | 2.808 | 703.164 |
| 1 | 8192 | 3584 | machete_bf16i4 | 0.024 | 2.460 | 615.990 |

Note that currently the precision with `fast_gemv` kernel and `cuda_lite` does not match yet. so need fp8 to coming in for a fairer result.

Also no_cuda_graph flag enabled when running the quantize_bench

heuristic sweep results from the 4 problem sizes we care about:
P1722806148


**Next step:**
Need fp8 mixed precision support for fast gemv kernel which is what we want

Differential Revision: D68470488
  • Loading branch information
YUNQIUGUO authored and facebook-github-bot committed Feb 4, 2025
1 parent bdcce9c commit d3aecb7
Show file tree
Hide file tree
Showing 8 changed files with 747 additions and 2 deletions.
5 changes: 4 additions & 1 deletion fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ endif()
# CUDA-specific sources
file(GLOB_RECURSE experimental_gen_ai_cpp_source_files_cuda
src/quantize/cutlass_extensions/*.cu
src/quantize/cutlass_extensions/**/*.cu)
src/quantize/cutlass_extensions/**/*.cu
src/quantize/fast_gemv/*.cu
src/quantize/fast_gemv/**/*.cu
src/quantize/fast_gemv/**/*.cuh)

# HIP-specific sources
file(GLOB_RECURSE experimental_gen_ai_cpp_source_files_hip
Expand Down
32 changes: 32 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,38 @@ def cuda(self) -> bool:
return True


@register_quantize_op
class BF16OSSFastGemv(QuantizeOpBase):
"""
BF16 OSS fast gemv kernel.
"""

def quantize(self, x, w):
# dummy quantize
return x, w

def compute(self, x, w):
out = torch.ops.fbgemm.bf16_fast_gemv(x, w)
return out

def quantize_and_compute(self, x, w):
x, w = self.quantize(x, w)
return self.compute(x, w)

@property
def name(self) -> str:
return "bf16_oss_fast_gemv"

@property
def hip(self) -> bool:
# This implementation is specific to cublas.
return False

@property
def cuda(self) -> bool:
return True


@register_quantize_op
class FP8CublasRowwiseGemm(QuantizeOpBase):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/core/ScalarType.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_bf16.h>

#include "include/fast_gemv.cuh"

namespace fbgemm_gpu {

// The heuristics are derived by sweeping over 4 different
// problem sizes we care about and selected the best elapsed time/bw
// combination. See more in
// deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/sweep_utils.py
dim3 get_best_block_dim(int m, int n, int k) {
if (m == 1 && n == 1280 && k == 8192) {
return dim3(128, 4);
} else if (m == 1 && n == 8192 && k == 1024) {
return dim3(32, 8);
} else if (m == 1 && n == 7168 && k == 8192) {
return dim3(256, 1);
} else if (m == 1 && n == 8192 && k == 3584) {
return dim3(64, 2);
} else {
// Default block dimensions
return dim3(32, 4);
}
}

at::Tensor bf16_fast_gemv(at::Tensor X, at::Tensor W) {
// X: M x K
// W: N x K
auto m = X.size(0);
auto n = W.size(0);
auto k = W.size(1);

TORCH_CHECK(X.is_cuda() && X.is_contiguous());
TORCH_CHECK(W.is_cuda() && W.is_contiguous());

dim3 block_dim = get_best_block_dim(m, n, k);

TORCH_CHECK(
n % block_dim.y == 0,
"Invalid block dimensions: n (",
n,
") must be divisible by block_dim.y (",
block_dim.y,
"). Received n: ",
n,
", block_dim.y: ",
block_dim.y,
" Please either use a `n` which is divisible by `block_dim.y`, or update "
"`get_best_block_dim()` heuristics to choose another `block_dim.y`. "
" All current params - m: ",
m,
", n: ",
n,
", k: ",
k,
", block_dim.x: ",
block_dim.x,
", block_dim.y: ",
block_dim.y,
".");
TORCH_CHECK(
k % block_dim.x == 0,
"Invalid block dimensions: k (",
k,
") must be divisible by block_dim.x (",
block_dim.x,
"). Received k: ",
k,
", block_dim.x: ",
block_dim.x,
" Please either use a `k` which is divisible by `block_dim.x`, or update "
"`get_best_block_dim()` heuristics to choose another `block_dim.x`."
" All current params - m: ",
m,
", n: ",
n,
", k: ",
k,
", block_dim.x: ",
block_dim.x,
", block_dim.y: ",
block_dim.y,
".");
TORCH_CHECK(
(k / block_dim.x) % 8 == 0,
"Invalid num_per_thread: (",
k / block_dim.x,
") must be divisible by 8.",
" Received k: ",
k,
", block_dim.x: ",
block_dim.x,
" Please either use a `k` that `k / block_dim.x` that is divisble by 8, or update "
"`get_best_block_dim()` heuristics to choose another `block_dim.x`."
" All current params - m: ",
m,
", n: ",
n,
", k: ",
k,
", block_dim.x: ",
block_dim.x,
", block_dim.y: ",
block_dim.y,
".");

dim3 grid_dim(1, n / block_dim.y);
unsigned int num_per_thread = k / block_dim.x;

auto stream = at::cuda::getCurrentCUDAStream();

auto Y = at::empty({m, n}, X.options().dtype(at::kBFloat16));

gemv_bf16<<<grid_dim, block_dim, 0, stream>>>(
reinterpret_cast<__nv_bfloat16*>(W.data_ptr()), // mat
reinterpret_cast<__nv_bfloat16*>(X.data_ptr()), // vec
reinterpret_cast<__nv_bfloat16*>(Y.data_ptr()), // res
k,
num_per_thread);

C10_CUDA_KERNEL_LAUNCH_CHECK();

return Y;
}

} // namespace fbgemm_gpu
Loading

0 comments on commit d3aecb7

Please sign in to comment.