Skip to content

Commit 606449f

Browse files
jwfrommfacebook-github-bot
authored andcommitted
FP8 Grouped Gemm Optimization
Summary: X-link: facebookresearch/FBGEMM#731 While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`. To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead. To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor. In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads. Reviewed By: jiawenliu64 Differential Revision: D69072529
1 parent 00c43b4 commit 606449f

File tree

5 files changed

+200
-107
lines changed

5 files changed

+200
-107
lines changed

fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2447,6 +2447,13 @@ def triton_quantize_fp8_row(
24472447
torch.Tensor: fp8 scaled tensor.
24482448
torch.Tensor: reciprocal scale tensor per row.
24492449
"""
2450+
assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
2451+
a_shape = a.shape
2452+
while a.dim() < 4:
2453+
a = a.unsqueeze(0)
2454+
if zero_start_index_M is not None:
2455+
# There should be one value of zero_start_index_M per NxK matrix.
2456+
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
24502457
# Get constant values.
24512458
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
24522459
num_rows = a.numel() // a.shape[-1]
@@ -2484,7 +2491,7 @@ def triton_quantize_fp8_row(
24842491
USE_INT64=use_int64,
24852492
)
24862493

2487-
return a_fp8, a_scale
2494+
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
24882495

24892496

24902497
@torch.library.custom_op("triton::quantize_fp8_row", mutates_args=())
@@ -2514,17 +2521,7 @@ def quantize_fp8_row(
25142521
logger.info("Triton does not support cpu, falling back to torch ops.")
25152522
use_triton = False
25162523
if use_triton:
2517-
assert (
2518-
a.dim() <= 4
2519-
), "Only up to 4 dimension input tensor is supported if use_triton is True"
2520-
a_shape = a.shape
2521-
while a.dim() < 4:
2522-
a = a.unsqueeze(0)
2523-
if zero_start_index_M is not None:
2524-
# There should be one value of zero_start_index_M per NxK matrix.
2525-
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
2526-
a_fp8, a_scale = triton_quantize_fp8_row(a, scale_ub, zero_start_index_M)
2527-
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
2524+
return triton_quantize_fp8_row(a, scale_ub, zero_start_index_M)
25282525
# else use pytorch implementation.
25292526
if not output_device:
25302527
output_device = a.device

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -541,16 +541,11 @@ def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None):
541541
)
542542
else:
543543
# Break tensor into groups, simulates what is done e2e.
544-
B = xq.shape[0]
545-
xq_group = [xq[i, :, :] for i in range(B)]
546-
x_scale_group = [x_scale[i, :] for i in range(B)]
547-
wq_group = [wq[i, :, :] for i in range(B)]
548-
w_scale_group = [w_scale[i, :] for i in range(B)]
549544
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
550-
xq_group,
551-
wq_group,
552-
x_scale_group,
553-
w_scale_group,
545+
xq,
546+
wq,
547+
x_scale,
548+
w_scale,
554549
zero_start_index_M=m_values,
555550
kernel_name=kernel_name,
556551
)

0 commit comments

Comments
 (0)