diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py index 91cea3b158..6703d2771c 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py @@ -127,7 +127,8 @@ def benchmark_grouped( # Also check if the operator is supported. if kernel_requested and quantize_op.supported: # Get the quantized tensors for this operator. - quantized_vals = quantize_op.quantize(A, B) + preprocessed_args = quantize_op.preprocess(A, B) + quantized_vals = quantize_op.quantize(*preprocessed_args) # Compute the output given quantized values. output = quantize_op.compute(*quantized_vals) # Some kernels may pad output, just take the first m values of each row. @@ -143,8 +144,7 @@ def benchmark_grouped( if bench_quantize: # Benchmark both quantize and compute. ms_runtime = quantize_op.benchmark( - A, - B, + *preprocessed_args, bench_quantize=True, use_rotating_buffer_bench=use_rotating_buffer_bench, use_cuda_graph=use_cuda_graph, @@ -218,8 +218,10 @@ def benchmark( ) # Also check if the operator is supported. if kernel_requested and quantize_op.supported: + # Preprocess data if needed. + preprocessed_args = quantize_op.preprocess(A, B) # Get the quantized tensors for this operator. - quantized_vals = quantize_op.quantize(A, B) + quantized_vals = quantize_op.quantize(*preprocessed_args) # Compute the output given quantized values. output = quantize_op.compute(*quantized_vals) # Compare the quantize op output to reference as a sanity check. @@ -229,8 +231,7 @@ def benchmark( if bench_quantize: # Benchmark both quantize and compute. ms_runtime = quantize_op.benchmark( - A, - B, + *preprocessed_args, bench_quantize=True, use_rotating_buffer_bench=use_rotating_buffer_bench, use_cuda_graph=use_cuda_graph, diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index e867de8d2f..7759de6704 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -65,6 +65,10 @@ def quantize_and_compute(self, *args, **kwargs): """Function which quantizes inputs and performs main compute operation.""" pass + def preprocess(self, *args): + """Preprocess inputs before benchmarking. These outputs will be passed to quantize.""" + return args + def bench_with_rotating_buffer(self, fn, args, use_cuda_graph: bool = True): import copy import pickle @@ -113,8 +117,13 @@ def benchmark( ) -> float: """Benchmark runtime of this operator.""" if bench_quantize: - with torch.cuda.stream(torch.cuda.Stream()): - t = triton.testing.do_bench_cudagraph( + if use_cuda_graph: + with torch.cuda.stream(torch.cuda.Stream()): + t = triton.testing.do_bench_cudagraph( + lambda: self.quantize_and_compute(*args, **kwargs) + ) + else: + t = triton.testing.do_bench( lambda: self.quantize_and_compute(*args, **kwargs) ) else: @@ -468,57 +477,52 @@ class FP8RowwiseGroupedGemm(QuantizeOpBase): FP8 grouped matmul with rowwise scaling. """ - def quantize_fixed_nk(self, x, w): - group_size = len(x) - m_values = [i.shape[0] for i in x] - # Inputs for fixed nk mode must be contiguous, however in the benchmark - # script they typically are not. Do a little special processing to make them - # work. In practice this wont be needed. - # Start by padding along m dimension with zeros. - max_m = max(m_values) - xq = [ - torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0) - for i in x - ] - # Stack inputs into groups. - xq = torch.stack(xq).contiguous() - wq = torch.stack(w).contiguous() - # Apply quantization. - xq, x_scale = quantize_fp8_row(xq) - wq, w_scale = quantize_fp8_row(wq) - # View these unified tensors as lists of tensors. - xq = [x.squeeze() for x in xq.split(1, dim=0)] - wq = [w.squeeze() for w in wq.split(1, dim=0)] - x_scale = [xs.squeeze() for xs in x_scale.view(group_size, -1).split(1, dim=0)] - w_scale = [ws.squeeze() for ws in w_scale.view(group_size, -1).split(1, dim=0)] - - # Return processed tensors. - return ( - xq, - wq, - x_scale, - w_scale, - torch.tensor(m_values).to(dtype=torch.int64, device=xq[0].device), - ) - - def quantize(self, x, w): - assert isinstance( - x, (list, tuple) - ), "Inputs to group gemm must be a list of tensors." - + def preprocess(self, x, w): + # Apply sparsity to inputs if appropriate. # First check if N and K are fixed. m_values = [i.shape[0] for i in x] n_values = [i.shape[0] for i in w] k_values = [i.shape[1] for i in w] - # if so, do specialized version of initialization. + # If so, do specialized version of initialization. if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1: - return self.quantize_fixed_nk(x, w) - - # Otherwise handle in eager mode. - xq, x_scale = zip(*[quantize_fp8_row(i) for i in x]) + m_values = [i.shape[0] for i in x] + # Inputs for fixed nk mode must be contiguous, however in the benchmark + # script they typically are not. Do a little special processing to make them + # work. In practice this wont be needed. + # Start by padding along m dimension with zeros. + max_m = max(m_values) + x = [ + torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0) + for i in x + ] + # Stack inputs into groups. + x = torch.stack(x).contiguous() + w = torch.stack(w).contiguous() + + # Preapply weight quantization. + wq, w_scale = quantize_fp8_row(w) + # Return processed tensors. + return ( + x, + wq, + w_scale, + torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device), + ) + # Otherwise run without sparsity. wq, w_scale = zip(*[quantize_fp8_row(i) for i in w]) - m_values = None - return xq, wq, x_scale, w_scale, m_values + return x, wq, w_scale, None + + def quantize(self, x, wq, w_scale, m_values=None): + # Handle case where inputs are explicitly grouped and non-sparse. + if isinstance(x, (tuple, list)): + xq, x_scale = zip(*[quantize_fp8_row(i) for i in x]) + return xq, wq, x_scale, w_scale, m_values + # Otherwise inputs are unified tensors and sparse. + else: + B = x.shape[0] + xq, x_scale = quantize_fp8_row(x, zero_start_index_M=m_values) + x_scale = x_scale.view(B, -1) + return xq, wq, x_scale, w_scale, m_values def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None): if m_values is None: @@ -530,17 +534,23 @@ def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None): kernel_name=kernel_name, ) else: + # Break tensor into groups, simulates what is done e2e. + B = xq.shape[0] + xq_group = [xq[i, :, :] for i in range(B)] + x_scale_group = [x_scale[i, :] for i in range(B)] + wq_group = [wq[i, :, :] for i in range(B)] + w_scale_group = [w_scale[i, :] for i in range(B)] return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic( - xq, - wq, - x_scale, - w_scale, + xq_group, + wq_group, + x_scale_group, + w_scale_group, zero_start_index_M=m_values, kernel_name=kernel_name, ) - def quantize_and_compute(self, x, w): - xq, wq, x_scale, w_scale, m_values = self.quantize(x, w) + def quantize_and_compute(self, x, wq, w_scale, m_values=None): + xq, wq, x_scale, w_scale, m_values = self.quantize(x, wq, w_scale, m_values) return self.compute(xq, wq, x_scale, w_scale, m_values) @property @@ -565,55 +575,48 @@ class BF16GroupedGemm(QuantizeOpBase): BF16 grouped matmul implemented with CK or Cutlass. """ - def quantize_fixed_nk(self, x, w): - m_values = [i.shape[0] for i in x] - # Inputs for fixed nk mode must be contiguous, however in the benchmark - # script they typically are not. Do a little special processing to make them - # work. In practice this wont be needed. - # Start by padding along m dimension with zeros. - max_m = max(m_values) - xp = [ - torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0) - for i in x - ] - # Stack inputs into groups. - x = torch.stack(xp).contiguous() - w = torch.stack(w).contiguous() - # View these unified tensors as lists of tensors. - x = [xi.squeeze() for xi in x.split(1, dim=0)] - w = [wi.squeeze() for wi in w.split(1, dim=0)] - - # Return processed tensors. - return ( - x, - w, - torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device), - ) - - def quantize(self, x, w): - assert isinstance( - x, (list, tuple) - ), "Inputs to group gemm must be a list of tensors." - + def preprocess(self, x, w): + # Apply sparsity to inputs if appropriate. # First check if N and K are fixed. m_values = [i.shape[0] for i in x] n_values = [i.shape[0] for i in w] k_values = [i.shape[1] for i in w] - # if so, do specialized version of initialization. + # If so, do specialized version of initialization. if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1: - return self.quantize_fixed_nk(x, w) + m_values = [i.shape[0] for i in x] + # Inputs for fixed nk mode must be contiguous, however in the benchmark + # script they typically are not. Do a little special processing to make them + # work. In practice this wont be needed. + # Start by padding along m dimension with zeros. + max_m = max(m_values) + x = [ + torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0) + for i in x + ] + # Stack inputs into groups. + x = torch.stack(x).contiguous() + w = torch.stack(w).contiguous() + return ( + x, + w, + torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device), + ) + return x, w, None - m_values = None + def quantize(self, x, w, m_values=None): + # No action required. return x, w, m_values def compute(self, x, w, m_values): if m_values is None: return torch.ops.fbgemm.bf16bf16bf16_grouped(x, w) else: + B = x.shape[0] + x = [x[i, :, :] for i in range(B)] + w = [w[i, :, :] for i in range(B)] return torch.ops.fbgemm.bf16bf16bf16_grouped_dynamic(x, w, m_values) - def quantize_and_compute(self, x, w): - x, w, m_values = self.quantize(x, w) + def quantize_and_compute(self, x, w, m_values): return self.compute(x, w, m_values) @property