diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py index 6703d2771..7eeff8a8c 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py @@ -16,6 +16,16 @@ import seaborn as sns import torch +try: + from accelerators.pytorch.lib.utils.torch_profiler import profiler_or_nullcontext +except ImportError: + from contextlib import nullcontext + + class profiler_or_nullcontext(nullcontext): + def __init__(self, *args, **kwargs): + super().__init__() + + from .quantize_ops import get_quantize_ops, QuantizeOpBase @@ -96,6 +106,7 @@ def benchmark_grouped( bench_quantize: bool = False, use_rotating_buffer_bench: bool = False, use_cuda_graph: bool = True, + trace: bool = False, ) -> Dict[str, Any]: num_groups = len(m) # Create input tensors. @@ -143,19 +154,21 @@ def benchmark_grouped( # Now perform benchmark. if bench_quantize: # Benchmark both quantize and compute. - ms_runtime = quantize_op.benchmark( - *preprocessed_args, - bench_quantize=True, - use_rotating_buffer_bench=use_rotating_buffer_bench, - use_cuda_graph=use_cuda_graph, - ) + with profiler_or_nullcontext(enabled=trace, with_stack=True): + ms_runtime = quantize_op.benchmark( + *preprocessed_args, + bench_quantize=True, + use_rotating_buffer_bench=use_rotating_buffer_bench, + use_cuda_graph=use_cuda_graph, + ) else: - ms_runtime = quantize_op.benchmark( - *quantized_vals, - bench_quantize=False, - use_rotating_buffer_bench=use_rotating_buffer_bench, - use_cuda_graph=use_cuda_graph, - ) + with profiler_or_nullcontext(enabled=trace, with_stack=True): + ms_runtime = quantize_op.benchmark( + *quantized_vals, + bench_quantize=False, + use_rotating_buffer_bench=use_rotating_buffer_bench, + use_cuda_graph=use_cuda_graph, + ) # Print out results for this op. tflops = 0 @@ -197,6 +210,7 @@ def benchmark( bench_quantize: bool = False, use_rotating_buffer_bench: bool = False, use_cuda_graph: bool = True, + trace: bool = False, ) -> Dict[str, Any]: # Create input tensors. if b > 1: @@ -230,19 +244,21 @@ def benchmark( # Now perform benchmark. if bench_quantize: # Benchmark both quantize and compute. - ms_runtime = quantize_op.benchmark( - *preprocessed_args, - bench_quantize=True, - use_rotating_buffer_bench=use_rotating_buffer_bench, - use_cuda_graph=use_cuda_graph, - ) + with profiler_or_nullcontext(enabled=trace, with_stack=True): + ms_runtime = quantize_op.benchmark( + *preprocessed_args, + bench_quantize=True, + use_rotating_buffer_bench=use_rotating_buffer_bench, + use_cuda_graph=use_cuda_graph, + ) else: - ms_runtime = quantize_op.benchmark( - *quantized_vals, - bench_quantize=False, - use_rotating_buffer_bench=use_rotating_buffer_bench, - use_cuda_graph=use_cuda_graph, - ) + with profiler_or_nullcontext(enabled=trace, with_stack=True): + ms_runtime = quantize_op.benchmark( + *quantized_vals, + bench_quantize=False, + use_rotating_buffer_bench=use_rotating_buffer_bench, + use_cuda_graph=use_cuda_graph, + ) # Print out results for this op. tflops = 2 * b * m * n * k / (ms_runtime / 1e3) / 1e12 @@ -370,6 +386,7 @@ def main(args: Any): args.bench_quantize, args.use_rotating_buffer_bench, not args.no_cuda_graph, + args.trace, ) benchmark_results.append(quantize_measurements) if args.export_csv: @@ -460,6 +477,12 @@ def invoke_main() -> None: action="store_true", help="If set, benchmark using fixed shapes relevant to ldm workloads.", ) + parser.add_argument( + "--trace", + default=False, + action="store_true", + help="If set, produce a performance trace of the benchmark.", + ) args = parser.parse_args() main(args) diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index 7759de670..22ae34226 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -411,20 +411,26 @@ class FP8RowwiseGemm(QuantizeOpBase): FP8 matmul with rowwise scaling. """ - def quantize(self, x, w): + def preprocess(self, x, w): + # Prequantize weights. + if isinstance(w, (list, tuple)): + wq, w_scale = zip(*[quantize_fp8_row(i) for i in w]) + else: + wq, w_scale = quantize_fp8_row(w) + if wq.dim() == 3: + w_scale = w_scale.view(wq.size(0), -1) + return x, wq, w_scale + + def quantize(self, x, wq, w_scale): # Quantize both input tensors. # Handle both grouped and standard gemm. if isinstance(x, (list, tuple)): xq, x_scale = zip(*[quantize_fp8_row(i) for i in x]) - wq, w_scale = zip(*[quantize_fp8_row(i) for i in w]) else: xq, x_scale = quantize_fp8_row(x) - wq, w_scale = quantize_fp8_row(w) # Set proper batch dimension shapes. if xq.dim() == 3: x_scale = x_scale.view(xq.size(0), -1) - if wq.dim() == 3: - w_scale = w_scale.view(wq.size(0), -1) return xq, wq, x_scale, w_scale def compute(self, xq, wq, x_scale, w_scale): @@ -451,8 +457,8 @@ def compute(self, xq, wq, x_scale, w_scale): # Otherwise return normal gemm result. return torch.ops.fbgemm.f8f8bf16_rowwise(xq, wq, x_scale, w_scale) - def quantize_and_compute(self, x, w): - xq, wq, x_scale, w_scale = self.quantize(x, w) + def quantize_and_compute(self, x, wq, w_scale): + xq, wq, x_scale, w_scale = self.quantize(x, wq, w_scale) return self.compute(xq, wq, x_scale, w_scale) @property