Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test out D68193920 #3606

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,40 @@ def cuda(self) -> bool:
return True


@register_quantize_op
class FP8LiteGemm(QuantizeOpBase):
"""
FP8 lite matmul for memory bound.
"""

def quantize(self, x, w):
# Quantize both input tensors.
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
return xq, wq, x_scale, w_scale

def compute(self, xq, wq, x_scale, w_scale):
return torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale = self.quantize(x, w)
return self.compute(xq, wq, x_scale * w_scale)

@property
def name(self) -> str:
return "cuda_lite"

@property
def hip(self) -> bool:
# Need to add support for better quantize kernel.
# Also may have an issue with cuda graphs.
return False

@property
def cuda(self) -> bool:
return True


@register_quantize_op
class TritonFP8RowwiseGemm(QuantizeOpBase):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cutlass/numeric_conversion.h>
#include <cub/cub.cuh>

namespace fbgemm_gpu {

#if CUDART_VERSION >= 12000

using SizeType32 = std::size_t;

struct Params {
void const* act;
void const* weight;
void const* alpha;
void* output;
SizeType32 m, n, k;

Params(
void const* _act,
void const* _weight,
void const* _alpha,
void* _output,
SizeType32 _m,
SizeType32 _n,
SizeType32 _k)
: act(_act),
weight(_weight),
alpha(_alpha),
output(_output),
m(_m),
n(_n),
k(_k) {}
};

template <
typename InputType,
typename OutputType,
SizeType32 TILE_M,
SizeType32 TILE_N,
SizeType32 BLOCK_SIZE>
__global__ void cudaCoreGemm(
InputType const* __restrict__ act,
InputType const* __restrict__ weight,
float const* alpha,
OutputType* __restrict__ output,
SizeType32 m,
SizeType32 n,
SizeType32 k) {
using VecType = int4;
static constexpr SizeType32 kStepK =
static_cast<SizeType32>(128 / (8 * sizeof(InputType)));
static constexpr SizeType32 kTileK = kStepK * BLOCK_SIZE;
auto tileIdM = static_cast<SizeType32>(blockIdx.x * TILE_M);
auto tileIdN = static_cast<SizeType32>(blockIdx.y * TILE_N);
auto tid = static_cast<SizeType32>(threadIdx.x);
float tile_a[kStepK], tile_w[TILE_N * kStepK];
float acc[TILE_M * TILE_N];

static_assert(kStepK % 4 == 0);
using CvtInputType = cutlass::float_e4m3_t;
using Converter = cutlass::NumericArrayConverter<float, CvtInputType, 4>;
using CvtSrcType = typename Converter::source_type;
using CvtResType = typename Converter::result_type;
static constexpr SizeType32 kCvtCount =
static_cast<SizeType32>(sizeof(VecType) / sizeof(CvtSrcType));

#pragma unroll
for (SizeType32 i = 0; i < TILE_M * TILE_N; ++i) {
acc[i] = 0;
}

act += tileIdM * k;
weight += tileIdN * k;
output += tileIdM * n + tileIdN;

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif

for (SizeType32 idxK = tid * kStepK; idxK < k; idxK += kTileK) {
for (SizeType32 i = 0; i < TILE_N; ++i) {
auto tile_w_quantized =
reinterpret_cast<VecType const*>(weight + i * k + idxK)[0];
#pragma unroll
for (SizeType32 cvtIdx = 0; cvtIdx < kCvtCount; ++cvtIdx) {
reinterpret_cast<CvtResType*>(tile_w)[i * kCvtCount + cvtIdx] =
Converter::convert(
reinterpret_cast<CvtSrcType*>(&tile_w_quantized)[cvtIdx]);
}
}
// Remove the pragma unroll that used to be here
for (SizeType32 i = 0; i < TILE_M; ++i) {
auto tile_a_quantized =
reinterpret_cast<VecType const*>(act + i * k + idxK)[0];

#pragma unroll
for (SizeType32 cvtIdx = 0; cvtIdx < kCvtCount; ++cvtIdx) {
reinterpret_cast<CvtResType*>(tile_a)[cvtIdx] = Converter::convert(
reinterpret_cast<CvtSrcType*>(&tile_a_quantized)[cvtIdx]);
}

#pragma unroll
for (SizeType32 j = 0; j < TILE_N; ++j) {
#pragma unroll
for (SizeType32 l = 0; l < kStepK; ++l) {
acc[i * TILE_N + j] =
fma(tile_a[l], tile_w[j * kStepK + l], acc[i * TILE_N + j]);
}
}
}
}

typedef cub::WarpReduce<float> WarpReduce;

static constexpr SizeType32 kWarpSize = 32;
static constexpr SizeType32 kWarpNum = BLOCK_SIZE / kWarpSize;
SizeType32 warpId = tid / kWarpSize, laneId = tid % kWarpSize;
__shared__ float shmem[TILE_M * TILE_N * kWarpNum];
__shared__ typename WarpReduce::TempStorage tempStorage[kWarpNum];
#pragma unroll
for (SizeType32 mi = 0; mi < TILE_M; ++mi) {
#pragma unroll
for (SizeType32 ni = 0; ni < TILE_N; ++ni) {
float val = WarpReduce(tempStorage[warpId]).Sum(acc[mi * TILE_N + ni]);
if (laneId == 0) {
shmem[mi * TILE_N + ni + warpId * TILE_M * TILE_N] = val;
}
}
}

__syncthreads();

for (SizeType32 ii = tid; ii < TILE_M * TILE_N; ii += BLOCK_SIZE) {
SizeType32 mid = ii / TILE_N, nid = ii % TILE_N;
float val = 0;
#pragma unroll
for (SizeType32 jj = 0; jj < kWarpNum; ++jj) {
val += shmem[jj * TILE_M * TILE_N + ii];
}
output[mid * n + nid] = static_cast<OutputType>(val * *alpha);
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

template <
typename InputType,
typename OutputType,
SizeType32 TILE_M,
SizeType32 TILE_N,
SizeType32 BLOCK_SIZE>
void cudaCoreGemmKernel(Params const& params, cudaStream_t stream) {
dim3 block(BLOCK_SIZE);
dim3 grid(params.m / TILE_M, params.n / TILE_N);

cudaCoreGemm<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE>
<<<grid, block, 0, stream>>>(
reinterpret_cast<InputType const*>(params.act),
reinterpret_cast<InputType const*>(params.weight),
reinterpret_cast<float const*>(params.alpha),
reinterpret_cast<OutputType*>(params.output),
params.m,
params.n,
params.k);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template <
typename InputType,
typename OutputType,
int TILE_M,
int TILE_N,
int BLOCK_SIZE>
bool cudaCoreGemmTemplateCaller(Params const& params, cudaStream_t stream) {
constexpr int cudaCoreGemmTemplateMaxM = 128;
if (params.m == TILE_M) {
cudaCoreGemmKernel<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE>(
params, stream);
return true;
}
if constexpr (TILE_M < cudaCoreGemmTemplateMaxM) {
return cudaCoreGemmTemplateCaller<
InputType,
OutputType,
TILE_M + 1,
TILE_N,
BLOCK_SIZE>(params, stream);
}
return false;
}

template <typename InputType, typename OutputType>
bool cudaCoreGemmLauncher(Params const& params, cudaStream_t stream) {
return cudaCoreGemmTemplateCaller<InputType, OutputType, 1, 2, 128>(
params, stream);
}

at::Tensor f8f8bf16_lite(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor scale) {
bool dispatched = true;
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
TORCH_CHECK(XQ.size(-1) == K);

if (M > 128) {
throw std::runtime_error("f8f8bf16_lite cannot run when M > 128");
} else if (N % 2 != 0) {
throw std::runtime_error("f8f8bf16_lite cannot run when N % 2 != 0");
} else if (K % 16 != 0) {
throw std::runtime_error("f8f8bf16_lite cannot run when K % 16 != 0");
}

auto out_sizes = XQ.sizes().vec();
out_sizes.back() = N;
at::Tensor Y = at::empty(out_sizes, XQ.options().dtype(at::kBFloat16));

Params params{
XQ.data_ptr(),
WQ.data_ptr(),
scale.data_ptr(),
Y.data_ptr(),
(SizeType32)M,
(SizeType32)N,
(SizeType32)K};
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dispatched = cudaCoreGemmLauncher<cutlass::float_e4m3_t, __nv_bfloat16>(
params, stream);
if (!dispatched) {
throw std::runtime_error("f8f8bf16_lite cannot run");
}
return Y;
}

#else

at::Tensor f8f8bf16_lite(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor scale) {
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}

#endif

} // namespace fbgemm_gpu
12 changes: 12 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ at::Tensor f8f8bf16_tensorwise(
at::Tensor WQ,
double scale,
bool use_fast_accum = true);
at::Tensor f8f8bf16_lite(at::Tensor XQ, at::Tensor WQ, at::Tensor scale);
std::vector<at::Tensor> f8f8bf16_grouped(
at::TensorList XQ,
at::TensorList WQ,
Expand Down Expand Up @@ -187,6 +188,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"f8i4bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_zp) -> Tensor");
m.def(
"f8f8bf16_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] scale, Tensor? zero_start_index_M=None, bool use_fast_accum=True) -> Tensor[]");
m.def("f8f8bf16_lite(Tensor XQ, Tensor WQ, Tensor scale) -> Tensor");
m.def(
"bf16i4bf16_rowwise(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor");
m.def(
Expand Down Expand Up @@ -268,6 +270,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
m.impl("f8f8bf16", f8f8bf16);
m.impl("f8f8bf16_cublas", f8f8bf16_cublas);
m.impl("f8f8bf16_grouped", f8f8bf16_grouped);
m.impl("f8f8bf16_lite", f8f8bf16_lite);
m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise);
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched);
m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise);
Expand Down Expand Up @@ -295,6 +298,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
m.impl("f8f8bf16", f8f8bf16);
m.impl("f8f8bf16_cublas", f8f8bf16_cublas);
m.impl("f8f8bf16_grouped", f8f8bf16_grouped);
m.impl("f8f8bf16_lite", f8f8bf16_lite);
m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise);
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched);
m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise);
Expand Down Expand Up @@ -415,6 +419,13 @@ at::Tensor f8f8bf16_tensorwise_meta(
return Y;
}

