Skip to content

Commit

Permalink
Update bf16i4 gemm with new cutlass version
Browse files Browse the repository at this point in the history
Summary: Repro for cutlass team showing accuracy loss in mixed input gemm.

Differential Revision: D68693984
  • Loading branch information
jwfromm authored and facebook-github-bot committed Jan 29, 2025
1 parent 9059770 commit d640504
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 103 deletions.
11 changes: 8 additions & 3 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,8 +972,8 @@ def _int4_row_quantize(

# Cutlass expects column major layout for scale and zero point,
# so we transpose here and make them contiguous.
scales = scales.view(x.shape[0], -1).t().contiguous()
zeros = zeros.view(x.shape[0], -1).t().contiguous()
scales = scales.view(x.shape[0], -1)
zeros = zeros.view(x.shape[0], -1)

return out, scales, zeros

Expand Down Expand Up @@ -1030,7 +1030,12 @@ def quantize(self, x, w):
wq, w_scale, w_zp = self._int4_row_quantize(w)
# Pack int4 values together.
wq = self._pack_int4(wq)
return x.to(torch.bfloat16), wq, w_scale, w_zp
return (
x.to(torch.bfloat16),
wq,
w_scale.to(torch.bfloat16),
w_zp.to(torch.bfloat16),
)

def compute(self, x, wq, w_scale, w_zp):
return torch.ops.fbgemm.bf16i4bf16_rowwise(x, wq, w_scale, w_zp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ template <
int TBS_M,
int TBS_N,
int TBS_K,
bool PONG,
typename WEIGHT_SCALE_DTYPE>
bool PONG>
at::Tensor bf16i4bf16_rowwise_impl(
at::Tensor X, // BF16
at::Tensor WQ, // INT4
Expand All @@ -42,42 +41,54 @@ at::Tensor bf16i4bf16_rowwise_impl(
int M = X.size(0);
int N = WQ.size(0);
int K = X.size(1);

int num_groups = w_scale.size(0);
int scale_k = w_scale.size(1);

TORCH_CHECK(X.is_cuda() && X.is_contiguous());
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());
TORCH_CHECK(w_scale.is_cuda() && w_scale.is_contiguous());
TORCH_CHECK(w_zp.is_cuda() && w_zp.is_contiguous());
TORCH_CHECK(K >= num_groups && K % num_groups == 0);
TORCH_CHECK(K >= scale_k && K % scale_k == 0);

int group_size = K / num_groups;
int group_size = K / scale_k;

auto Y = at::empty({M, N}, X.options().dtype(at::kBFloat16));

using ElementInputA = cutlass::bfloat16_t;
using LayoutInputA = cutlass::layout::ColumnMajor;
constexpr int AlignmentInputA =
using MmaType = cutlass::bfloat16_t;
using QuantType = cutlass::int4b_t;
// TODO Is this really needed?
constexpr int TileShapeK = 128 * 8 / cutlass::sizeof_bits<MmaType>::value;

using ElementA = MmaType;
using LayoutA = cutlass::layout::RowMajor;
constexpr int AlignmentA =
128 /
cutlass::sizeof_bits<
ElementInputA>::value; // Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)
ElementA>::value; // Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)

using ElementInputB = cutlass::int4b_t;
using LayoutInputB = cutlass::layout::RowMajor;
constexpr int AlignmentInputB =
using ElementB = QuantType;
using LayoutB = cutlass::layout::ColumnMajor;
constexpr int AlignmentB =
128 /
cutlass::sizeof_bits<
ElementInputB>::value; // Memory access granularity/alignment of B
// matrix in units of elements (up to 16 bytes)
ElementB>::value; // Memory access granularity/alignment of B
// matrix in units of elements (up to 16 bytes)

using ElementScale = WEIGHT_SCALE_DTYPE;
using ElementZeroPoint = WEIGHT_SCALE_DTYPE;
using ElementComputeEpilogue = float;
// We transpose and swap inputs.
using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutB>::type;

using LayoutScale = cutlass::layout::RowMajor;

using ElementScale = MmaType;
using ElementZero = MmaType;
using ElementCompute = float;
using ElementAccumulator = float;

using ElementOutput = cutlass::bfloat16_t;
using LayoutOutput = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
constexpr int AlignmentOutput =
128 /
cutlass::sizeof_bits<
Expand All @@ -90,20 +101,25 @@ at::Tensor bf16i4bf16_rowwise_impl(
using TileShape = cute::Shape<
cute::Int<TB_M>,
cute::Int<TB_N>,
cute::Int<TB_K>>; // Threadblock-level
// tile size
cute::Int<TileShapeK>>; // Threadblock-level
// tile size
using ClusterShape = cute::Shape<
cute::Int<TBS_M>,
cute::Int<TBS_N>,
cute::Int<TBS_K>>; // Shape of the
// threadblocks in a
// cluster
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using DefaultEpiSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using PongEpiSchedule = cutlass::epilogue::TmaWarpSpecialized;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using MainLoopSchedule =
using KernelSchedule =
cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
// TODO Possible that only cooperative schedule works.
using EpilogueSchedule = DefaultEpiSchedule;
// using EpilogueSchedule =
// cute::conditional_t<PONG, PongEpiSchedule, DefaultEpiSchedule>;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
Expand All @@ -114,67 +130,73 @@ at::Tensor bf16i4bf16_rowwise_impl(
EpilogueTileType,
ElementAccumulator,
ElementAccumulator,
// Transpose layout of D here since we use explicit swap + transpose
// the void type for C tells the builder to allocate 0 smem for the C
// matrix. We can enable this if beta == 0 by changing ElementC to
// void below.
ElementOutput,
LayoutOutput,
typename cutlass::layout::LayoutTranspose<LayoutOutput>::type,
AlignmentOutput,
ElementOutput,
LayoutOutput,
typename cutlass::layout::LayoutTranspose<LayoutOutput>::type,
AlignmentOutput,
EpilogueSchedule>::CollectiveOp;
EpilogueSchedule // This is the only epi supporting the required swap
// + transpose.
>::CollectiveOp;

using CollectiveMainloop =
using CollectiveMainloopScaleWithZeroPoint =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
cute::tuple<ElementInputB, ElementScale, ElementZeroPoint>,
LayoutInputB,
AlignmentInputB,
ElementInputA,
LayoutInputA,
AlignmentInputA,
cute::tuple<ElementB, ElementScale, ElementZero>,
LayoutB_Transpose,
AlignmentB,
ElementA,
LayoutA_Transpose,
AlignmentA,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainLoopSchedule>::CollectiveOp;
KernelSchedule>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int>,
CollectiveMainloop,
cute::Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloopScaleWithZeroPoint,
CollectiveEpilogue>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideOutput = typename Gemm::GemmKernel::StrideC;
using StrideS = typename CollectiveMainloop::StrideScale;
using StrideS = typename CollectiveMainloopScaleWithZeroPoint::StrideScale;

StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA{}, cute::make_shape(M, K, 1));
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB{}, cute::make_shape(N, K, 1));
StrideA stride_A =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1));
StrideB stride_B =
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, 1));
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput{}, cute::make_shape(N, M, 1));
StrideS stride_S = cutlass::make_cute_packed_stride(
StrideS{}, cute::make_shape(N, num_groups, 1));
StrideS{}, cute::make_shape(N, scale_k, 1));

typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K},
{reinterpret_cast<ElementInputB*>(WQ.data_ptr()),
stride_b,
reinterpret_cast<ElementInputA*>(X.data_ptr()),
stride_a,
{N, M, K, 1},
{reinterpret_cast<ElementB*>(WQ.data_ptr()),
stride_B,
reinterpret_cast<ElementA*>(X.data_ptr()),
stride_A,
reinterpret_cast<ElementScale*>(w_scale.data_ptr()),
stride_S,
group_size,
reinterpret_cast<ElementZeroPoint*>(w_zp.data_ptr())},
{{1.0, 0.0},
(ElementOutput*)Y.data_ptr<at::BFloat16>(),
reinterpret_cast<ElementZero*>(w_zp.data_ptr())},
{{},
reinterpret_cast<ElementOutput*>(Y.data_ptr()),
stride_output,
(ElementOutput*)Y.data_ptr<at::BFloat16>(),
reinterpret_cast<ElementOutput*>(Y.data_ptr()),
stride_output}};

Gemm gemm;
Expand Down Expand Up @@ -210,43 +232,21 @@ at::Tensor bf16i4bf16_rowwise_impl(
return Y;
}

template <typename WEIGHT_SCALE_DTYPE>
at::Tensor dispatch_bf16i4bf16_rowwise_kernel(
at::Tensor X, // BF16
at::Tensor WQ, // INT4
at::Tensor w_scale,
at::Tensor w_zp) {
KernelMode kernel = get_kernel_mode(X, WQ);
if (kernel == KernelMode::Small) {
return bf16i4bf16_rowwise_impl<
64,
128,
128,
2,
1,
1,
true,
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp);
return bf16i4bf16_rowwise_impl<64, 128, 128, 2, 1, 1, true>(
X, WQ, w_scale, w_zp);
} else if (kernel == KernelMode::Large) {
return bf16i4bf16_rowwise_impl<
128,
128,
128,
2,
1,
1,
true,
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp);
return bf16i4bf16_rowwise_impl<128, 128, 128, 2, 1, 1, true>(
X, WQ, w_scale, w_zp);
} else {
return bf16i4bf16_rowwise_impl<
128,
128,
128,
2,
1,
1,
false,
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp);
return bf16i4bf16_rowwise_impl<128, 128, 128, 2, 1, 1, false>(
X, WQ, w_scale, w_zp);
}
}

Expand All @@ -257,23 +257,10 @@ at::Tensor bf16i4bf16_rowwise(
at::Tensor w_zp) {
// Check datatypes.
TORCH_CHECK(
(w_scale.dtype() == at::kFloat && w_zp.dtype() == at::kFloat) ||
(w_scale.dtype() == at::kHalf && w_zp.dtype() == at::kHalf) ||
(w_scale.dtype() == at::kBFloat16 && w_zp.dtype() == at::kBFloat16),
"Weight scale and zero point tensors must be float32, bfloat16, or float16, and dtype of weight scale and zero point tensors must be the same .");

if (w_scale.dtype() == at::kFloat) {
return dispatch_bf16i4bf16_rowwise_kernel<float>(X, WQ, w_scale, w_zp);
} else if (w_scale.dtype() == at::kHalf) {
return dispatch_bf16i4bf16_rowwise_kernel<cutlass::half_t>(
X, WQ, w_scale, w_zp);
} else if (w_scale.dtype() == at::kBFloat16) {
return dispatch_bf16i4bf16_rowwise_kernel<cutlass::bfloat16_t>(
X, WQ, w_scale, w_zp);
} else {
throw std::runtime_error(
"Weight scale and zero point data type not supported in bf16i4bf16_rowwise");
}
(w_scale.dtype() == at::kBFloat16 && w_zp.dtype() == at::kBFloat16),
"Weight scale and zero point tensors must be bfloat16 and dtype of weight scale and zero point tensors must be the same.");

return dispatch_bf16i4bf16_rowwise_kernel(X, WQ, w_scale, w_zp);
}

#else
Expand Down

0 comments on commit d640504

Please sign in to comment.