Skip to content

Commit 8d6c0bf

Browse files
YUNQIUGUOfacebook-github-bot
authored andcommitted
Port oss bf16_fast_gemv kernel into fbcode (#3610)
Summary: X-link: facebookresearch/FBGEMM#688 Pull Request resolved: #3610 This diff content includes: 1. Port OSS FastGEMV `bf16` kernel into fbcode and expose to python as a step 1 - `torch.ops.fbgemm.bf16_fast_gemv` https://github.com/wangsiping97/FastGEMV/blob/1fdff6f74aade033c02727a419afd6a4b4bfbc3f/fast_gemv.cu#L14 2. Add `bf16_oss_fast_gemv` to quantize ops benchmark script 3. Add two simple tests for custom op`torch.ops.fbgemm.f16_fast_gemv` to test - `torch.compile()` able - correctness Perf numbers compared with `bf16_baseline,bf16_oss_fast_gemv,cuda_lite,marlin_bf16i4,machete_bf16i4` ====================== ### Benchmark Results on H100 | **M** | **N** | **K** | **Method** | **Elapsed Time (ms)** | **TFLOPS** | **GB/s** | | --- | --- | --- | --- | --- | --- | --- | | 1 | 1280 | 8192 | bf16_baseline | 0.024 | 0.860 | 861.042 | | 1 | 1280 | 8192 | bf16_oss_fast_gemv | 0.019 | 1.126 | 1127.391 | | 1 | 1280 | 8192 | cuda_lite_fp8 | 0.015 | 1.357 | 679.032 | | 1 | 1280 | 8192 | marlin_bf16i4 | 0.027 | 0.768 | 192.612 | | 1 | 1280 | 8192 | machete_bf16i4 | 0.026 | 0.810 | 203.219 | | 1 | 8192 | 1024 | bf16_baseline | 0.018 | 0.952 | 953.176 | | 1 | 8192 | 1024 | bf16_oss_fast_gemv | 0.015 | 1.100 | 1100.900 | | 1 | 8192 | 1024 | cuda_lite_fp8 | 0.014 | 1.198 | 600.054 | | 1 | 8192 | 1024 | marlin_bf16i4 | 0.015 | 1.144 | 287.150 | | 1 | 8192 | 1024 | machete_bf16i4 | 0.014 | 1.187 | 298.096 | | 1 | 7168 | 8192 | bf16_baseline | 0.073 | 1.609 | 1608.983 | | 1 | 7168 | 8192 | bf16_oss_fast_gemv | 0.069 | 1.697 | 1697.308 | | 1 | 7168 | 8192 | cuda_lite_fp8 | 0.044 | 2.679 | 1340.093 | | 1 | 7168 | 8192 | marlin_bf16i4 | 0.033 | 3.590 | 898.436 | | 1 | 7168 | 8192 | machete_bf16i4 | 0.039 | 3.017 | 755.147 | | 1 | 8192 | 3584 | bf16_baseline | 0.045 | 1.312 | 1312.239 | | 1 | 8192 | 3584 | bf16_oss_fast_gemv | 0.041 | 1.427 | 1427.166 | | 1 | 8192 | 3584 | cuda_lite_fp8 | 0.026 | 2.271 | 1136.151 | | 1 | 8192 | 3584 | marlin_bf16i4 | 0.021 | 2.808 | 703.164 | | 1 | 8192 | 3584 | machete_bf16i4 | 0.024 | 2.460 | 615.990 | Note that currently the precision with `fast_gemv` kernel and `cuda_lite` does not match yet. so need fp8 to coming in for a fairer result. Also no_cuda_graph flag enabled when running the quantize_bench heuristic sweep results from the 4 problem sizes we care about: P1722806148 **Next step:** Need fp8 mixed precision support for fast gemv kernel which is what we want Reviewed By: ipiszy Differential Revision: D68470488 fbshipit-source-id: ac2e6b857c0b9984ee25b5df2a90cce6d1a93dd8
1 parent 864db21 commit 8d6c0bf

File tree

8 files changed

+747
-2
lines changed

8 files changed

+747
-2
lines changed

fbgemm_gpu/experimental/gen_ai/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ endif()
3333
# CUDA-specific sources
3434
file(GLOB_RECURSE experimental_gen_ai_cpp_source_files_cuda
3535
src/quantize/cutlass_extensions/*.cu
36-
src/quantize/cutlass_extensions/**/*.cu)
36+
src/quantize/cutlass_extensions/**/*.cu
37+
src/quantize/fast_gemv/*.cu
38+
src/quantize/fast_gemv/**/*.cu
39+
src/quantize/fast_gemv/**/*.cuh)
3740

