diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index bf8d52cb1e..4a6111510c 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -2447,6 +2447,13 @@ def triton_quantize_fp8_row( torch.Tensor: fp8 scaled tensor. torch.Tensor: reciprocal scale tensor per row. """ + assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor." + a_shape = a.shape + while a.dim() < 4: + a = a.unsqueeze(0) + if zero_start_index_M is not None: + # There should be one value of zero_start_index_M per NxK matrix. + zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1]) # Get constant values. pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants() num_rows = a.numel() // a.shape[-1] @@ -2484,7 +2491,7 @@ def triton_quantize_fp8_row( USE_INT64=use_int64, ) - return a_fp8, a_scale + return a_fp8.view(a_shape), a_scale.view(a_shape[:-1]) @torch.library.custom_op("triton::quantize_fp8_row", mutates_args=()) @@ -2514,17 +2521,7 @@ def quantize_fp8_row( logger.info("Triton does not support cpu, falling back to torch ops.") use_triton = False if use_triton: - assert ( - a.dim() <= 4 - ), "Only up to 4 dimension input tensor is supported if use_triton is True" - a_shape = a.shape - while a.dim() < 4: - a = a.unsqueeze(0) - if zero_start_index_M is not None: - # There should be one value of zero_start_index_M per NxK matrix. - zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1]) - a_fp8, a_scale = triton_quantize_fp8_row(a, scale_ub, zero_start_index_M) - return a_fp8.view(a_shape), a_scale.view(a_shape[:-1]) + return triton_quantize_fp8_row(a, scale_ub, zero_start_index_M) # else use pytorch implementation. if not output_device: output_device = a.device diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py index 91cea3b158..7eeff8a8c0 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py @@ -16,6 +16,16 @@ import seaborn as sns import torch +try: + from accelerators.pytorch.lib.utils.torch_profiler import profiler_or_nullcontext +except ImportError: + from contextlib import nullcontext + + class profiler_or_nullcontext(nullcontext): + def __init__(self, *args, **kwargs): + super().__init__() + + from .quantize_ops import get_quantize_ops, QuantizeOpBase @@ -96,6 +106,7 @@ def benchmark_grouped( bench_quantize: bool = False, use_rotating_buffer_bench: bool = False, use_cuda_graph: bool = True, + trace: bool = False, ) -> Dict[str, Any]: num_groups = len(m) # Create input tensors. @@ -127,7 +138,8 @@ def benchmark_grouped( # Also check if the operator is supported. if kernel_requested and quantize_op.supported: # Get the quantized tensors for this operator. - quantized_vals = quantize_op.quantize(A, B) + preprocessed_args = quantize_op.preprocess(A, B) + quantized_vals = quantize_op.quantize(*preprocessed_args) # Compute the output given quantized values. output = quantize_op.compute(*quantized_vals) # Some kernels may pad output, just take the first m values of each row. @@ -142,20 +154,21 @@ def benchmark_grouped( # Now perform benchmark. if bench_quantize: # Benchmark both quantize and compute. - ms_runtime = quantize_op.benchmark( - A, - B, - bench_quantize=True, - use_rotating_buffer_bench=use_rotating_buffer_bench, - use_cuda_graph=use_cuda_graph, - ) + with profiler_or_nullcontext(enabled=trace, with_stack=True): + ms_runtime = quantize_op.benchmark( + *preprocessed_args, + bench_quantize=True, + use_rotating_buffer_bench=use_rotating_buffer_bench, + use_cuda_graph=use_cuda_graph, + ) else: - ms_runtime = quantize_op.benchmark( - *quantized_vals, - bench_quantize=False, - use_rotating_buffer_bench=use_rotating_buffer_bench, - use_cuda_graph=use_cuda_graph, - ) + with profiler_or_nullcontext(enabled=trace, with_stack=True): + ms_runtime = quantize_op.benchmark( + *quantized_vals, + bench_quantize=False, + use_rotating_buffer_bench=use_rotating_buffer_bench, + use_cuda_graph=use_cuda_graph, + ) # Print out results for this op. tflops = 0 @@ -197,6 +210,7 @@ def benchmark( bench_quantize: bool = False, use_rotating_buffer_bench: bool = False, use_cuda_graph: bool = True, + trace: bool = False, ) -> Dict[str, Any]: # Create input tensors. if b > 1: @@ -218,8 +232,10 @@ def benchmark( ) # Also check if the operator is supported. if kernel_requested and quantize_op.supported: + # Preprocess data if needed. + preprocessed_args = quantize_op.preprocess(A, B) # Get the quantized tensors for this operator. - quantized_vals = quantize_op.quantize(A, B) + quantized_vals = quantize_op.quantize(*preprocessed_args) # Compute the output given quantized values. output = quantize_op.compute(*quantized_vals) # Compare the quantize op output to reference as a sanity check. @@ -228,20 +244,21 @@ def benchmark( # Now perform benchmark. if bench_quantize: # Benchmark both quantize and compute. - ms_runtime = quantize_op.benchmark( - A, - B, - bench_quantize=True, - use_rotating_buffer_bench=use_rotating_buffer_bench, - use_cuda_graph=use_cuda_graph, - ) + with profiler_or_nullcontext(enabled=trace, with_stack=True): + ms_runtime = quantize_op.benchmark( + *preprocessed_args, + bench_quantize=True, + use_rotating_buffer_bench=use_rotating_buffer_bench, + use_cuda_graph=use_cuda_graph, + ) else: - ms_runtime = quantize_op.benchmark( - *quantized_vals, - bench_quantize=False, - use_rotating_buffer_bench=use_rotating_buffer_bench, - use_cuda_graph=use_cuda_graph, - ) + with profiler_or_nullcontext(enabled=trace, with_stack=True): + ms_runtime = quantize_op.benchmark( + *quantized_vals, + bench_quantize=False, + use_rotating_buffer_bench=use_rotating_buffer_bench, + use_cuda_graph=use_cuda_graph, + ) # Print out results for this op. tflops = 2 * b * m * n * k / (ms_runtime / 1e3) / 1e12 @@ -369,6 +386,7 @@ def main(args: Any): args.bench_quantize, args.use_rotating_buffer_bench, not args.no_cuda_graph, + args.trace, ) benchmark_results.append(quantize_measurements) if args.export_csv: @@ -459,6 +477,12 @@ def invoke_main() -> None: action="store_true", help="If set, benchmark using fixed shapes relevant to ldm workloads.", ) + parser.add_argument( + "--trace", + default=False, + action="store_true", + help="If set, produce a performance trace of the benchmark.", + ) args = parser.parse_args() main(args) diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index e867de8d2f..b83d0fe6d3 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -65,6 +65,10 @@ def quantize_and_compute(self, *args, **kwargs): """Function which quantizes inputs and performs main compute operation.""" pass + def preprocess(self, *args): + """Preprocess inputs before benchmarking. These outputs will be passed to quantize.""" + return args + def bench_with_rotating_buffer(self, fn, args, use_cuda_graph: bool = True): import copy import pickle @@ -113,8 +117,13 @@ def benchmark( ) -> float: """Benchmark runtime of this operator.""" if bench_quantize: - with torch.cuda.stream(torch.cuda.Stream()): - t = triton.testing.do_bench_cudagraph( + if use_cuda_graph: + with torch.cuda.stream(torch.cuda.Stream()): + t = triton.testing.do_bench_cudagraph( + lambda: self.quantize_and_compute(*args, **kwargs) + ) + else: + t = triton.testing.do_bench( lambda: self.quantize_and_compute(*args, **kwargs) ) else: @@ -402,20 +411,26 @@ class FP8RowwiseGemm(QuantizeOpBase): FP8 matmul with rowwise scaling. """ - def quantize(self, x, w): + def preprocess(self, x, w): + # Prequantize weights. + if isinstance(w, (list, tuple)): + wq, w_scale = zip(*[quantize_fp8_row(i) for i in w]) + else: + wq, w_scale = quantize_fp8_row(w) + if wq.dim() == 3: + w_scale = w_scale.view(wq.size(0), -1) + return x, wq, w_scale + + def quantize(self, x, wq, w_scale): # Quantize both input tensors. # Handle both grouped and standard gemm. if isinstance(x, (list, tuple)): xq, x_scale = zip(*[quantize_fp8_row(i) for i in x]) - wq, w_scale = zip(*[quantize_fp8_row(i) for i in w]) else: xq, x_scale = quantize_fp8_row(x) - wq, w_scale = quantize_fp8_row(w) # Set proper batch dimension shapes. if xq.dim() == 3: x_scale = x_scale.view(xq.size(0), -1) - if wq.dim() == 3: - w_scale = w_scale.view(wq.size(0), -1) return xq, wq, x_scale, w_scale def compute(self, xq, wq, x_scale, w_scale): @@ -442,8 +457,8 @@ def compute(self, xq, wq, x_scale, w_scale): # Otherwise return normal gemm result. return torch.ops.fbgemm.f8f8bf16_rowwise(xq, wq, x_scale, w_scale) - def quantize_and_compute(self, x, w): - xq, wq, x_scale, w_scale = self.quantize(x, w) + def quantize_and_compute(self, x, wq, w_scale): + xq, wq, x_scale, w_scale = self.quantize(x, wq, w_scale) return self.compute(xq, wq, x_scale, w_scale) @property @@ -468,57 +483,52 @@ class FP8RowwiseGroupedGemm(QuantizeOpBase): FP8 grouped matmul with rowwise scaling. """ - def quantize_fixed_nk(self, x, w): - group_size = len(x) - m_values = [i.shape[0] for i in x] - # Inputs for fixed nk mode must be contiguous, however in the benchmark - # script they typically are not. Do a little special processing to make them - # work. In practice this wont be needed. - # Start by padding along m dimension with zeros. - max_m = max(m_values) - xq = [ - torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0) - for i in x - ] - # Stack inputs into groups. - xq = torch.stack(xq).contiguous() - wq = torch.stack(w).contiguous() - # Apply quantization. - xq, x_scale = quantize_fp8_row(xq) - wq, w_scale = quantize_fp8_row(wq) - # View these unified tensors as lists of tensors. - xq = [x.squeeze() for x in xq.split(1, dim=0)] - wq = [w.squeeze() for w in wq.split(1, dim=0)] - x_scale = [xs.squeeze() for xs in x_scale.view(group_size, -1).split(1, dim=0)] - w_scale = [ws.squeeze() for ws in w_scale.view(group_size, -1).split(1, dim=0)] - - # Return processed tensors. - return ( - xq, - wq, - x_scale, - w_scale, - torch.tensor(m_values).to(dtype=torch.int64, device=xq[0].device), - ) - - def quantize(self, x, w): - assert isinstance( - x, (list, tuple) - ), "Inputs to group gemm must be a list of tensors." - + def preprocess(self, x, w): + # Apply sparsity to inputs if appropriate. # First check if N and K are fixed. m_values = [i.shape[0] for i in x] n_values = [i.shape[0] for i in w] k_values = [i.shape[1] for i in w] - # if so, do specialized version of initialization. + # If so, do specialized version of initialization. if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1: - return self.quantize_fixed_nk(x, w) - - # Otherwise handle in eager mode. - xq, x_scale = zip(*[quantize_fp8_row(i) for i in x]) + m_values = [i.shape[0] for i in x] + # Inputs for fixed nk mode must be contiguous, however in the benchmark + # script they typically are not. Do a little special processing to make them + # work. In practice this wont be needed. + # Start by padding along m dimension with zeros. + max_m = max(m_values) + x = [ + torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0) + for i in x + ] + # Stack inputs into groups. + x = torch.stack(x).contiguous() + w = torch.stack(w).contiguous() + + # Preapply weight quantization. + wq, w_scale = quantize_fp8_row(w) + # Return processed tensors. + return ( + x, + wq, + w_scale, + torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device), + ) + # Otherwise run without sparsity. wq, w_scale = zip(*[quantize_fp8_row(i) for i in w]) - m_values = None - return xq, wq, x_scale, w_scale, m_values + return x, wq, w_scale, None + + def quantize(self, x, wq, w_scale, m_values=None): + # Handle case where inputs are explicitly grouped and non-sparse. + if isinstance(x, (tuple, list)): + xq, x_scale = zip(*[quantize_fp8_row(i) for i in x]) + return xq, wq, x_scale, w_scale, m_values + # Otherwise inputs are unified tensors and sparse. + else: + B = x.shape[0] + xq, x_scale = quantize_fp8_row(x, zero_start_index_M=m_values) + x_scale = x_scale.view(B, -1) + return xq, wq, x_scale, w_scale, m_values def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None): if m_values is None: @@ -530,6 +540,7 @@ def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None): kernel_name=kernel_name, ) else: + # Break tensor into groups, simulates what is done e2e. return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic( xq, wq, @@ -539,8 +550,8 @@ def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None): kernel_name=kernel_name, ) - def quantize_and_compute(self, x, w): - xq, wq, x_scale, w_scale, m_values = self.quantize(x, w) + def quantize_and_compute(self, x, wq, w_scale, m_values=None): + xq, wq, x_scale, w_scale, m_values = self.quantize(x, wq, w_scale, m_values) return self.compute(xq, wq, x_scale, w_scale, m_values) @property @@ -565,55 +576,48 @@ class BF16GroupedGemm(QuantizeOpBase): BF16 grouped matmul implemented with CK or Cutlass. """ - def quantize_fixed_nk(self, x, w): - m_values = [i.shape[0] for i in x] - # Inputs for fixed nk mode must be contiguous, however in the benchmark - # script they typically are not. Do a little special processing to make them - # work. In practice this wont be needed. - # Start by padding along m dimension with zeros. - max_m = max(m_values) - xp = [ - torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0) - for i in x - ] - # Stack inputs into groups. - x = torch.stack(xp).contiguous() - w = torch.stack(w).contiguous() - # View these unified tensors as lists of tensors. - x = [xi.squeeze() for xi in x.split(1, dim=0)] - w = [wi.squeeze() for wi in w.split(1, dim=0)] - - # Return processed tensors. - return ( - x, - w, - torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device), - ) - - def quantize(self, x, w): - assert isinstance( - x, (list, tuple) - ), "Inputs to group gemm must be a list of tensors." - + def preprocess(self, x, w): + # Apply sparsity to inputs if appropriate. # First check if N and K are fixed. m_values = [i.shape[0] for i in x] n_values = [i.shape[0] for i in w] k_values = [i.shape[1] for i in w] - # if so, do specialized version of initialization. + # If so, do specialized version of initialization. if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1: - return self.quantize_fixed_nk(x, w) + m_values = [i.shape[0] for i in x] + # Inputs for fixed nk mode must be contiguous, however in the benchmark + # script they typically are not. Do a little special processing to make them + # work. In practice this wont be needed. + # Start by padding along m dimension with zeros. + max_m = max(m_values) + x = [ + torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0) + for i in x + ] + # Stack inputs into groups. + x = torch.stack(x).contiguous() + w = torch.stack(w).contiguous() + return ( + x, + w, + torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device), + ) + return x, w, None - m_values = None + def quantize(self, x, w, m_values=None): + # No action required. return x, w, m_values def compute(self, x, w, m_values): if m_values is None: return torch.ops.fbgemm.bf16bf16bf16_grouped(x, w) else: + B = x.shape[0] + x = [x[i, :, :] for i in range(B)] + w = [w[i, :, :] for i in range(B)] return torch.ops.fbgemm.bf16bf16bf16_grouped_dynamic(x, w, m_values) - def quantize_and_compute(self, x, w): - x, w, m_values = self.quantize(x, w) + def quantize_and_compute(self, x, w, m_values): return self.compute(x, w, m_values) @property diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu index 437d3d8bf5..bcdbaac526 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu @@ -303,6 +303,7 @@ __global__ void set_dynamic_kernel_args_kernel( } template < + typename InputType, int TB_M, int TB_N, int TB_K, @@ -310,17 +311,26 @@ template < int TBS_N, int TBS_K, bool PONG> -std::tuple> f8f8bf16_rowwise_grouped_impl( - at::TensorList XQ, // FP8 - at::TensorList WQ, // FP8 - at::TensorList x_scale, - at::TensorList w_scale, +at::Tensor f8f8bf16_rowwise_grouped_impl( + InputType XQ, // FP8 + InputType WQ, // FP8 + InputType x_scale, + InputType w_scale, at::Tensor output, std::optional zero_start_index_M) { - int group_count = XQ.size(); - TORCH_CHECK(WQ.size() == group_count); + int group_count; + at::TensorOptions options; + if constexpr (std::is_same_v) { + group_count = XQ.size(); + options = XQ[0].options(); + TORCH_CHECK(WQ.size() == group_count); + } else { + group_count = XQ.size(0); + options = XQ.options(); + TORCH_CHECK(WQ.size(0) == group_count); + } if (group_count == 0) { - return {at::Tensor(), std::vector()}; + return at::Tensor(); } using GroupedGemmConfigs = GroupedGemmArgs:: GroupedGemmConfigs; @@ -328,8 +338,7 @@ std::tuple> f8f8bf16_rowwise_grouped_impl( int64_t total_output_size = 0; std::vector output_sizes; output_sizes.reserve(group_count); - at::Tensor output_args = - at::empty({group_count}, XQ[0].options().dtype(at::kLong)); + at::Tensor output_args = at::empty({group_count}, options.dtype(at::kLong)); const int64_t problem_shape_size = group_count * ((int64_t)sizeof(GroupedGemmArgs::ProblemShape::UnderlyingProblemShape)); @@ -341,7 +350,7 @@ std::tuple> f8f8bf16_rowwise_grouped_impl( // number at::Tensor input_args = at::empty( {group_count * 4 + problem_shape_size + stride_size * 3 + 1000}, - XQ[0].options().dtype(at::kLong)); + options.dtype(at::kLong)); int xq_ptr_offset = 0; int wq_ptr_offset = group_count * sizeof(int64_t); @@ -351,10 +360,18 @@ std::tuple> f8f8bf16_rowwise_grouped_impl( int stride_buf_offset = group_count * 4 * sizeof(int64_t) + problem_shape_size; - for (int i = 0; i < group_count; ++i) { - const int64_t output_size = XQ[i].size(0) * WQ[i].size(0); - total_output_size += output_size; - output_sizes.push_back(output_size); + if constexpr (std::is_same_v) { + for (int i = 0; i < group_count; ++i) { + const int64_t output_size = XQ[i].size(0) * WQ[i].size(0); + total_output_size += output_size; + output_sizes.push_back(output_size); + } + } else { + // When inputs are pregrouped, output has G * M * N elements. + total_output_size = group_count * XQ.size(1) * WQ.size(1); + for (int i = 0; i < group_count; ++i) { + output_sizes.push_back(XQ.size(1) * WQ.size(1)); + } } int blockSize = 256; @@ -369,17 +386,39 @@ std::tuple> f8f8bf16_rowwise_grouped_impl( // Set arguments for (int i = 0; i < group_count; ++i) { - int N = WQ[i].size(0); - int K = XQ[i].size(1); - TORCH_CHECK_EQ(WQ[i].size(1), K); + int M, N, K; + int64_t xq_ptr, wq_ptr, x_scale_ptr, w_scale_ptr; + // Compute buffer pointers based on input type. + if constexpr (std::is_same_v) { + M = XQ[i].size(0); + N = WQ[i].size(0); + K = XQ[i].size(1); + TORCH_CHECK_EQ(WQ[i].size(1), K); + // Calculate data pointer for this group. + xq_ptr = reinterpret_cast(XQ[i].data_ptr()); + wq_ptr = reinterpret_cast(WQ[i].data_ptr()); + x_scale_ptr = reinterpret_cast(x_scale[i].data_ptr()); + w_scale_ptr = reinterpret_cast(w_scale[i].data_ptr()); + } else { + M = XQ.size(1); + N = WQ.size(1); + K = XQ.size(2); + // Calculate data pointer for this group. + xq_ptr = reinterpret_cast(XQ.data_ptr()) + + (i * M * K * sizeof(GroupedGemmArgs::ElementInputA)); + wq_ptr = reinterpret_cast(WQ.data_ptr()) + + (i * N * K * sizeof(GroupedGemmArgs::ElementInputB)); + x_scale_ptr = reinterpret_cast(x_scale.data_ptr()) + + (i * M * sizeof(GroupedGemmArgs::ElementAccumulator)); + w_scale_ptr = reinterpret_cast(w_scale.data_ptr()) + + (i * N * sizeof(GroupedGemmArgs::ElementAccumulator)); + } if (zero_start_index_M.has_value() == true) { set_dynamic_kernel_args_kernel<<>>( - reinterpret_cast(XQ[i].data_ptr()), - reinterpret_cast(WQ[i].data_ptr()), - reinterpret_cast( - x_scale[i].data_ptr()), - reinterpret_cast( - w_scale[i].data_ptr()), + xq_ptr, + wq_ptr, + x_scale_ptr, + w_scale_ptr, input_args.data_ptr(), output_args.data_ptr(), output.data_ptr(), @@ -398,14 +437,11 @@ std::tuple> f8f8bf16_rowwise_grouped_impl( N, K); } else { - int M = XQ[i].size(0); set_kernel_args_kernel<<>>( - reinterpret_cast(XQ[i].data_ptr()), - reinterpret_cast(WQ[i].data_ptr()), - reinterpret_cast( - x_scale[i].data_ptr()), - reinterpret_cast( - w_scale[i].data_ptr()), + xq_ptr, + wq_ptr, + x_scale_ptr, + w_scale_ptr, input_args.data_ptr(), output_args.data_ptr(), output.data_ptr(), @@ -508,33 +544,44 @@ std::tuple> f8f8bf16_rowwise_grouped_impl( C10_CUDA_KERNEL_LAUNCH_CHECK(); - std::vector output_group = output.split(output_sizes); - for (int i = 0; i < group_count; ++i) { - output_group[i] = output_group[i].view({XQ[i].size(0), WQ[i].size(0)}); - } - // Return two views of the same underlying tensor. - return {output, output_group}; + return output; } // FP8 Tensorwise grouped cutlass kernel dispatch. -std::tuple> dispatch_fp8_grouped_kernel( - at::TensorList XQ, // FP8 - at::TensorList WQ, // FP8 - at::TensorList x_scale, - at::TensorList w_scale, +template +at::Tensor dispatch_fp8_grouped_kernel( + InputType XQ, // FP8 + InputType WQ, // FP8 + InputType x_scale, + InputType w_scale, at::Tensor output, std::optional zero_start_index_M) { KernelMode kernel = get_grouped_kernel_mode(XQ, WQ); if (kernel == KernelMode::Small) { - return f8f8bf16_rowwise_grouped_impl<64, 128, 128, 2, 1, 1, true>( - XQ, WQ, x_scale, w_scale, output, zero_start_index_M); + return f8f8bf16_rowwise_grouped_impl< + InputType, + 64, + 128, + 128, + 2, + 1, + 1, + true>(XQ, WQ, x_scale, w_scale, output, zero_start_index_M); } else { - return f8f8bf16_rowwise_grouped_impl<128, 256, 64, 2, 1, 1, false>( - XQ, WQ, x_scale, w_scale, output, zero_start_index_M); + return f8f8bf16_rowwise_grouped_impl< + InputType, + 128, + 256, + 64, + 2, + 1, + 1, + false>(XQ, WQ, x_scale, w_scale, output, zero_start_index_M); } } -std::vector f8f8bf16_rowwise_grouped( +template +OutputType _f8f8bf16_rowwise_grouped( at::TensorList XQ, // FP8 at::TensorList WQ, // FP8 at::TensorList x_scale, @@ -543,6 +590,7 @@ std::vector f8f8bf16_rowwise_grouped( std::optional kernel_name = std::nullopt) { at::Tensor Y; int group_count = XQ.size(); + std::vector output_sizes; if (output.has_value()) { std::vector output_; output_ = output.value(); @@ -561,11 +609,11 @@ std::vector f8f8bf16_rowwise_grouped( TORCH_CHECK( output_[i].dtype() == at::kBFloat16, "Output dtype must be bfloat16."); + output_sizes.push_back(out_M * out_N); } Y = at::stack(output.value(), 0); } else { int64_t total_output_size = 0; - std::vector output_sizes; for (int i = 0; i < group_count; ++i) { const int64_t output_size = XQ[i].size(0) * WQ[i].size(0); total_output_size += output_size; @@ -573,38 +621,65 @@ std::vector f8f8bf16_rowwise_grouped( } Y = at::zeros(total_output_size, XQ[0].options().dtype(at::kBFloat16)); } - // Return grouped view of output. - return std::get<1>( - dispatch_fp8_grouped_kernel(XQ, WQ, x_scale, w_scale, Y, std::nullopt)); + + // Run kernel. + at::Tensor g_out = dispatch_fp8_grouped_kernel( + XQ, WQ, x_scale, w_scale, Y, std::nullopt); + + // Return proper output type. + if constexpr (std::is_same_v>) { + // Return grouped view of output. + std::vector output_group = g_out.split(output_sizes); + for (int i = 0; i < group_count; ++i) { + output_group[i] = output_group[i].view({XQ[i].size(0), WQ[i].size(0)}); + } + return output_group; + } else { + // When stacked we assume N is fixed. + return g_out.view({-1, WQ[0].size(0)}); + } } -at::Tensor f8f8bf16_rowwise_grouped_dynamic( +std::vector f8f8bf16_rowwise_grouped( at::TensorList XQ, // FP8 at::TensorList WQ, // FP8 at::TensorList x_scale, at::TensorList w_scale, - std::optional zero_start_index_M = std::nullopt, + std::optional> output = std::nullopt, + std::optional kernel_name = std::nullopt) { + return _f8f8bf16_rowwise_grouped>( + XQ, WQ, x_scale, w_scale); +} + +at::Tensor f8f8bf16_rowwise_grouped_stacked( + at::TensorList XQ, // FP8 + at::TensorList WQ, // FP8 + at::TensorList x_scale, + at::TensorList w_scale, + std::optional> output = std::nullopt, + std::optional kernel_name = std::nullopt) { + return _f8f8bf16_rowwise_grouped(XQ, WQ, x_scale, w_scale); +} + +at::Tensor f8f8bf16_rowwise_grouped_dynamic( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor zero_start_index_M, std::optional kernel_name = std::nullopt) { at::Tensor Y; - int group_count = XQ.size(); - int64_t total_output_size = 0; - for (int i = 0; i < group_count; ++i) { - const int64_t output_size = XQ[i].size(0) * WQ[i].size(0); - total_output_size += output_size; - } - Y = at::zeros(total_output_size, XQ[0].options().dtype(at::kBFloat16)); + int group_count = XQ.size(0); + int M = XQ.size(1); + int N = WQ.size(1); + int K = XQ.size(0); + int total_output_size = group_count * M * N; + Y = at::zeros(total_output_size, XQ.options().dtype(at::kBFloat16)); // Return continuous view of output. - at::Tensor output = std::get<0>(dispatch_fp8_grouped_kernel( - XQ, WQ, x_scale, w_scale, Y, zero_start_index_M)); + at::Tensor output = dispatch_fp8_grouped_kernel( + XQ, WQ, x_scale, w_scale, Y, zero_start_index_M); // View as proper shape. - // When zero_start_index_M is provided, we can view as [G, M, N] - if (zero_start_index_M.has_value()) { - output = output.view({-1, XQ[0].size(0), WQ[0].size(0)}); - // Otherwise we view as {total_M, N}. - } else { - output = output.view({-1, WQ[0].size(0)}); - } - return output; + return output.view({-1, M, N}); } #else @@ -620,6 +695,17 @@ std::vector f8f8bf16_rowwise_grouped( "CUDA version is older than 12.0"); // requires CUDA>=12 } +at::Tensor f8f8bf16_rowwise_grouped_stacked( + at::TensorList XQ, // FP8 + at::TensorList WQ, // FP8 + at::TensorList x_scale, + at::TensorList w_scale, + std::optional> output = std::nullopt, + std::optional kernel_name = std::nullopt) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + at::Tensor f8f8bf16_rowwise_grouped_dynamic( at::TensorList XQ, // FP8 at::TensorList WQ, // FP8 diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index aa8faa11c8..43d9beefc2 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -94,12 +94,19 @@ std::vector f8f8bf16_rowwise_grouped( at::TensorList w_scale, std::optional> output = std::nullopt, std::optional kernel_name = std::nullopt); -at::Tensor f8f8bf16_rowwise_grouped_dynamic( +at::Tensor f8f8bf16_rowwise_grouped_stacked( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, - std::optional zero_start_index_M = std::nullopt, + std::optional> output = std::nullopt, + std::optional kernel_name = std::nullopt); +at::Tensor f8f8bf16_rowwise_grouped_dynamic( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor zero_start_index_M, std::optional kernel_name = std::nullopt); std::vector get_f8f8bf16_rowwise_grouped_kernels(); at::Tensor f8f8bf16_blockwise( @@ -210,7 +217,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "f8f8bf16_rowwise_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor[](a!)? output=None, str? kernel_name=None) -> Tensor[]"); m.def( - "f8f8bf16_rowwise_grouped_dynamic(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor? zero_start_index_M=None, str? kernel_name=None) -> Tensor"); + "f8f8bf16_rowwise_grouped_stacked(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor[](a!)? output=None, str? kernel_name=None) -> Tensor"); + m.def( + "f8f8bf16_rowwise_grouped_dynamic(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor zero_start_index_M, str? kernel_name=None) -> Tensor"); m.def( "f8f8bf16_tensorwise(Tensor XQ, Tensor WQ, float scale, bool use_fast_accum=True) -> Tensor"); m.def("per_tensor_quantize_i8(Tensor X, float scale) -> Tensor"); @@ -252,6 +261,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { m.impl("f8f8bf16_rowwise_out", f8f8bf16_rowwise_out); m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched); m.impl("f8f8bf16_rowwise_grouped", f8f8bf16_rowwise_grouped); + m.impl("f8f8bf16_rowwise_grouped_stacked", f8f8bf16_rowwise_grouped_stacked); m.impl("f8f8bf16_rowwise_grouped_dynamic", f8f8bf16_rowwise_grouped_dynamic); m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor); m.impl("quantize_fp8_per_row", quantize_fp8_per_row); @@ -278,6 +288,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise); m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched); m.impl("f8f8bf16_rowwise_grouped", f8f8bf16_rowwise_grouped); + m.impl("f8f8bf16_rowwise_grouped_stacked", f8f8bf16_rowwise_grouped_stacked); m.impl("f8f8bf16_rowwise_grouped_dynamic", f8f8bf16_rowwise_grouped_dynamic); m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor); m.impl("quantize_fp8_per_row", quantize_fp8_per_row); diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 51fbcc6db8..5853657597 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -737,7 +737,7 @@ def fp8_loopover_bmm( K=st.sampled_from([512, 3584]), use_cudagraph=st.booleans(), use_padding_zeros=st.booleans(), - use_dynamic=st.booleans(), + return_stacked=st.booleans(), ) def test_fp8_grouped_gemm( self, @@ -747,7 +747,7 @@ def test_fp8_grouped_gemm( K: int, use_cudagraph: bool, use_padding_zeros: bool, - use_dynamic: bool, + return_stacked: bool, ) -> None: ms = ( torch.randint( @@ -759,7 +759,7 @@ def test_fp8_grouped_gemm( * 64 ) # When using padding or the dynamic kernel, Ns and Ks should be fixed. - if use_padding_zeros or use_dynamic: + if use_padding_zeros or return_stacked: ns = [N] * G ks = [K] * G # Otherwise, any value is supported. @@ -800,10 +800,10 @@ def test_fp8_grouped_gemm( if use_padding_zeros: x_group = torch.unbind(torch.stack(x_group, dim=0).contiguous()) w_group = torch.unbind(torch.stack(w_group, dim=0).contiguous()) - xq_group = torch.unbind(torch.stack(xq_group, dim=0).contiguous()) - wq_group = torch.unbind(torch.stack(wq_group, dim=0).contiguous()) - x_scale_group = torch.unbind(torch.stack(x_scale_group, dim=0).contiguous()) - w_scale_group = torch.unbind(torch.stack(w_scale_group, dim=0).contiguous()) + xq_group = torch.stack(xq_group, dim=0).contiguous() + wq_group = torch.stack(wq_group, dim=0).contiguous() + x_scale_group = torch.stack(x_scale_group, dim=0).contiguous() + w_scale_group = torch.stack(w_scale_group, dim=0).contiguous() # FP8 grouped gemm kernel fp8_args = ( @@ -812,15 +812,19 @@ def test_fp8_grouped_gemm( wq_group, x_scale_group, w_scale_group, - zero_start_index_M if use_padding_zeros else None, + zero_start_index_M, ] - if use_dynamic + if use_padding_zeros else [xq_group, wq_group, x_scale_group, w_scale_group] ) fp8_op = ( torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic - if use_dynamic - else torch.ops.fbgemm.f8f8bf16_rowwise_grouped + if use_padding_zeros + else ( + torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked + if return_stacked + else torch.ops.fbgemm.f8f8bf16_rowwise_grouped + ) ) if use_cudagraph: # warmup @@ -843,12 +847,12 @@ def test_fp8_grouped_gemm( # BF16 grouped gemm kernel bf16_args = ( [x_group, w_group, zero_start_index_M if use_padding_zeros else None] - if use_dynamic + if return_stacked else [x_group, w_group] ) bf16_op = ( torch.ops.fbgemm.bf16bf16bf16_grouped_dynamic - if use_dynamic + if return_stacked else torch.ops.fbgemm.bf16bf16bf16_grouped ) if use_cudagraph: