Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 Grouped Gemm Optimization #3655

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2447,6 +2447,13 @@ def triton_quantize_fp8_row(
torch.Tensor: fp8 scaled tensor.
torch.Tensor: reciprocal scale tensor per row.
"""
assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
a_shape = a.shape
while a.dim() < 4:
a = a.unsqueeze(0)
if zero_start_index_M is not None:
# There should be one value of zero_start_index_M per NxK matrix.
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
# Get constant values.
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
num_rows = a.numel() // a.shape[-1]
Expand Down Expand Up @@ -2484,7 +2491,7 @@ def triton_quantize_fp8_row(
USE_INT64=use_int64,
)

return a_fp8, a_scale
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])


@torch.library.custom_op("triton::quantize_fp8_row", mutates_args=())
Expand Down Expand Up @@ -2514,17 +2521,7 @@ def quantize_fp8_row(
logger.info("Triton does not support cpu, falling back to torch ops.")
use_triton = False
if use_triton:
assert (
a.dim() <= 4
), "Only up to 4 dimension input tensor is supported if use_triton is True"
a_shape = a.shape
while a.dim() < 4:
a = a.unsqueeze(0)
if zero_start_index_M is not None:
# There should be one value of zero_start_index_M per NxK matrix.
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
a_fp8, a_scale = triton_quantize_fp8_row(a, scale_ub, zero_start_index_M)
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
return triton_quantize_fp8_row(a, scale_ub, zero_start_index_M)
# else use pytorch implementation.
if not output_device:
output_device = a.device
Expand Down
80 changes: 52 additions & 28 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
import seaborn as sns
import torch

try:
from accelerators.pytorch.lib.utils.torch_profiler import profiler_or_nullcontext
except ImportError:
from contextlib import nullcontext

class profiler_or_nullcontext(nullcontext):
def __init__(self, *args, **kwargs):
super().__init__()


from .quantize_ops import get_quantize_ops, QuantizeOpBase


Expand Down Expand Up @@ -96,6 +106,7 @@ def benchmark_grouped(
bench_quantize: bool = False,
use_rotating_buffer_bench: bool = False,
use_cuda_graph: bool = True,
trace: bool = False,
) -> Dict[str, Any]:
num_groups = len(m)
# Create input tensors.
Expand Down Expand Up @@ -127,7 +138,8 @@ def benchmark_grouped(
# Also check if the operator is supported.
if kernel_requested and quantize_op.supported:
# Get the quantized tensors for this operator.
quantized_vals = quantize_op.quantize(A, B)
preprocessed_args = quantize_op.preprocess(A, B)
quantized_vals = quantize_op.quantize(*preprocessed_args)
# Compute the output given quantized values.
output = quantize_op.compute(*quantized_vals)
# Some kernels may pad output, just take the first m values of each row.
Expand All @@ -142,20 +154,21 @@ def benchmark_grouped(
# Now perform benchmark.
if bench_quantize:
# Benchmark both quantize and compute.
ms_runtime = quantize_op.benchmark(
A,
B,
bench_quantize=True,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
with profiler_or_nullcontext(enabled=trace, with_stack=True):
ms_runtime = quantize_op.benchmark(
*preprocessed_args,
bench_quantize=True,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
else:
ms_runtime = quantize_op.benchmark(
*quantized_vals,
bench_quantize=False,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
with profiler_or_nullcontext(enabled=trace, with_stack=True):
ms_runtime = quantize_op.benchmark(
*quantized_vals,
bench_quantize=False,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)

# Print out results for this op.
tflops = 0
Expand Down Expand Up @@ -197,6 +210,7 @@ def benchmark(
bench_quantize: bool = False,
use_rotating_buffer_bench: bool = False,
use_cuda_graph: bool = True,
trace: bool = False,
) -> Dict[str, Any]:
# Create input tensors.
if b > 1:
Expand All @@ -218,8 +232,10 @@ def benchmark(
)
# Also check if the operator is supported.
if kernel_requested and quantize_op.supported:
# Preprocess data if needed.
preprocessed_args = quantize_op.preprocess(A, B)
# Get the quantized tensors for this operator.
quantized_vals = quantize_op.quantize(A, B)
quantized_vals = quantize_op.quantize(*preprocessed_args)
# Compute the output given quantized values.
output = quantize_op.compute(*quantized_vals)
# Compare the quantize op output to reference as a sanity check.
Expand All @@ -228,20 +244,21 @@ def benchmark(
# Now perform benchmark.
if bench_quantize:
# Benchmark both quantize and compute.
ms_runtime = quantize_op.benchmark(
A,
B,
bench_quantize=True,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
with profiler_or_nullcontext(enabled=trace, with_stack=True):
ms_runtime = quantize_op.benchmark(
*preprocessed_args,
bench_quantize=True,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
else:
ms_runtime = quantize_op.benchmark(
*quantized_vals,
bench_quantize=False,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
with profiler_or_nullcontext(enabled=trace, with_stack=True):
ms_runtime = quantize_op.benchmark(
*quantized_vals,
bench_quantize=False,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)

# Print out results for this op.
tflops = 2 * b * m * n * k / (ms_runtime / 1e3) / 1e12
Expand Down Expand Up @@ -369,6 +386,7 @@ def main(args: Any):
args.bench_quantize,
args.use_rotating_buffer_bench,
not args.no_cuda_graph,
args.trace,
)
benchmark_results.append(quantize_measurements)
if args.export_csv:
Expand Down Expand Up @@ -459,6 +477,12 @@ def invoke_main() -> None:
action="store_true",
help="If set, benchmark using fixed shapes relevant to ldm workloads.",
)
parser.add_argument(
"--trace",
default=False,
action="store_true",
help="If set, produce a performance trace of the benchmark.",
)

args = parser.parse_args()
main(args)
Loading
Loading