3841
# HIP-specific sources
3942
file(GLOB_RECURSE experimental_gen_ai_cpp_source_files_hip

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,38 @@ def cuda(self) -> bool:
370370
return True
371371

372372

373+
@register_quantize_op
374+
class BF16OSSFastGemv(QuantizeOpBase):
375+
"""
376+
BF16 OSS fast gemv kernel.
377+
"""
378+
379+
def quantize(self, x, w):
380+
# dummy quantize
381+
return x, w
382+
383+
def compute(self, x, w):
384+
out = torch.ops.fbgemm.bf16_fast_gemv(x, w)
385+
return out
386+
387+
def quantize_and_compute(self, x, w):
388+
x, w = self.quantize(x, w)
389+
return self.compute(x, w)
390+
391+
@property
392+
def name(self) -> str:
393+
return "bf16_oss_fast_gemv"
394+
395+
@property
396+
def hip(self) -> bool:
397+
# This implementation is specific to cublas.
398+
return False
399+
400+
@property
401+
def cuda(self) -> bool:
402+
return True
403+
404+
373405
@register_quantize_op
374406
class FP8CublasRowwiseGemm(QuantizeOpBase):
375407
"""
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include <ATen/cuda/CUDAContext.h>
11+
#include <c10/core/ScalarType.h>
12+
#include <c10/cuda/CUDAGuard.h>
13+
#include <cuda_bf16.h>
14+
15+
#include "include/fast_gemv.cuh"
16+
17+
namespace fbgemm_gpu {
18+
19+
// The heuristics are derived by sweeping over 4 different
20+
// problem sizes we care about and selected the best elapsed time/bw
21+
// combination. See more in
22+
// deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/sweep_utils.py
23+
dim3 get_best_block_dim(int m, int n, int k) {
24+
if (m == 1 && n == 1280 && k == 8192) {
25+
return dim3(128, 4);
26+
} else if (m == 1 && n == 8192 && k == 1024) {
27+
return dim3(32, 8);
28+
} else if (m == 1 && n == 7168 && k == 8192) {
29+
return dim3(256, 1);
30+
} else if (m == 1 && n == 8192 && k == 3584) {
31+
return dim3(64, 2);
32+
} else {
33+
// Default block dimensions
34+
return dim3(32, 4);
35+
}
36+
}
37+
38+
at::Tensor bf16_fast_gemv(at::Tensor X, at::Tensor W) {
39+
// X: M x K
40+
// W: N x K
41+
auto m = X.size(0);
42+
auto n = W.size(0);
43+
auto k = W.size(1);
44+
45+
TORCH_CHECK(X.is_cuda() && X.is_contiguous());
46+
TORCH_CHECK(W.is_cuda() && W.is_contiguous());
47+
48+
dim3 block_dim = get_best_block_dim(m, n, k);
49+
50+
TORCH_CHECK(
51+
n % block_dim.y == 0,
52+
"Invalid block dimensions: n (",
53+
n,
54+
") must be divisible by block_dim.y (",
55+
block_dim.y,
56+
"). Received n: ",
57+
n,
58+
", block_dim.y: ",
59+
block_dim.y,
60+
" Please either use a `n` which is divisible by `block_dim.y`, or update "
61+
"`get_best_block_dim()` heuristics to choose another `block_dim.y`. "
62+
" All current params - m: ",
63+
m,
64+
", n: ",
65+
n,
66+
", k: ",
67+
k,
68+
", block_dim.x: ",
69+
block_dim.x,
70+
", block_dim.y: ",
71+
block_dim.y,
72+
".");
73+
TORCH_CHECK(
74+
k % block_dim.x == 0,
75+
"Invalid block dimensions: k (",
76+
k,
77+
") must be divisible by block_dim.x (",
78+
block_dim.x,
79+
"). Received k: ",
80+
k,
81+
", block_dim.x: ",
82+
block_dim.x,
83+
" Please either use a `k` which is divisible by `block_dim.x`, or update "
84+
"`get_best_block_dim()` heuristics to choose another `block_dim.x`."
85+
" All current params - m: ",
86+
m,
87+
", n: ",
88+
n,
89+
", k: ",
90+
k,
91+
", block_dim.x: ",
92+
block_dim.x,
93+
", block_dim.y: ",
94+
block_dim.y,
95+
".");
96+
TORCH_CHECK(
97+
(k / block_dim.x) % 8 == 0,
98+
"Invalid num_per_thread: (",
99+
k / block_dim.x,
100+
") must be divisible by 8.",
101+
" Received k: ",
102+
k,
103+
", block_dim.x: ",
104+
block_dim.x,
105+
" Please either use a `k` that `k / block_dim.x` that is divisble by 8, or update "
106+
"`get_best_block_dim()` heuristics to choose another `block_dim.x`."
107+
" All current params - m: ",
108+
m,
109+
", n: ",
110+
n,
111+
", k: ",
112+
k,
113+
", block_dim.x: ",
114+
block_dim.x,
115+
", block_dim.y: ",
116+
block_dim.y,
117+
".");
118+
119+
dim3 grid_dim(1, n / block_dim.y);
120+
unsigned int num_per_thread = k / block_dim.x;
121+
122+
auto stream = at::cuda::getCurrentCUDAStream();
123+
124+
auto Y = at::empty({m, n}, X.options().dtype(at::kBFloat16));
125+
126+
gemv_bf16<<<grid_dim, block_dim, 0, stream>>>(
127+
reinterpret_cast<__nv_bfloat16*>(W.data_ptr()), // mat
128+
reinterpret_cast<__nv_bfloat16*>(X.data_ptr()), // vec
129+
reinterpret_cast<__nv_bfloat16*>(Y.data_ptr()), // res
130+
k,
131+
num_per_thread);
132+
133+
C10_CUDA_KERNEL_LAUNCH_CHECK();
134+
135+
return Y;
136+
}
137+
138+
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)