at::Tensor f8f8bf16_lite_meta(at::Tensor X, at::Tensor W, at::Tensor scale) {
const at::SymInt M = X.sym_size(0);
const at::SymInt N = W.sym_size(0);
auto Y = at::empty_symint({M, N}, X.options().dtype(at::kBFloat16));
return Y;
}

at::Tensor f8i4bf16_rowwise_meta(
at::Tensor XQ, // FP8
at::Tensor WQ, // INT4
Expand Down Expand Up @@ -533,6 +544,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise_meta);
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched_meta);
m.impl("f8f8bf16_grouped", f8f8bf16_grouped_meta);
m.impl("f8f8bf16_lite", f8f8bf16_lite_meta);
#endif
}

Expand Down
24 changes: 24 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,30 @@ def test_quantize_zero_input(self, K) -> None:
torch.testing.assert_close(w.shape, wq.shape)
torch.testing.assert_close(w_scale.shape, w_scale_ref.shape)

@unittest.skipIf(torch.version.hip, "Skip on AMD: fp8 lite op is yet suported.")
@settings(deadline=None)
@given(
M=st.sampled_from([1, 5, 16]),
N=st.sampled_from([1024, 6144]),
K=st.sampled_from([512, 3584]),
CudaGraph=st.sampled_from([True, False]),
)
def test_fp8_lite_matmul(self, M: int, N: int, K: int, CudaGraph: bool) -> None:
x = torch.randn(size=(M, K), dtype=torch.bfloat16, device="cuda") * 0.1
w = torch.randn(size=(N, K), dtype=torch.bfloat16, device="cuda") * 0.01
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
if CudaGraph:
zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
g.replay()
else:
zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
zq_ref = (x @ w.T).to(torch.bfloat16)
torch.testing.assert_close(zq, zq_ref, atol=9.0e-2, rtol=9.0e-2)


if __name__ == "__main__":
unittest.main()
Loading