Skip to content

FP8 Grouped Gemm Optimization #3655

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2447,6 +2447,13 @@ def triton_quantize_fp8_row(
torch.Tensor: fp8 scaled tensor.
torch.Tensor: reciprocal scale tensor per row.
"""
assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
a_shape = a.shape
while a.dim() < 4:
a = a.unsqueeze(0)
if zero_start_index_M is not None:
# There should be one value of zero_start_index_M per NxK matrix.
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
# Get constant values.
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
num_rows = a.numel() // a.shape[-1]
Expand Down Expand Up @@ -2484,7 +2491,7 @@ def triton_quantize_fp8_row(
USE_INT64=use_int64,
)

return a_fp8, a_scale
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])


@torch.library.custom_op("triton::quantize_fp8_row", mutates_args=())
Expand Down Expand Up @@ -2514,17 +2521,7 @@ def quantize_fp8_row(
logger.info("Triton does not support cpu, falling back to torch ops.")
use_triton = False
if use_triton:
assert (
a.dim() <= 4
), "Only up to 4 dimension input tensor is supported if use_triton is True"
a_shape = a.shape
while a.dim() < 4:
a = a.unsqueeze(0)
if zero_start_index_M is not None:
# There should be one value of zero_start_index_M per NxK matrix.
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
a_fp8, a_scale = triton_quantize_fp8_row(a, scale_ub, zero_start_index_M)
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
return triton_quantize_fp8_row(a, scale_ub, zero_start_index_M)
# else use pytorch implementation.
if not output_device:
output_device = a.device
Expand Down
103 changes: 0 additions & 103 deletions fbgemm_gpu/experimental/gen_ai/bench/profile_grouped_gemm.py

This file was deleted.

22 changes: 8 additions & 14 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
quantize_fp8_block,
quantize_fp8_row,
scale_fp8_row,
triton_quantize_fp8_row,
)
from tinygemm.utils import group_quantize_tensor

Expand Down Expand Up @@ -553,38 +554,31 @@ def preprocess(self, x, w):
def quantize(self, x, wq, w_scale, m_values=None):
# Handle case where inputs are explicitly grouped and non-sparse.
if isinstance(x, (tuple, list)):
xq, x_scale = zip(*[quantize_fp8_row(i) for i in x])
xq, x_scale = zip(*[triton_quantize_fp8_row(i) for i in x])
return xq, wq, x_scale, w_scale, m_values
# Otherwise inputs are unified tensors and sparse.
else:
B = x.shape[0]
xq, x_scale = quantize_fp8_row(x, zero_start_index_M=m_values)
xq, x_scale = triton_quantize_fp8_row(x, zero_start_index_M=m_values)
x_scale = x_scale.view(B, -1)
return xq, wq, x_scale, w_scale, m_values

def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None):
def compute(self, xq, wq, x_scale, w_scale, m_values):
if m_values is None:
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
xq,
wq,
x_scale,
w_scale,
kernel_name=kernel_name,
)
else:
# Break tensor into groups, simulates what is done e2e.
B = xq.shape[0]
xq_group = [xq[i, :, :] for i in range(B)]
x_scale_group = [x_scale[i, :] for i in range(B)]
wq_group = [wq[i, :, :] for i in range(B)]
w_scale_group = [w_scale[i, :] for i in range(B)]
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
xq_group,
wq_group,
x_scale_group,
w_scale_group,
xq,
wq,
x_scale,
w_scale,
zero_start_index_M=m_values,
kernel_name=kernel_name,
)

def quantize_and_compute(self, x, wq, w_scale, m_values=None):
Expand Down
Loading
Loading