From bf69f6cab2a025a1f664a9d66142f609a6907305 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Sun, 2 Feb 2025 17:12:54 -0800 Subject: [PATCH] Add preprocess stage to quantize bench operators (#3648) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/724 When benchmarking quantize functions, we'd like the overhead to mimic e2e behavior as closely as possible. For example, weights should be quantized ahead of time. The current design of quantize_bench does not allow this. To accomodate it, I've added a new optional preprocess phase that allows some transformations to be applied independently from benchmarking. Here we use it to prepare data for grouped gemm benchmarks to more accurately capture the e2e behavior. Reviewed By: jiawenliu64 Differential Revision: D68964950 --- .../gen_ai/bench/quantize_bench.py | 13 +- .../experimental/gen_ai/bench/quantize_ops.py | 179 +++++++++--------- 2 files changed, 98 insertions(+), 94 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py index 91cea3b158..6703d2771c 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py @@ -127,7 +127,8 @@ def benchmark_grouped( # Also check if the operator is supported. if kernel_requested and quantize_op.supported: # Get the quantized tensors for this operator. - quantized_vals = quantize_op.quantize(A, B) + preprocessed_args = quantize_op.preprocess(A, B) + quantized_vals = quantize_op.quantize(*preprocessed_args) # Compute the output given quantized values. output = quantize_op.compute(*quantized_vals) # Some kernels may pad output, just take the first m values of each row. @@ -143,8 +144,7 @@ def benchmark_grouped( if bench_quantize: # Benchmark both quantize and compute. ms_runtime = quantize_op.benchmark( - A, - B, + *preprocessed_args, bench_quantize=True, use_rotating_buffer_bench=use_rotating_buffer_bench, use_cuda_graph=use_cuda_graph, @@ -218,8 +218,10 @@ def benchmark( ) # Also check if the operator is supported. if kernel_requested and quantize_op.supported: + # Preprocess data if needed. + preprocessed_args = quantize_op.preprocess(A, B) # Get the quantized tensors for this operator. - quantized_vals = quantize_op.quantize(A, B) + quantized_vals = quantize_op.quantize(*preprocessed_args) # Compute the output given quantized values. output = quantize_op.compute(*quantized_vals) # Compare the quantize op output to reference as a sanity check. @@ -229,8 +231,7 @@ def benchmark( if bench_quantize: # Benchmark both quantize and compute. ms_runtime = quantize_op.benchmark( - A, - B, + *preprocessed_args, bench_quantize=True, use_rotating_buffer_bench=use_rotating_buffer_bench, use_cuda_graph=use_cuda_graph, diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index e867de8d2f..7759de6704 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -65,6 +65,10 @@ def quantize_and_compute(self, *args, **kwargs): """Function which quantizes inputs and performs main compute operation.""" pass + def preprocess(self, *args): + """Preprocess inputs before benchmarking. These outputs will be passed to quantize.""" + return args + def bench_with_rotating_buffer(self, fn, args, use_cuda_graph: bool = True): import copy import pickle @@ -113,8 +117,13 @@ def benchmark( ) -> float: """Benchmark runtime of this operator.""" if bench_quantize: - with torch.cuda.stream(torch.cuda.Stream()): - t = triton.testing.do_bench_cudagraph( + if use_cuda_graph: + with torch.cuda.stream(torch.cuda.Stream()): + t = triton.testing.do_bench_cudagraph( + lambda: self.quantize_and_compute(*args, **kwargs) + ) + else: + t = triton.testing.do_bench( lambda: self.quantize_and_compute(*args, **kwargs) ) else: @@ -468,57 +477,52 @@ class FP8RowwiseGroupedGemm(QuantizeOpBase): FP8 grouped matmul with rowwise scaling. """ - def quantize_fixed_nk(self, x, w): - group_size = len(x) - m_values = [i.shape[0] for i in x] - # Inputs for fixed nk mode must be contiguous, however in the benchmark - # script they typically are not. Do a little special processing to make them - # work. In practice this wont be needed. - # Start by padding along m dimension with zeros. - max_m = max(m_values) - xq = [ - torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0) - for i in x - ] - # Stack inputs into groups. - xq = torch.stack(xq).contiguous() - wq = torch.stack(w).contiguous() - # Apply quantization. - xq, x_scale = quantize_fp8_row(xq) - wq, w_scale = quantize_fp8_row(wq) - # View these unified tensors as lists of tensors. - xq = [x.squeeze() for x in xq.split(1, dim=0)] - wq = [w.squeeze() for w in wq.split(1, dim=0)] - x_scale = [xs.squeeze() for xs in x_scale.view(group_size, -1).split(1, dim=0)] - w_scale = [ws.squeeze() for ws in w_scale.view(group_size, -1).split(1, dim=0)] - - # Return processed tensors. - return ( - xq, - wq, - x_scale, - w_scale, - torch.tensor(m_values).to(dtype=torch.int64, device=xq[0].device), - ) - - def quantize(self, x, w): - assert isinstance( - x, (list, tuple) - ), "Inputs to group gemm must be a list of tensors." - + def preprocess(self, x, w): + # Apply sparsity to inputs if appropriate. # First check if N and K are fixed. m_values = [i.shape[0] for i in x] n_values = [i.shape[0] for i in w] k_values = [i.shape[1] for i in w] - # if so, do specialized version of initialization. + # If so, do specialized version of initialization. if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1: - return self.quantize_fixed_nk(x, w) - - # Otherwise handle in eager mode. - xq, x_scale = zip(*[quantize_fp8_row(i) for i in x]) + m_values = [i.shape[0] for i in x] + # Inputs for fixed nk mode must be contiguous, however in the benchmark + # script they typically are not. Do a little special processing to make them + # work. In practice this wont be needed. + # Start by padding along m dimension with zeros. + max_m = max(m_values) + x = [ + torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0) + for i in x + ] + # Stack inputs into groups. + x = torch.stack(x).contiguous() + w = torch.stack(w).contiguous() + + # Preapply weight quantization. + wq, w_scale = quantize_fp8_row(w) + # Return processed tensors. + return ( + x, + wq, + w_scale, + torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device), + ) + # Otherwise run without sparsity. wq, w_scale = zip(*[quantize_fp8_row(i) for i in w]) - m_values = None - return xq, wq, x_scale, w_scale, m_values + return x, wq, w_scale, None + + def quantize(self, x, wq, w_scale, m_values=None): + # Handle case where inputs are explicitly grouped and non-sparse. + if isinstance(x, (tuple, list)): + xq, x_scale = zip(*[quantize_fp8_row(i) for i in x]) + return xq, wq, x_scale, w_scale, m_values + # Otherwise inputs are unified tensors and sparse. + else: + B = x.shape[0] + xq, x_scale = quantize_fp8_row(x, zero_start_index_M=m_values) + x_scale = x_scale.view(B, -1) + return xq, wq, x_scale, w_scale, m_values def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None): if m_values is None: @@ -530,17 +534,23 @@ def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None): kernel_name=kernel_name, ) else: + # Break tensor into groups, simulates what is done e2e. + B = xq.shape[0] + xq_group = [xq[i, :, :] for i in range(B)] + x_scale_group = [x_scale[i, :] for i in range(B)] + wq_group = [wq[i, :, :] for i in range(B)] + w_scale_group = [w_scale[i, :] for i in range(B)] return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic( - xq, - wq, - x_scale, - w_scale, + xq_group, + wq_group, + x_scale_group, + w_scale_group, zero_start_index_M=m_values, kernel_name=kernel_name, ) - def quantize_and_compute(self, x, w): - xq, wq, x_scale, w_scale, m_values = self.quantize(x, w) + def quantize_and_compute(self, x, wq, w_scale, m_values=None): + xq, wq, x_scale, w_scale, m_values = self.quantize(x, wq, w_scale, m_values) return self.compute(xq, wq, x_scale, w_scale, m_values) @property @@ -565,55 +575,48 @@ class BF16GroupedGemm(QuantizeOpBase): BF16 grouped matmul implemented with CK or Cutlass. """ - def quantize_fixed_nk(self, x, w): - m_values = [i.shape[0] for i in x] - # Inputs for fixed nk mode must be contiguous, however in the benchmark - # script they typically are not. Do a little special processing to make them - # work. In practice this wont be needed. - # Start by padding along m dimension with zeros. - max_m = max(m_values) - xp = [ - torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0) - for i in x - ] - # Stack inputs into groups. - x = torch.stack(xp).contiguous() - w = torch.stack(w).contiguous() - # View these unified tensors as lists of tensors. - x = [xi.squeeze() for xi in x.split(1, dim=0)] - w = [wi.squeeze() for wi in w.split(1, dim=0)] - - # Return processed tensors. - return ( - x, - w, - torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device), - ) - - def quantize(self, x, w): - assert isinstance( - x, (list, tuple) - ), "Inputs to group gemm must be a list of tensors." - + def preprocess(self, x, w): + # Apply sparsity to inputs if appropriate. # First check if N and K are fixed. m_values = [i.shape[0] for i in x] n_values = [i.shape[0] for i in w] k_values = [i.shape[1] for i in w] - # if so, do specialized version of initialization. + # If so, do specialized version of initialization. if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1: - return self.quantize_fixed_nk(x, w) + m_values = [i.shape[0] for i in x] + # Inputs for fixed nk mode must be contiguous, however in the benchmark + # script they typically are not. Do a little special processing to make them + # work. In practice this wont be needed. + # Start by padding along m dimension with zeros. + max_m = max(m_values) + x = [ + torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0) + for i in x + ] + # Stack inputs into groups. + x = torch.stack(x).contiguous() + w = torch.stack(w).contiguous() + return ( + x, + w, + torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device), + ) + return x, w, None - m_values = None + def quantize(self, x, w, m_values=None): + # No action required. return x, w, m_values def compute(self, x, w, m_values): if m_values is None: return torch.ops.fbgemm.bf16bf16bf16_grouped(x, w) else: + B = x.shape[0] + x = [x[i, :, :] for i in range(B)] + w = [w[i, :, :] for i in range(B)] return torch.ops.fbgemm.bf16bf16bf16_grouped_dynamic(x, w, m_values) - def quantize_and_compute(self, x, w): - x, w, m_values = self.quantize(x, w) + def quantize_and_compute(self, x, w, m_values): return self.compute(x, w, m_values) @property