Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port oss f16_fast_gemv into fbcode #3610

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ endif()
# CUDA-specific sources
file(GLOB_RECURSE experimental_gen_ai_cpp_source_files_cuda
src/quantize/cutlass_extensions/*.cu
src/quantize/cutlass_extensions/**/*.cu)
src/quantize/cutlass_extensions/**/*.cu
src/quantize/fast_gemv/*.cu
src/quantize/fast_gemv/**/*.cu
src/quantize/fast_gemv/**/*.cuh)

# HIP-specific sources
file(GLOB_RECURSE experimental_gen_ai_cpp_source_files_hip
Expand Down
32 changes: 32 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,38 @@ def cuda(self) -> bool:
return True


@register_quantize_op
class BF16OSSFastGemv(QuantizeOpBase):
"""
BF16 OSS fast gemv kernel.
"""

def quantize(self, x, w):
# dummy quantize
return x, w

def compute(self, x, w):
out = torch.ops.fbgemm.bf16_fast_gemv(x, w)
return out

def quantize_and_compute(self, x, w):
x, w = self.quantize(x, w)
return self.compute(x, w)

@property
def name(self) -> str:
return "bf16_oss_fast_gemv"

@property
def hip(self) -> bool:
# This implementation is specific to cublas.
return False

@property
def cuda(self) -> bool:
return True


@register_quantize_op
class FP8CublasRowwiseGemm(QuantizeOpBase):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/core/ScalarType.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_bf16.h>

#include "include/fast_gemv.cuh"

namespace fbgemm_gpu {

// The heuristics are derived by sweeping over 4 different
// problem sizes we care about and selected the best elapsed time/bw
// combination. See more in
// deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/sweep_utils.py
dim3 get_best_block_dim(int m, int n, int k) {
if (m == 1 && n == 1280 && k == 8192) {
return dim3(128, 4);
} else if (m == 1 && n == 8192 && k == 1024) {
return dim3(32, 8);
} else if (m == 1 && n == 7168 && k == 8192) {
return dim3(256, 1);
} else if (m == 1 && n == 8192 && k == 3584) {
return dim3(64, 2);
} else {
// Default block dimensions
return dim3(32, 4);
}
}

at::Tensor bf16_fast_gemv(at::Tensor X, at::Tensor W) {
// X: M x K
// W: N x K
auto m = X.size(0);
auto n = W.size(0);
auto k = W.size(1);

TORCH_CHECK(X.is_cuda() && X.is_contiguous());
TORCH_CHECK(W.is_cuda() && W.is_contiguous());

dim3 block_dim = get_best_block_dim(m, n, k);

TORCH_CHECK(
n % block_dim.y == 0,
"Invalid block dimensions: n (",
n,
") must be divisible by block_dim.y (",
block_dim.y,
"). Received n: ",
n,
", block_dim.y: ",
block_dim.y,
" Please either use a `n` which is divisible by `block_dim.y`, or update "
"`get_best_block_dim()` heuristics to choose another `block_dim.y`. "
" All current params - m: ",
m,
", n: ",
n,
", k: ",
k,
", block_dim.x: ",
block_dim.x,
", block_dim.y: ",
block_dim.y,
".");
TORCH_CHECK(
k % block_dim.x == 0,
"Invalid block dimensions: k (",
k,
") must be divisible by block_dim.x (",
block_dim.x,
"). Received k: ",
k,
", block_dim.x: ",
block_dim.x,
" Please either use a `k` which is divisible by `block_dim.x`, or update "
"`get_best_block_dim()` heuristics to choose another `block_dim.x`."
" All current params - m: ",
m,
", n: ",
n,
", k: ",
k,
", block_dim.x: ",
block_dim.x,
", block_dim.y: ",
block_dim.y,
".");
TORCH_CHECK(
(k / block_dim.x) % 8 == 0,
"Invalid num_per_thread: (",
k / block_dim.x,
") must be divisible by 8.",
" Received k: ",
k,
", block_dim.x: ",
block_dim.x,
" Please either use a `k` that `k / block_dim.x` that is divisble by 8, or update "
"`get_best_block_dim()` heuristics to choose another `block_dim.x`."
" All current params - m: ",
m,
", n: ",
n,
", k: ",
k,
", block_dim.x: ",
block_dim.x,
", block_dim.y: ",
block_dim.y,
".");

dim3 grid_dim(1, n / block_dim.y);
unsigned int num_per_thread = k / block_dim.x;

auto stream = at::cuda::getCurrentCUDAStream();

auto Y = at::empty({m, n}, X.options().dtype(at::kBFloat16));

gemv_bf16<<<grid_dim, block_dim, 0, stream>>>(
reinterpret_cast<__nv_bfloat16*>(W.data_ptr()), // mat
reinterpret_cast<__nv_bfloat16*>(X.data_ptr()), // vec
reinterpret_cast<__nv_bfloat16*>(Y.data_ptr()), // res
k,
num_per_thread);

C10_CUDA_KERNEL_LAUNCH_CHECK();

return Y;
}

} // namespace fbgemm_gpu
Loading
Loading