From e2dd52b03a4f78f7e0b28f767fe4529eae1122fe Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 6 Feb 2025 17:14:46 -0800 Subject: [PATCH] FP8 Grouped Gemm Optimization (#3655) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3655 X-link: https://github.com/facebookresearch/FBGEMM/pull/731 While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`. To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead. To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor. In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads. Reviewed By: jianyuh, jiawenliu64 Differential Revision: D69072529 --- .../experimental/gemm/triton_gemm/fp8_gemm.py | 21 +- .../gen_ai/bench/profile_grouped_gemm.py | 103 -- .../experimental/gen_ai/bench/quantize_ops.py | 22 +- .../fp8_rowwise_grouped_gemm.hip | 469 +++---- ...8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip | 43 +- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1.hip | 137 ++- ...8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1.hip | 137 ++- ...8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip | 137 ++- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip | 137 ++- ...8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip | 137 ++- ...8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip | 137 ++- ...8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip | 137 ++- ...4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip | 137 ++- ...4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip | 137 ++- ...8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip | 137 ++- ...x16x1_1x16x1x16_4x4x1_1x1_interwave_v2.hip | 137 ++- ...x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2.hip | 137 ++- ...8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip | 137 ++- ...4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip | 137 ++- ...8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2.hip | 137 ++- ...8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2.hip | 137 ++- ...8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip | 137 ++- ...8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip | 137 ++- ...4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip | 137 ++- ...4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip | 137 ++- ...8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2.hip | 137 ++- ...8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2.hip | 137 ++- ...x32x1_1x16x1x16_8x8x1_1x1_interwave_v2.hip | 137 ++- ...x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2.hip | 137 ++- ...8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip | 137 ++- ..._8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip | 137 ++- ..._8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip | 137 ++- ..._8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip | 137 ++- ..._8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip | 137 ++- ...16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip | 43 +- ..._8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip | 137 ++- ..._8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip | 137 ++- ..._8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip | 137 ++- ..._8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip | 137 ++- ...32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip | 43 +- ..._8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip | 137 ++- ..._8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip | 137 ++- ..._8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip | 137 ++- ..._8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip | 137 ++- ...4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip | 137 ++- ...4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip | 137 ++- .../kernels/fp8_rowwise_grouped_common.h | 69 +- .../fp8_rowwise_grouped_kernel_manifest.h | 1086 ++++++++--------- .../f8f8bf16_rowwise_grouped.cu | 290 +++-- .../gen_ai/src/quantize/quantize.cpp | 29 +- .../gen_ai/test/quantize/quantize_test.py | 37 +- 79 files changed, 6720 insertions(+), 4714 deletions(-) delete mode 100644 fbgemm_gpu/experimental/gen_ai/bench/profile_grouped_gemm.py diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index bf8d52cb1e..4a6111510c 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -2447,6 +2447,13 @@ def triton_quantize_fp8_row( torch.Tensor: fp8 scaled tensor. torch.Tensor: reciprocal scale tensor per row. """ + assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor." + a_shape = a.shape + while a.dim() < 4: + a = a.unsqueeze(0) + if zero_start_index_M is not None: + # There should be one value of zero_start_index_M per NxK matrix. + zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1]) # Get constant values. pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants() num_rows = a.numel() // a.shape[-1] @@ -2484,7 +2491,7 @@ def triton_quantize_fp8_row( USE_INT64=use_int64, ) - return a_fp8, a_scale + return a_fp8.view(a_shape), a_scale.view(a_shape[:-1]) @torch.library.custom_op("triton::quantize_fp8_row", mutates_args=()) @@ -2514,17 +2521,7 @@ def quantize_fp8_row( logger.info("Triton does not support cpu, falling back to torch ops.") use_triton = False if use_triton: - assert ( - a.dim() <= 4 - ), "Only up to 4 dimension input tensor is supported if use_triton is True" - a_shape = a.shape - while a.dim() < 4: - a = a.unsqueeze(0) - if zero_start_index_M is not None: - # There should be one value of zero_start_index_M per NxK matrix. - zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1]) - a_fp8, a_scale = triton_quantize_fp8_row(a, scale_ub, zero_start_index_M) - return a_fp8.view(a_shape), a_scale.view(a_shape[:-1]) + return triton_quantize_fp8_row(a, scale_ub, zero_start_index_M) # else use pytorch implementation. if not output_device: output_device = a.device diff --git a/fbgemm_gpu/experimental/gen_ai/bench/profile_grouped_gemm.py b/fbgemm_gpu/experimental/gen_ai/bench/profile_grouped_gemm.py deleted file mode 100644 index 82019a769c..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/bench/profile_grouped_gemm.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -from typing import Any - -import pandas as pd -import torch - -from .quantize_ops import FP8RowwiseGroupedGemm - - -def main(args: Any): - # Extract and format shape arguments. - M = [int(m) for m in args.M.strip().split(",")] - N = [int(n) for n in args.N.strip().split(",")] - K = [int(k) for k in args.K.strip().split(",")] - assert len(M) == len(N) == len(K), "M, N, and K must have the same length." - - # initialize tensors for benchmarking. - A = [] - B = [] - num_groups = len(M) - for i in range(num_groups): - A.append(torch.randn(M[i], K[i], device="cuda")) - B.append(torch.randn(N[i], K[i], device="cuda")) - - # Get quantized tensors. - group_gemm_op = FP8RowwiseGroupedGemm() - quantized_vals = group_gemm_op.quantize(A, B) - # Iterate over kernels to find the most performant one. - benchmark_results = [] - for kernel_name in torch.ops.fbgemm.get_f8f8bf16_rowwise_grouped_kernels(): - # Do a warmup run of the kernel. - output = group_gemm_op.compute(*quantized_vals, kernel_name=kernel_name) - # Benchmark this kernel implementation. - ms_runtime = group_gemm_op.benchmark( - *quantized_vals, use_cuda_graph=True, kernel_name=kernel_name - ) - # Compute statistics for this kernel. - tflops = 0 - gbps = 0 - for i in range(num_groups): - tflops += 2 * M[i] * N[i] * K[i] / (ms_runtime / 1e3) / 1e12 - gbps += ( - ( - quantized_vals[0][i].numel() * quantized_vals[0][i].element_size() - + quantized_vals[1][i].numel() * quantized_vals[1][i].element_size() - + output[i].numel() * output[i].element_size() - ) - / (ms_runtime / 1e3) - / 1e9 - ) - # Record results. - print(f"Kernel: {kernel_name}, ms: {ms_runtime:.4f}, TFLOPS: {tflops:.2f}") - benchmark_results.append( - { - "kernel_name": kernel_name, - "ms_runtime": ms_runtime, - "tflops": tflops, - "gbps": gbps, - } - ) - # Report best kernel. - best_kernel = min(benchmark_results, key=lambda x: x["ms_runtime"]) - print( - f"Best kernel for this shape: {best_kernel['kernel_name']}: {best_kernel['tflops']:.2f} TFLOPS" - ) - - # If specified, save all results. - if args.export_csv: - df = pd.DataFrame(benchmark_results) - df.to_csv("grouped_gemm_benchmark.csv", index=False) - - -def invoke_main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument( - "--export_csv", - action="store_true", - help="Export results to a CSV file.", - ) - parser.add_argument( - "--M", - required=True, - help="Comma separated list of M values of each group to benchmark.", - ) - parser.add_argument( - "--N", - required=True, - help="Comma separated list of N values of each group to benchmark", - ) - parser.add_argument( - "--K", - required=True, - help="Comma separated list of K values of each group to benchmark.", - ) - - args = parser.parse_args() - main(args) diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index fcc2141526..97a490c5b3 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -19,6 +19,7 @@ quantize_fp8_block, quantize_fp8_row, scale_fp8_row, + triton_quantize_fp8_row, ) from tinygemm.utils import group_quantize_tensor @@ -553,38 +554,31 @@ def preprocess(self, x, w): def quantize(self, x, wq, w_scale, m_values=None): # Handle case where inputs are explicitly grouped and non-sparse. if isinstance(x, (tuple, list)): - xq, x_scale = zip(*[quantize_fp8_row(i) for i in x]) + xq, x_scale = zip(*[triton_quantize_fp8_row(i) for i in x]) return xq, wq, x_scale, w_scale, m_values # Otherwise inputs are unified tensors and sparse. else: B = x.shape[0] - xq, x_scale = quantize_fp8_row(x, zero_start_index_M=m_values) + xq, x_scale = triton_quantize_fp8_row(x, zero_start_index_M=m_values) x_scale = x_scale.view(B, -1) return xq, wq, x_scale, w_scale, m_values - def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None): + def compute(self, xq, wq, x_scale, w_scale, m_values): if m_values is None: return torch.ops.fbgemm.f8f8bf16_rowwise_grouped( xq, wq, x_scale, w_scale, - kernel_name=kernel_name, ) else: # Break tensor into groups, simulates what is done e2e. - B = xq.shape[0] - xq_group = [xq[i, :, :] for i in range(B)] - x_scale_group = [x_scale[i, :] for i in range(B)] - wq_group = [wq[i, :, :] for i in range(B)] - w_scale_group = [w_scale[i, :] for i in range(B)] return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic( - xq_group, - wq_group, - x_scale_group, - w_scale_group, + xq, + wq, + x_scale, + w_scale, zero_start_index_M=m_values, - kernel_name=kernel_name, ) def quantize_and_compute(self, x, wq, w_scale, m_values=None): diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip index c82cd46410..ca3afc48fb 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip @@ -27,6 +27,15 @@ namespace fbgemm_gpu { +template +using RowwiseGroupedKernel = std::function; + // Define useful types that are needed for various kernels. using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<2>; @@ -37,48 +46,49 @@ using D1DataType = float; using DsDataType = ck::Tuple; using EDataType = ck::bhalf_t; -RowwiseGroupedKernel rowwise_grouped_heuristic_dispatch(int M, int N, int K) { +template +RowwiseGroupedKernel rowwise_grouped_heuristic_dispatch(int M, int N, int K) { // We use shape heuristics to find the best kernel. // To do this, we divide by the size of M and find the best // option within that grouping. if (M <= 16) { if (N < 8192 && K <= 8192) { - return fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1; + return fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1; } if (K <= 8192) { - return fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2; + return fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2; } - return fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2; + return fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2; } if (M <= 32) { if (N < 8192 && K <= 8192) { - return fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2; + return fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2; } if (K <= 8192) { - return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2; + return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2; } - return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2; + return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2; } if (M <= 64) { - return fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + return fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; } if (M <= 128) { if (N < 8192 && K <= 8192) { - return fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + return fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; } - return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; } if (M <= 256) { - return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; } if (M <= 512) { if (K <= 8192) { - return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; } - return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; } // Default kernel for all other shapes. - return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; } __global__ void set_kernel_args_kernel( @@ -102,75 +112,60 @@ __global__ void set_kernel_args_kernel( } } +template void set_static_kernel_args( at::Tensor kernel_args, at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, - std::vector output) { + OutputType output) { // Get current cuda stream. auto stream = at::cuda::getCurrentHIPStream().stream(); int group_count = XQ.size(); + // Declare variables for loop. + EDataType* output_ptr; + int output_offset = 0; // When group count is large, we can more efficiently initialize // by doing host setup and a memcpy. This is only viable if cuda // graphs arent being used. - if (group_count >= 16 && stream == 0) { - std::vector ggemm_kargs; - ggemm_kargs.reserve(group_count); + // Iterate over inputs and get group information. + for (int i = 0; i < group_count; i++) { + int M = XQ[i].size(0); + int K = XQ[i].size(1); + int N = WQ[i].size(0); - // Iterate over inputs and get group information. - for (int i = 0; i < group_count; i++) { - int M = XQ[i].size(0); - int K = XQ[i].size(1); - int N = WQ[i].size(0); - KernelArguments group_args = { - reinterpret_cast(XQ[i].data_ptr()), - reinterpret_cast(WQ[i].data_ptr()), - {reinterpret_cast(w_scale[i].data_ptr()), - reinterpret_cast(x_scale[i].data_ptr())}, - reinterpret_cast(output[i].data_ptr()), - M, - N, - K, - K, - K, - {0, 0}, - N}; - ggemm_kargs.push_back(group_args); + // Compute proper output pointer. + if constexpr (std::is_same_v>) { + // Output is a list of tensors and we can access each individually. + output_ptr = reinterpret_cast(output[i].data_ptr()); + } else{ + // Output is a single contiguous tensor and must be accessed via offset. + output_ptr = reinterpret_cast(output.data_ptr()) + output_offset; + output_offset += M * N; } - // Copy data onto device. - hipMemcpy( - kernel_args.data_ptr(), // Destination - ggemm_kargs.data(), // Source - sizeof(KernelArguments) * group_count, // Number of bytes - hipMemcpyHostToDevice); // Copy Type - } else { + // We use the smallest reasonable block size since we effectively need only // 1 thread. - int blockSize = 32; - int numBlocks = 1; + const int blockSize = 32; + const int numBlocks = 1; // Launch a kernel for each group to set kernel memory on device. // Using multiple kernels this way allows us to support arbitrary M,N,K. // For some reason, this approach is faster than using hipmemcpy. - for (int i = 0; i < group_count; i++) { - int M = XQ[i].size(0); - int K = XQ[i].size(1); - int N = WQ[i].size(0); - // Launch kernel to set kernel arguments. - set_kernel_args_kernel<<>>( - reinterpret_cast( - reinterpret_cast(kernel_args.data_ptr()) + - (i * sizeof(KernelArguments))), - reinterpret_cast(XQ[i].data_ptr()), - reinterpret_cast(WQ[i].data_ptr()), - reinterpret_cast(w_scale[i].data_ptr()), - reinterpret_cast(x_scale[i].data_ptr()), - reinterpret_cast(output[i].data_ptr()), - M, - N, - K); - } + // Launch kernel to set kernel arguments. + set_kernel_args_kernel<<>>( + reinterpret_cast( + reinterpret_cast(kernel_args.data_ptr()) + + (i * sizeof(KernelArguments))), + reinterpret_cast(XQ[i].data_ptr()), + reinterpret_cast(WQ[i].data_ptr()), + reinterpret_cast(w_scale[i].data_ptr()), + reinterpret_cast(x_scale[i].data_ptr()), + output_ptr, + M, + N, + K); + } } @@ -239,18 +234,18 @@ __global__ void set_kernel_args_fixed_nk_kernel( void set_dynamic_kernel_args( at::Tensor kernel_args, - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector output, + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor output, at::Tensor zero_start_index_M) { // Get current cuda stream. auto stream = at::cuda::getCurrentHIPStream().stream(); - int group_count = XQ.size(); + int group_count = XQ.size(0); // Confirm M is on the proper device. TORCH_CHECK( - XQ[0].device() == zero_start_index_M.device(), + XQ.device() == zero_start_index_M.device(), "zero_start_index_M and inputs must be on the same device."); TORCH_CHECK( zero_start_index_M.size(0) == group_count, @@ -261,35 +256,9 @@ void set_dynamic_kernel_args( // We assume that M, N, and K are fixed across groups. // The actual m values are sstored in the passed M tensor. - int M = XQ[0].size(0); - int K = XQ[0].size(1); - int N = WQ[0].size(0); - - // Make sure that inputs are allocated in sequential memory as required by - // this mode. - for (int i = 1; i < group_count; i++) { - // Check that all inputs are allocated directly following preceding input. - TORCH_CHECK( - XQ[i].data_ptr() == - (reinterpret_cast(XQ[i - 1].data_ptr()) + (M * K)), - "Inputs must be sequential in memory to support dynamic M, but XQ is not."); - TORCH_CHECK( - WQ[i].data_ptr() == - (reinterpret_cast(WQ[i - 1].data_ptr()) + (N * K)), - "Inputs must be sequential in memory to support dynamic M, but WQ is not."); - TORCH_CHECK( - x_scale[i].data_ptr() == - (reinterpret_cast(x_scale[i - 1].data_ptr()) + (M)), - "Inputs must be sequential in memory to support dynamic M, but x_scale is not."); - TORCH_CHECK( - w_scale[i].data_ptr() == - (reinterpret_cast(w_scale[i - 1].data_ptr()) + (N)), - "Inputs must be sequential in memory to support dynamic M, but w_scale is not."); - TORCH_CHECK( - output[i].data_ptr() == - (reinterpret_cast(output[i - 1].data_ptr()) + (M * N)), - "Inputs must be sequential in memory to support dynamic M, but output is not."); - } + int M = XQ.size(1); + int K = XQ.size(2); + int N = WQ.size(1); // Launch a kernel that sets kernel argument memory. const int BLOCK_SIZE = 8; @@ -298,11 +267,11 @@ void set_dynamic_kernel_args( int numBlocks = (block_factor + blockSize - 1) / blockSize; set_kernel_args_fixed_nk_kernel<<>>( reinterpret_cast(kernel_args.data_ptr()), - reinterpret_cast(XQ[0].data_ptr()), - reinterpret_cast(WQ[0].data_ptr()), - reinterpret_cast(w_scale[0].data_ptr()), - reinterpret_cast(x_scale[0].data_ptr()), - reinterpret_cast(output[0].data_ptr()), + reinterpret_cast(XQ.data_ptr()), + reinterpret_cast(WQ.data_ptr()), + reinterpret_cast(w_scale.data_ptr()), + reinterpret_cast(x_scale.data_ptr()), + reinterpret_cast(output.data_ptr()), reinterpret_cast(zero_start_index_M.data_ptr()), M, N, @@ -311,50 +280,13 @@ void set_dynamic_kernel_args( BLOCK_SIZE); } -at::Tensor get_grouped_kernel_args( +template +OutputType _f8f8bf16_rowwise_grouped( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, - std::optional zero_start_index_M, - std::vector output) { - int group_count = XQ.size(); - // Get space on device for the kernel argument tensor. - at::Tensor kernel_args = at::empty( - {static_cast(group_count * sizeof(KernelArguments))}, - XQ[0].options().dtype(at::kByte)); - - // There are two different modes for this kernel. - // When zero_start_index_M is provided, we assume that data is sequential and - // that N and K are constants. This allows a more efficient kernel - // launch and is best suited to MOE use cases where M is truly dynamic. - // When zero_start_index_M is not provided, we assume M, N, and K can all vary - // and set them for each group. It is important to note that this does not - // work well with cuda graphs and runtime dynamism so if possible we recommend - // using zero_start_index_M. - - if (zero_start_index_M.has_value()) { - set_dynamic_kernel_args( - kernel_args, - XQ, - WQ, - x_scale, - w_scale, - output, - zero_start_index_M.value()); - } else { - set_static_kernel_args(kernel_args, XQ, WQ, x_scale, w_scale, output); - } - return kernel_args; -} - -std::vector f8f8bf16_rowwise_grouped( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::optional> output = std::nullopt, - std::optional kernel_name = std::nullopt) { + std::optional output = std::nullopt) { // Check that input datatypes are valid. // First confirm that there are the same number of groups in all inputs. TORCH_CHECK( @@ -387,47 +319,60 @@ std::vector f8f8bf16_rowwise_grouped( TORCH_CHECK(ws.dtype() == at::kFloat, "Scales must be float32."); } - std::vector Y; - if (output.has_value()) { - Y = output.value(); - TORCH_CHECK( - Y.size() == group_count, - "Output and input must have same number of groups."); - // Check that output shapes are correct. - for (int i = 0; i < group_count; i++) { - int M = XQ[i].size(0); - int N = WQ[i].size(0); - int out_M = Y[i].size(0); - int out_N = Y[i].size(1); - TORCH_CHECK( - M == out_M && N == out_N, - "Output tensors do not have the expected shape."); + OutputType Y; + // Need to handle different output modes separately. + // First handle tensor list output. + if constexpr (std::is_same_v>) { + if (output.has_value()) { + Y = output.value(); TORCH_CHECK( - Y[i].dtype() == at::kBFloat16, "Output dtype must be bfloat16."); + Y.size() == group_count, + "Output and input must have same number of groups."); + // Check that output shapes are correct. + for (int i = 0; i < group_count; i++) { + int M = XQ[i].size(0); + int N = WQ[i].size(0); + int out_M = Y[i].size(0); + int out_N = Y[i].size(1); + TORCH_CHECK( + M == out_M && N == out_N, + "Output tensors do not have the expected shape."); + TORCH_CHECK( + Y[i].dtype() == at::kBFloat16, "Output dtype must be bfloat16."); + } + } else { + for (int i = 0; i < group_count; i++) { + int M = XQ[i].size(0); + int N = WQ[i].size(0); + Y.push_back(at::empty({M, N}, XQ[i].options().dtype(at::kBFloat16))); + } } + // Now handle single tensor output. } else { + // Compute total M across groups. + int total_M = 0; + int N = WQ[0].size(0); for (int i = 0; i < group_count; i++) { - int M = XQ[i].size(0); - int N = WQ[i].size(0); - Y.push_back(at::empty({M, N}, XQ[i].options().dtype(at::kBFloat16))); + total_M += XQ[i].size(0); + // Also make sure N is constant across shapes. + TORCH_CHECK(WQ[i].size(0) == N, "N must be constant across groups for stacked output."); + } + if (output.has_value()) { + Y = output.value(); + // Check that shape is expected. + TORCH_CHECK(Y.size(0) == total_M && Y.size(1) == N, "Preallocated output should have size [total_M, N]."); + } + else { + Y = at::empty({total_M, N}, XQ[0].options().dtype(at::kBFloat16)); } } // Prepare kernel arguments by copying them to the proper device location. - at::Tensor kernel_args = - get_grouped_kernel_args(XQ, WQ, x_scale, w_scale, std::nullopt, Y); + at::Tensor kernel_args = at::empty( + {static_cast(group_count * sizeof(KernelArguments))}, + XQ[0].options().dtype(at::kByte)); + set_static_kernel_args(kernel_args, XQ, WQ, x_scale, w_scale, Y); - // If provided a specific kernel implementation, dispatch to it. - if (kernel_name.has_value()) { - auto it = kernel_name_map.find(kernel_name.value()); - // If not found, raise an error. - TORCH_CHECK( - it != kernel_name_map.end(), - "Could not find kernel " + kernel_name.value()); - // If found, always use requested kernel. - return it->second(XQ, WQ, x_scale, w_scale, kernel_args, Y); - } - // Otherwise, use heuristics to find the best kernel options. // We use the largest of each shape for heuristics. int MaxM = 0; int MaxN = 0; @@ -437,127 +382,87 @@ std::vector f8f8bf16_rowwise_grouped( MaxN = max(MaxN, WQ[i].size(0)); MaxK = max(MaxK, XQ[i].size(1)); } - RowwiseGroupedKernel selected_kernel = - rowwise_grouped_heuristic_dispatch(MaxM, MaxN, MaxK); + RowwiseGroupedKernel selected_kernel = + rowwise_grouped_heuristic_dispatch(MaxM, MaxN, MaxK); return selected_kernel(XQ, WQ, x_scale, w_scale, kernel_args, Y); } -at::Tensor f8f8bf16_rowwise_grouped_dynamic( +// Wrapper function for list input list output. +std::vector f8f8bf16_rowwise_grouped( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, - std::optional zero_start_index_M = std::nullopt, - std::optional kernel_name = std::nullopt) { + std::optional> output = std::nullopt) { + return _f8f8bf16_rowwise_grouped>(XQ, WQ, x_scale, w_scale, output); +} + +// Wrapper function for list input single tensor output. +at::Tensor f8f8bf16_rowwise_grouped_stacked( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + std::optional output = std::nullopt) { + return _f8f8bf16_rowwise_grouped(XQ, WQ, x_scale, w_scale, output); +} + +at::Tensor f8f8bf16_rowwise_grouped_dynamic( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor zero_start_index_M) { // Check that input datatypes are valid. // First confirm that there are the same number of groups in all inputs. + int group_count = XQ.size(0); + int M = XQ.size(1); + int N = WQ.size(1); + int K = XQ.size(2); TORCH_CHECK( - XQ.size() == WQ.size() && XQ.size() == x_scale.size() && - XQ.size() == w_scale.size(), + WQ.size(0) == group_count && x_scale.numel() / group_count == M && + w_scale.numel() / group_count == N, "All inputs must have the same number of groups."); - int group_count = XQ.size(); // Iterate over inputs and check they are valid. - for (at::Tensor x : XQ) { - TORCH_CHECK(x.is_cuda() && x.is_contiguous()); - TORCH_CHECK(x.dim() == 2, "Inputs must be 2D."); - TORCH_CHECK( - x.dtype() == at::kFloat8_e4m3fnuz, - "Inputs must be type float8_e4m3fnuz."); - } - for (at::Tensor w : WQ) { - TORCH_CHECK(w.is_cuda() && w.is_contiguous()); - TORCH_CHECK(w.dim() == 2, "Inputs must be 2D."); - TORCH_CHECK( - w.dtype() == at::kFloat8_e4m3fnuz, - "Inputs must be type float8_e4m3fnuz."); - TORCH_CHECK( - w.size(0) >= 512 && w.size(1) >= 512, - "N and K must be at least 512 for grouped gemm. For smaller inputs, consider unrolling."); - } - for (at::Tensor xs : x_scale) { - TORCH_CHECK(xs.dtype() == at::kFloat, "Scales must be float32."); - } - for (at::Tensor ws : x_scale) { - TORCH_CHECK(ws.dtype() == at::kFloat, "Scales must be float32."); - } + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); + TORCH_CHECK(XQ.dim() == 3, "Input XQ must be 3D (G,M,K)."); + TORCH_CHECK( + XQ.dtype() == at::kFloat8_e4m3fnuz, + "Input XQ must be type float8_e4m3fnuz."); - at::Tensor Y_full; - std::vector Y; - int M = XQ[0].size(0); - int N = WQ[0].size(0); - int K = XQ[0].size(1); - // When run with padding, we return an output with shape [G, M, N] since - // M and N are consistent across groups. - if (zero_start_index_M.has_value()) { - // Allocate an empty output array. We will set its values to zero as part - // of kernel setup. - Y_full = - at::empty({group_count, M, N}, XQ[0].options().dtype(at::kBFloat16)); - // Split the output into groups. - Y = at::unbind(Y_full, 0); - // When padding isnt provided, we return a tensor with shape [Total_M, N] - // since viewing as groups isnt possible due to variable M. - } else { - int total_M = 0; - std::vector group_sizes = {}; - for (int i = 0; i < group_count; i++) { - TORCH_CHECK( - XQ[i].size(1) == K && WQ[i].size(0) == N && WQ[i].size(1) == K, - "Dynamic grouped gemm requires fixed N and K."); - int group_M = XQ[i].size(0); - total_M += group_M; - group_sizes.push_back(group_M); - } - // Allocate continuous array for all groups. - Y_full = at::empty({total_M, N}, XQ[0].options().dtype(at::kBFloat16)); - // Split full array into groups for downstream handling. - int offset = 0; - for (int size : group_sizes) { - Y.push_back(Y_full.narrow(0, offset, size)); - offset += size; - } - } + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); + TORCH_CHECK(WQ.dim() == 3, "Input WQ must be 3D (G,N,K)."); + TORCH_CHECK( + WQ.dtype() == at::kFloat8_e4m3fnuz, + "Input WQ must be type float8_e4m3fnuz."); + TORCH_CHECK( + WQ.size(1) >= 512 && WQ.size(2) >= 512, + "N and K must be at least 512 for grouped gemm. For smaller inputs, consider unrolling."); - // Prepare kernel arguments by copying them to the proper device location. - at::Tensor kernel_args = - get_grouped_kernel_args(XQ, WQ, x_scale, w_scale, zero_start_index_M, Y); + TORCH_CHECK(x_scale.dtype() == at::kFloat, "Scales must be float32."); + TORCH_CHECK(w_scale.dtype() == at::kFloat, "Scales must be float32."); - // If provided a specific kernel implementation, dispatch to it. - if (kernel_name.has_value()) { - auto it = kernel_name_map.find(kernel_name.value()); - // If not found, raise an error. - TORCH_CHECK( - it != kernel_name_map.end(), - "Could not find kernel " + kernel_name.value()); - // If found, always use requested kernel. - it->second(XQ, WQ, x_scale, w_scale, kernel_args, Y); - return Y_full; - } - // Otherwise, use heuristics to find the best kernel options. - // We use the largest of each shape for heuristics. - int MaxM = 0; - int MaxN = 0; - int MaxK = 0; - for (int i = 0; i < group_count; i++) { - MaxM = max(MaxM, XQ[i].size(0)); - MaxN = max(MaxN, WQ[i].size(0)); - MaxK = max(MaxK, XQ[i].size(1)); - } - RowwiseGroupedKernel selected_kernel = - rowwise_grouped_heuristic_dispatch(MaxM, MaxN, MaxK); - // Run kernel to populate Y. - selected_kernel(XQ, WQ, x_scale, w_scale, kernel_args, Y); - // Return unified view of Y_full. - return Y_full; -} + // Allocate an empty output array. We will set its values to zero as part + // of kernel setup. + at::Tensor Y = + at::empty({group_count, M, N}, XQ.options().dtype(at::kBFloat16)); -std::vector get_f8f8bf16_rowwise_grouped_kernels() { - /* Helper function to get the names of avaialable grouped gemm kernels.*/ - std::vector kernel_names; - for (const auto& pair : kernel_name_map) { - kernel_names.push_back(pair.first); - } - return kernel_names; + // Prepare kernel arguments by copying them to the proper device location. + at::Tensor kernel_args = at::empty( + {static_cast(group_count * sizeof(KernelArguments))}, + XQ.options().dtype(at::kByte)); + set_dynamic_kernel_args( + kernel_args, + XQ, + WQ, + x_scale, + w_scale, + Y, + zero_start_index_M); + + RowwiseGroupedKernel selected_kernel = + rowwise_grouped_heuristic_dispatch(M, N, K); + return selected_kernel(XQ, WQ, x_scale, w_scale, kernel_args, Y); } } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip index 4c99757b80..050a7c4dbb 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 128, - 16, - 128, - 16, - 16, - 4, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 128, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 128, - 16, - 128, - 16, - 16, - 4, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 128, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip index 21e92b3997..8e688a8afa 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 128, - 16, - 128, - 16, - 16, - 4, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 128, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 128, - 16, - 128, - 16, - 16, - 4, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 128, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip index e24f8b6fd3..8742cbfdcf 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip @@ -8,14 +8,15 @@ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y) { + OutputType Y) { // A kernel that works well on small but not super tiny shapes. using DeviceGemmInstance = DeviceGemmHelper< 128, @@ -35,5 +36,33 @@ fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_in ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, kernel_args, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } + +template std::vector +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip index 3e5d0f249c..f2db563f83 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 128, - 32, - 128, - 32, - 32, - 2, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 128, - 32, - 128, - 32, - 32, - 2, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip index 9c871f64c9..dd21b3ca82 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 128, - 32, - 128, - 32, - 32, - 2, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 128, - 32, - 128, - 32, - 32, - 2, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip index f2a90ac2b0..aa5e48af27 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 128, - 128, - 16, - 16, - 1, - 4, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 128, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 128, - 128, - 16, - 16, - 1, - 4, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 128, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip index 64f6311648..7a9a68b32f 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 128, - 128, - 16, - 16, - 1, - 4, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 128, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 128, - 128, - 16, - 16, - 1, - 4, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 128, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip index b9ed2888b3..c7d290f959 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip index 34a7367f5b..496c536bdd 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip index 0637fde603..773d5256fa 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip index 868cd8275e..ff5009dbb4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip index 8c324b7989..0c13cb3add 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 256 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 256 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 256 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 256, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 256, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip index eeb3ecfa99..7e12cfd440 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 256 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 256 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 256 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 256, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 256, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip index 2d4faaaa7a..7890c90eaf 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 256 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 256 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 256 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 256, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 256, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip index f5c9aa7795..1ee23d2480 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 256 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 256 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 256 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 256, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 256, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip index 6f8da7b7e7..b810c78317 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 512 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 512 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 512 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 512, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 512, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip index 123638a331..0c8e5070a7 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 512 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 512 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 512 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 512, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 512, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip index 28a14967cc..8e996b2d6d 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 512 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 512 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 512 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 512, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 512, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip index 0a20e29cd7..d390371d63 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 512 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 512 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 512 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 512, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 512, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip index fe794f8533..3304f518c4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 64, - 128, - 16, - 16, - 1, - 2, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 64, + 128, + 16, + 16, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 64, - 128, - 16, - 16, - 1, - 2, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 64, + 128, + 16, + 16, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip index 6ef2ed503b..e3c54acbf1 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 64, - 128, - 16, - 16, - 1, - 2, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 64, + 128, + 16, + 16, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 64, - 128, - 16, - 16, - 1, - 2, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 64, + 128, + 16, + 16, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip index 3cd1a219ce..2ab8c05133 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 128, - 128, - 32, - 32, - 1, - 2, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 128, - 128, - 32, - 32, - 1, - 2, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip index a3b3b37ec3..45539c10c1 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 128, - 128, - 32, - 32, - 1, - 2, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 128, - 128, - 32, - 32, - 1, - 2, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1.hip index fac1cd90b6..4fb9a5ee52 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip index 5059ba7370..7439d38bba 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1.hip index 1750fe9158..e0a8497931 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip index cd5cc29f61..007087a94e 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip index 67d7d7d778..f1448f9f63 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 64, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 64, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip index 63bb549f4f..7cf04ce97e 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 64, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 64, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip index 329e3a0e06..bb77629561 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 64, - 16, - 128, - 16, - 16, - 2, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 64, + 16, + 128, + 16, + 16, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 64, - 16, - 128, - 16, - 16, - 2, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 64, + 16, + 128, + 16, + 16, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip index 9d7908f441..74bc32a07f 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 64, - 16, - 128, - 16, - 16, - 2, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 64, + 16, + 128, + 16, + 16, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 64, - 16, - 128, - 16, - 16, - 2, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 64, + 16, + 128, + 16, + 16, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip index 7330ad30c5..395b0bb6e9 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 64, - 32, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 64, + 32, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 64, - 32, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 64, + 32, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip index 5c07c9b991..72b63c1f15 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 64, - 32, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 64, + 32, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 64, - 32, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 64, + 32, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip index 270c4b6b47..9bafe6abc8 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip index c90498392b..18a22e68d0 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip index 9282b9d898..2d0612e928 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v4, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v4, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip index 1a09ef2a40..7456c4a272 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 64 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 64 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 64 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 64, - 32, - 32, - 2, - 2, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v4, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 128, + 128, + 64, + 32, + 32, + 2, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 64, - 32, - 32, - 2, - 2, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v4, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 128, + 128, + 64, + 32, + 32, + 2, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip index e05304e317..e01d070f81 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 64 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 64 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 64 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 256, - 64, - 32, - 32, - 2, - 4, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 128, + 256, + 64, + 32, + 32, + 2, + 4, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 256, - 64, - 32, - 32, - 2, - 4, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 128, + 256, + 64, + 32, + 32, + 2, + 4, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip index 005b98e6f2..169610d9aa 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 64, - 128, - 32, - 32, - 2, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 128, + 64, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 64, - 128, - 32, - 32, - 2, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 128, + 64, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2.hip index 43440d9f09..6343d07c67 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 16, - 256, - 128, - 16, - 16, - 1, - 4, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 16>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 16, + 256, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 16>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 16, - 256, - 128, - 16, - 16, - 1, - 4, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 16>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 16, + 256, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 16>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2.hip index 59892b1829..55a7df51dc 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 16, - 256, - 128, - 16, - 16, - 1, - 4, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 16>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 16, + 256, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 16>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 16, - 256, - 128, - 16, - 16, - 1, - 4, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 16>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 16, + 256, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 16>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip index 47be5bd05b..59bfe98ad8 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 224, - 256, - 128, - 16, - 16, - 7, - 8, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 224, + 256, + 128, + 16, + 16, + 7, + 8, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 224, - 256, - 128, - 16, - 16, - 7, - 8, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 224, + 256, + 128, + 16, + 16, + 7, + 8, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip index b4c7f93443..daf19b198a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 64 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 64 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 64 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 128, - 64, - 32, - 32, - 4, - 2, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 128, + 64, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 128, - 64, - 32, - 32, - 4, - 2, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 128, + 64, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2.hip index c27587b65d..0f885162ed 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 16, - 128, - 16, - 16, - 4, - 1, - S<8, 32, 1>, - S<8, 16, 1>, - S<1, 32, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 32, 1>, + S<8, 16, 1>, + S<1, 32, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 16, - 128, - 16, - 16, - 4, - 1, - S<8, 32, 1>, - S<8, 16, 1>, - S<1, 32, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 32, 1>, + S<8, 16, 1>, + S<1, 32, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2.hip index 5410a90837..629296c380 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 16, - 128, - 16, - 16, - 4, - 1, - S<8, 32, 1>, - S<8, 16, 1>, - S<1, 32, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 32, 1>, + S<8, 16, 1>, + S<1, 32, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 16, - 128, - 16, - 16, - 4, - 1, - S<8, 32, 1>, - S<8, 16, 1>, - S<1, 32, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 32, 1>, + S<8, 16, 1>, + S<1, 32, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip index f728fd9cfe..4604430f93 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 224, - 128, - 16, - 16, - 8, - 7, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 64, 1, 4>, - S<8, 8, 1>, - 2, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 224, + 128, + 16, + 16, + 8, + 7, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 64, 1, 4>, + S<8, 8, 1>, + 2, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 224, - 128, - 16, - 16, - 8, - 7, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 64, 1, 4>, - S<8, 8, 1>, - 2, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 224, + 128, + 16, + 16, + 8, + 7, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 64, 1, 4>, + S<8, 8, 1>, + 2, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip index 165d3cb906..942b849cae 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 128, - 16, - 16, - 8, - 8, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 256, + 128, + 16, + 16, + 8, + 8, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 128, - 16, - 16, - 8, - 8, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 256, + 128, + 16, + 16, + 8, + 8, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip index fd63d9befd..7ce351f1d0 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 64 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 64 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 64 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 64, - 16, - 16, - 8, - 8, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 256, + 64, + 16, + 16, + 8, + 8, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 64, - 16, - 16, - 8, - 8, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 256, + 64, + 16, + 16, + 8, + 8, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip index 872c8d6756..8b1cdb3b91 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 64 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 64 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 64 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 64, - 32, - 32, - 4, - 4, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v4, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 256, + 64, + 32, + 32, + 4, + 4, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 64, - 32, - 32, - 4, - 4, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v4, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 256, + 64, + 32, + 32, + 4, + 4, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2.hip index 6a391e1da0..c34b647725 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 32, - 128, - 32, - 32, - 2, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 32, - 128, - 32, - 32, - 2, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2.hip index 3d051451a6..816fffd74f 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 32, - 128, - 32, - 32, - 2, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 32, - 128, - 32, - 32, - 2, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2.hip index 7815fb9334..fbaf38d847 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 32, - 256, - 128, - 32, - 32, - 1, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 16, 1, 16>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 32, + 256, + 128, + 32, + 32, + 1, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 16, 1, 16>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 32, - 256, - 128, - 32, - 32, - 1, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 16, 1, 16>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 32, + 256, + 128, + 32, + 32, + 1, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 16, 1, 16>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2.hip index fc8bdf60d0..673a6904c5 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 32, - 256, - 128, - 32, - 32, - 1, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 16, 1, 16>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 32, + 256, + 128, + 32, + 32, + 1, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 16, 1, 16>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 32, - 256, - 128, - 32, - 32, - 1, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 16, 1, 16>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 32, + 256, + 128, + 32, + 32, + 1, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 16, 1, 16>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip index 35700d8098..81e91d08ce 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 64, - 64, - 128, - 32, - 32, - 1, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 64, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 64, - 64, - 128, - 32, - 32, - 1, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 64, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip index 8714c5d8e7..b0c156cd4e 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip index 749d0a3c94..1cd68d42dd 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip index 76388b8bba..6903129661 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip index 5cc809bef8..dfc6e5aa44 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip index bf8b430f0e..da245ef71a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip @@ -8,14 +8,15 @@ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y) { + OutputType Y) { // Secret kernel that seems good with small M but large N and K. using DeviceGemmInstance = DeviceGemmHelper< 64, @@ -36,5 +37,33 @@ fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intr ck::BlockGemmPipelineVersion::v1, ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, kernel_args, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } + +template std::vector +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip index 3468432a90..e9dfcad9ee 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 256 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 256 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 256 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 256, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 256, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip index 1f2339131e..0e4a7e8be9 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 256 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 256 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 256 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 256, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 256, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip index 2f28b951ea..93597dfa51 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 256 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 256 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 256 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 256, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 256, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip index 673d19a0f4..ccc4a81bf4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 256 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 256 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 256 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 256, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 256, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip index a832dcd1e3..df36c52ceb 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -8,14 +8,15 @@ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y) { + OutputType Y) { // The smallest kernel we have available. Works well for memory bound shapes. using DeviceGemmInstance = DeviceGemmHelper< 64, @@ -36,5 +37,33 @@ fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_inte ck::BlockGemmPipelineVersion::v2, ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, kernel_args, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } + +template std::vector +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip index b9d76db772..a9ef13ed69 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 512 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 512 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 512 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 512, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 512, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip index 18236d9db4..3b9b2e0ed4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 512 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 512 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 512 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 512, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 512, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip index 17f17cee6f..e079dada52 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 512 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 512 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 512 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 512, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 512, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip index e9efd64a56..aa1a75e1ec 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 512 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 512 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 512 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 512, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 512, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip index 6b7d7553d0..a673233688 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 64 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 64 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 64 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 64, - 16, - 16, - 1, - 1, - S<4, 16, 1>, - S<4, 16, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 64, + 16, + 16, + 1, + 1, + S<4, 16, 1>, + S<4, 16, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 64, - 16, - 16, - 1, - 1, - S<4, 16, 1>, - S<4, 16, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 64, + 16, + 16, + 1, + 1, + S<4, 16, 1>, + S<4, 16, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip index 3de229cd4d..5cbc3fd60b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,103 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 64 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 64 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 64 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 64, - 16, - 16, - 1, - 1, - S<4, 16, 1>, - S<4, 16, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 64, + 16, + 16, + 1, + 1, + S<4, 16, 1>, + S<4, 16, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 64, - 16, - 16, - 1, - 1, - S<4, 16, 1>, - S<4, 16, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 64, + 16, + 16, + 1, + 1, + S<4, 16, 1>, + S<4, 16, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h index 806adabca2..9ea8df849a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h @@ -122,16 +122,22 @@ using DeviceGemmHelper = PIPELINE_VERSION, ComputeType>; -template -std::vector f8f8bf16_rowwise_grouped_impl( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, +// Templated kernel launch to accommodate different input and output types. +template +OutputType f8f8bf16_rowwise_grouped_impl( + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y) { + OutputType Y) { // Get input information. - int group_count = XQ.size(); + int group_count; + if constexpr (std::is_same_v) { + group_count = XQ.size(0); + } else { + group_count = XQ.size(); + } using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<2>; using GemmDesc = ck::tensor_operation::device::GemmDesc; @@ -148,21 +154,50 @@ std::vector f8f8bf16_rowwise_grouped_impl( B_args.reserve(group_count); C_args.reserve(group_count); D_args.reserve(group_count); + int M; + int K; + int N; + // Declare pointers to input and output buffers. + ADataType* a_ptr; + BDataType* b_ptr; + EDataType* c_ptr; + D0DataType* d0_ptr; + D1DataType* d1_ptr; // Populate arguments. for (int i = 0; i < group_count; i++) { + // Compute appropriate data pointers. // Set the shape arguments for this gemm. - int M = XQ[i].size(0); - int K = XQ[i].size(1); - int N = WQ[i].size(0); + if constexpr (std::is_same_v) { + M = XQ.size(1); + N = WQ.size(1); + K = XQ.size(2); + a_ptr = reinterpret_cast(XQ.data_ptr()) + (i * M * K); + b_ptr = reinterpret_cast(WQ.data_ptr()) + (i * N * K); + d0_ptr = reinterpret_cast(w_scale.data_ptr()) + (i * N); + d1_ptr = reinterpret_cast(x_scale.data_ptr()) + (i * M); + } else { + M = XQ[i].size(0); + N = WQ[i].size(0); + K = XQ[i].size(1); + a_ptr = reinterpret_cast(XQ[i].data_ptr()); + b_ptr = reinterpret_cast(WQ[i].data_ptr()); + d0_ptr = reinterpret_cast(w_scale[i].data_ptr()); + d1_ptr = reinterpret_cast(x_scale[i].data_ptr()); + } + if constexpr (std::is_same_v) { + c_ptr = reinterpret_cast(Y.data_ptr()) + (i * M * N); + } else { + c_ptr = reinterpret_cast(Y[i].data_ptr()); + } + GemmDesc gemm_desc = {M, N, K, K, K, N, {0, 0}}; gemm_descs.push_back(gemm_desc); + // Set pointers to inputs and outputs. - A_args.push_back(reinterpret_cast(XQ[i].data_ptr())); - B_args.push_back(reinterpret_cast(WQ[i].data_ptr())); - C_args.push_back(reinterpret_cast(Y[i].data_ptr())); - D_args.emplace_back(std::array{ - reinterpret_cast(w_scale[i].data_ptr()), - reinterpret_cast(x_scale[i].data_ptr())}); + A_args.push_back(a_ptr); + B_args.push_back(b_ptr); + C_args.push_back(c_ptr); + D_args.emplace_back(std::array{d0_ptr, d1_ptr}); } // Create gemm launcher and arguments. diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_kernel_manifest.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_kernel_manifest.h index ff7de71e75..181d8b27f9 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_kernel_manifest.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_kernel_manifest.h @@ -6,804 +6,714 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include - #include -#define KERNEL_NAME_MAP_ENTRY(name) \ - { #name, name } - -using RowwiseGroupedKernel = std::function( - at::TensorList, - at::TensorList, - at::TensorList, - at::TensorList, - at::Tensor, - std::vector)>; - -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y); - -// Map function for string name to kernel implementation for manual -// specification. -static const std::unordered_map kernel_name_map = { - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2), -}; + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu index 437d3d8bf5..8cd4baf894 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu @@ -303,6 +303,7 @@ __global__ void set_dynamic_kernel_args_kernel( } template < + typename InputType, int TB_M, int TB_N, int TB_K, @@ -310,17 +311,26 @@ template < int TBS_N, int TBS_K, bool PONG> -std::tuple> f8f8bf16_rowwise_grouped_impl( - at::TensorList XQ, // FP8 - at::TensorList WQ, // FP8 - at::TensorList x_scale, - at::TensorList w_scale, +at::Tensor f8f8bf16_rowwise_grouped_impl( + InputType XQ, // FP8 + InputType WQ, // FP8 + InputType x_scale, + InputType w_scale, at::Tensor output, std::optional zero_start_index_M) { - int group_count = XQ.size(); - TORCH_CHECK(WQ.size() == group_count); + int group_count; + at::TensorOptions options; + if constexpr (std::is_same_v) { + group_count = XQ.size(); + options = XQ[0].options(); + TORCH_CHECK(WQ.size() == group_count); + } else { + group_count = XQ.size(0); + options = XQ.options(); + TORCH_CHECK(WQ.size(0) == group_count); + } if (group_count == 0) { - return {at::Tensor(), std::vector()}; + return at::Tensor(); } using GroupedGemmConfigs = GroupedGemmArgs:: GroupedGemmConfigs; @@ -328,8 +338,7 @@ std::tuple> f8f8bf16_rowwise_grouped_impl( int64_t total_output_size = 0; std::vector output_sizes; output_sizes.reserve(group_count); - at::Tensor output_args = - at::empty({group_count}, XQ[0].options().dtype(at::kLong)); + at::Tensor output_args = at::empty({group_count}, options.dtype(at::kLong)); const int64_t problem_shape_size = group_count * ((int64_t)sizeof(GroupedGemmArgs::ProblemShape::UnderlyingProblemShape)); @@ -341,7 +350,7 @@ std::tuple> f8f8bf16_rowwise_grouped_impl( // number at::Tensor input_args = at::empty( {group_count * 4 + problem_shape_size + stride_size * 3 + 1000}, - XQ[0].options().dtype(at::kLong)); + options.dtype(at::kLong)); int xq_ptr_offset = 0; int wq_ptr_offset = group_count * sizeof(int64_t); @@ -351,10 +360,18 @@ std::tuple> f8f8bf16_rowwise_grouped_impl( int stride_buf_offset = group_count * 4 * sizeof(int64_t) + problem_shape_size; - for (int i = 0; i < group_count; ++i) { - const int64_t output_size = XQ[i].size(0) * WQ[i].size(0); - total_output_size += output_size; - output_sizes.push_back(output_size); + if constexpr (std::is_same_v) { + for (int i = 0; i < group_count; ++i) { + const int64_t output_size = XQ[i].size(0) * WQ[i].size(0); + total_output_size += output_size; + output_sizes.push_back(output_size); + } + } else { + // When inputs are pregrouped, output has G * M * N elements. + total_output_size = group_count * XQ.size(1) * WQ.size(1); + for (int i = 0; i < group_count; ++i) { + output_sizes.push_back(XQ.size(1) * WQ.size(1)); + } } int blockSize = 256; @@ -369,17 +386,39 @@ std::tuple> f8f8bf16_rowwise_grouped_impl( // Set arguments for (int i = 0; i < group_count; ++i) { - int N = WQ[i].size(0); - int K = XQ[i].size(1); - TORCH_CHECK_EQ(WQ[i].size(1), K); + int M, N, K; + int64_t xq_ptr, wq_ptr, x_scale_ptr, w_scale_ptr; + // Compute buffer pointers based on input type. + if constexpr (std::is_same_v) { + M = XQ[i].size(0); + N = WQ[i].size(0); + K = XQ[i].size(1); + TORCH_CHECK_EQ(WQ[i].size(1), K); + // Calculate data pointer for this group. + xq_ptr = reinterpret_cast(XQ[i].data_ptr()); + wq_ptr = reinterpret_cast(WQ[i].data_ptr()); + x_scale_ptr = reinterpret_cast(x_scale[i].data_ptr()); + w_scale_ptr = reinterpret_cast(w_scale[i].data_ptr()); + } else { + M = XQ.size(1); + N = WQ.size(1); + K = XQ.size(2); + // Calculate data pointer for this group. + xq_ptr = reinterpret_cast(XQ.data_ptr()) + + (i * M * K * sizeof(GroupedGemmArgs::ElementInputA)); + wq_ptr = reinterpret_cast(WQ.data_ptr()) + + (i * N * K * sizeof(GroupedGemmArgs::ElementInputB)); + x_scale_ptr = reinterpret_cast(x_scale.data_ptr()) + + (i * M * sizeof(GroupedGemmArgs::ElementAccumulator)); + w_scale_ptr = reinterpret_cast(w_scale.data_ptr()) + + (i * N * sizeof(GroupedGemmArgs::ElementAccumulator)); + } if (zero_start_index_M.has_value() == true) { set_dynamic_kernel_args_kernel<<>>( - reinterpret_cast(XQ[i].data_ptr()), - reinterpret_cast(WQ[i].data_ptr()), - reinterpret_cast( - x_scale[i].data_ptr()), - reinterpret_cast( - w_scale[i].data_ptr()), + xq_ptr, + wq_ptr, + x_scale_ptr, + w_scale_ptr, input_args.data_ptr(), output_args.data_ptr(), output.data_ptr(), @@ -398,14 +437,11 @@ std::tuple> f8f8bf16_rowwise_grouped_impl( N, K); } else { - int M = XQ[i].size(0); set_kernel_args_kernel<<>>( - reinterpret_cast(XQ[i].data_ptr()), - reinterpret_cast(WQ[i].data_ptr()), - reinterpret_cast( - x_scale[i].data_ptr()), - reinterpret_cast( - w_scale[i].data_ptr()), + xq_ptr, + wq_ptr, + x_scale_ptr, + w_scale_ptr, input_args.data_ptr(), output_args.data_ptr(), output.data_ptr(), @@ -508,103 +544,157 @@ std::tuple> f8f8bf16_rowwise_grouped_impl( C10_CUDA_KERNEL_LAUNCH_CHECK(); - std::vector output_group = output.split(output_sizes); - for (int i = 0; i < group_count; ++i) { - output_group[i] = output_group[i].view({XQ[i].size(0), WQ[i].size(0)}); - } - // Return two views of the same underlying tensor. - return {output, output_group}; + return output; } // FP8 Tensorwise grouped cutlass kernel dispatch. -std::tuple> dispatch_fp8_grouped_kernel( - at::TensorList XQ, // FP8 - at::TensorList WQ, // FP8 - at::TensorList x_scale, - at::TensorList w_scale, +template +at::Tensor dispatch_fp8_grouped_kernel( + InputType XQ, // FP8 + InputType WQ, // FP8 + InputType x_scale, + InputType w_scale, at::Tensor output, std::optional zero_start_index_M) { KernelMode kernel = get_grouped_kernel_mode(XQ, WQ); if (kernel == KernelMode::Small) { - return f8f8bf16_rowwise_grouped_impl<64, 128, 128, 2, 1, 1, true>( - XQ, WQ, x_scale, w_scale, output, zero_start_index_M); + return f8f8bf16_rowwise_grouped_impl< + InputType, + 64, + 128, + 128, + 2, + 1, + 1, + true>(XQ, WQ, x_scale, w_scale, output, zero_start_index_M); } else { - return f8f8bf16_rowwise_grouped_impl<128, 256, 64, 2, 1, 1, false>( - XQ, WQ, x_scale, w_scale, output, zero_start_index_M); + return f8f8bf16_rowwise_grouped_impl< + InputType, + 128, + 256, + 64, + 2, + 1, + 1, + false>(XQ, WQ, x_scale, w_scale, output, zero_start_index_M); } } -std::vector f8f8bf16_rowwise_grouped( +template +OutputType _f8f8bf16_rowwise_grouped( at::TensorList XQ, // FP8 at::TensorList WQ, // FP8 at::TensorList x_scale, at::TensorList w_scale, - std::optional> output = std::nullopt, - std::optional kernel_name = std::nullopt) { + std::optional output = std::nullopt) { at::Tensor Y; int group_count = XQ.size(); + std::vector output_sizes; if (output.has_value()) { - std::vector output_; - output_ = output.value(); - TORCH_CHECK( - output_.size() == group_count, - "Output and input must have same number of groups."); - // Check that output shapes are correct. - for (int i = 0; i < group_count; i++) { - int M = XQ[i].size(0); - int N = WQ[i].size(0); - int out_M = output_[i].size(0); - int out_N = output_[i].size(1); + // Handle initialization check for output list. + if constexpr (std::is_same_v>) { + std::vector output_; + output_ = output.value(); TORCH_CHECK( - M == out_M && N == out_N, - "Output tensors do not have the expected shape."); - TORCH_CHECK( - output_[i].dtype() == at::kBFloat16, - "Output dtype must be bfloat16."); + output_.size() == group_count, + "Output and input must have same number of groups."); + // Check that output shapes are correct. + for (int i = 0; i < group_count; i++) { + int M = XQ[i].size(0); + int N = WQ[i].size(0); + int out_M = output_[i].size(0); + int out_N = output_[i].size(1); + TORCH_CHECK( + M == out_M && N == out_N, + "Output tensors do not have the expected shape."); + TORCH_CHECK( + output_[i].dtype() == at::kBFloat16, + "Output dtype must be bfloat16."); + output_sizes.push_back(out_M * out_N); + } + Y = at::stack(output.value(), 0); + // Handle initialization check for output tensor. + } else { + at::Tensor output_ = output.value(); + // Check that output shapes are correct. + int N = WQ[0].size(0); + for (int i = 0; i < group_count; i++) { + TORCH_CHECK( + WQ[i].size(0) == N, + "N must be static for single output tensor mode."); + TORCH_CHECK( + output_.size(0) == group_count && + output_.size(1) == XQ[i].size(0) && output_.size(2) == N, + "Output must be shape [G, M, N]"); + } + TORCH_CHECK(output_.dtype() == at::kBFloat16, "Output must be bfloat16.") + Y = output_.view(-1); } - Y = at::stack(output.value(), 0); } else { int64_t total_output_size = 0; - std::vector output_sizes; for (int i = 0; i < group_count; ++i) { const int64_t output_size = XQ[i].size(0) * WQ[i].size(0); total_output_size += output_size; output_sizes.push_back(output_size); } - Y = at::zeros(total_output_size, XQ[0].options().dtype(at::kBFloat16)); + Y = at::empty(total_output_size, XQ[0].options().dtype(at::kBFloat16)); + } + + // Run kernel. + at::Tensor g_out = dispatch_fp8_grouped_kernel( + XQ, WQ, x_scale, w_scale, Y, std::nullopt); + + // Return proper output type. + if constexpr (std::is_same_v>) { + // Return grouped view of output. + std::vector output_group = g_out.split(output_sizes); + for (int i = 0; i < group_count; ++i) { + output_group[i] = output_group[i].view({XQ[i].size(0), WQ[i].size(0)}); + } + return output_group; + } else { + // When stacked we assume N is fixed. + return g_out.view({-1, WQ[0].size(0)}); } - // Return grouped view of output. - return std::get<1>( - dispatch_fp8_grouped_kernel(XQ, WQ, x_scale, w_scale, Y, std::nullopt)); } -at::Tensor f8f8bf16_rowwise_grouped_dynamic( +std::vector f8f8bf16_rowwise_grouped( at::TensorList XQ, // FP8 at::TensorList WQ, // FP8 at::TensorList x_scale, at::TensorList w_scale, - std::optional zero_start_index_M = std::nullopt, - std::optional kernel_name = std::nullopt) { + std::optional> output = std::nullopt) { + return _f8f8bf16_rowwise_grouped>( + XQ, WQ, x_scale, w_scale); +} + +at::Tensor f8f8bf16_rowwise_grouped_stacked( + at::TensorList XQ, // FP8 + at::TensorList WQ, // FP8 + at::TensorList x_scale, + at::TensorList w_scale, + std::optional output = std::nullopt) { + return _f8f8bf16_rowwise_grouped(XQ, WQ, x_scale, w_scale); +} + +at::Tensor f8f8bf16_rowwise_grouped_dynamic( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor zero_start_index_M) { at::Tensor Y; - int group_count = XQ.size(); - int64_t total_output_size = 0; - for (int i = 0; i < group_count; ++i) { - const int64_t output_size = XQ[i].size(0) * WQ[i].size(0); - total_output_size += output_size; - } - Y = at::zeros(total_output_size, XQ[0].options().dtype(at::kBFloat16)); + int group_count = XQ.size(0); + int M = XQ.size(1); + int N = WQ.size(1); + int K = XQ.size(0); + int total_output_size = group_count * M * N; + Y = at::zeros(total_output_size, XQ.options().dtype(at::kBFloat16)); // Return continuous view of output. - at::Tensor output = std::get<0>(dispatch_fp8_grouped_kernel( - XQ, WQ, x_scale, w_scale, Y, zero_start_index_M)); + at::Tensor output = dispatch_fp8_grouped_kernel( + XQ, WQ, x_scale, w_scale, Y, zero_start_index_M); // View as proper shape. - // When zero_start_index_M is provided, we can view as [G, M, N] - if (zero_start_index_M.has_value()) { - output = output.view({-1, XQ[0].size(0), WQ[0].size(0)}); - // Otherwise we view as {total_M, N}. - } else { - output = output.view({-1, WQ[0].size(0)}); - } - return output; + return output.view({-1, M, N}); } #else @@ -614,19 +704,27 @@ std::vector f8f8bf16_rowwise_grouped( at::TensorList WQ, // FP8 at::TensorList x_scale, at::TensorList w_scale, - std::optional> output = std::nullopt, - std::optional kernel_name = std::nullopt) { + std::optional> output = std::nullopt) { throw std::runtime_error( "CUDA version is older than 12.0"); // requires CUDA>=12 } -at::Tensor f8f8bf16_rowwise_grouped_dynamic( +at::Tensor f8f8bf16_rowwise_grouped_stacked( at::TensorList XQ, // FP8 at::TensorList WQ, // FP8 at::TensorList x_scale, at::TensorList w_scale, - std::optional zero_start_index_M = std::nullopt, - std::optional kernel_name = std::nullopt) { + std::optional output = std::nullopt) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + +at::Tensor f8f8bf16_rowwise_grouped_dynamic( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor zero_start_index_M) { throw std::runtime_error( "CUDA version is older than 12.0"); // requires CUDA>=12 } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 221597b43f..b6100e0ee0 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -92,16 +92,19 @@ std::vector f8f8bf16_rowwise_grouped( at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, - std::optional> output = std::nullopt, - std::optional kernel_name = std::nullopt); -at::Tensor f8f8bf16_rowwise_grouped_dynamic( + std::optional> output = std::nullopt); +at::Tensor f8f8bf16_rowwise_grouped_stacked( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, - std::optional zero_start_index_M = std::nullopt, - std::optional kernel_name = std::nullopt); -std::vector get_f8f8bf16_rowwise_grouped_kernels(); + std::optional output = std::nullopt); +at::Tensor f8f8bf16_rowwise_grouped_dynamic( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor zero_start_index_M); at::Tensor f8f8bf16_blockwise( at::Tensor XQ, at::Tensor WQ, @@ -190,12 +193,6 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "i8i8bf16_dynamic(Tensor XQ, Tensor WQ, Tensor scale, int split_k=1) -> Tensor"); m.impl("i8i8bf16_dynamic", i8i8bf16_dynamic); -#endif -#ifdef USE_ROCM - m.def("get_f8f8bf16_rowwise_grouped_kernels() -> str[]"); - m.impl( - "get_f8f8bf16_rowwise_grouped_kernels", - get_f8f8bf16_rowwise_grouped_kernels); #endif m.def( "bf16bf16bf16_grouped(Tensor[] X, Tensor[] W, Tensor[](a!)? output=None) -> Tensor[]"); @@ -210,9 +207,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "f8f8bf16_rowwise_batched(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor? bias=None, bool use_fast_accum=True, Tensor(a!)? output=None) -> Tensor"); m.def( - "f8f8bf16_rowwise_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor[](a!)? output=None, str? kernel_name=None) -> Tensor[]"); + "f8f8bf16_rowwise_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor[](a!)? output=None) -> Tensor[]"); + m.def( + "f8f8bf16_rowwise_grouped_stacked(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor(a!)? output=None) -> Tensor"); m.def( - "f8f8bf16_rowwise_grouped_dynamic(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor? zero_start_index_M=None, str? kernel_name=None) -> Tensor"); + "f8f8bf16_rowwise_grouped_dynamic(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor zero_start_index_M) -> Tensor"); m.def( "f8f8bf16_tensorwise(Tensor XQ, Tensor WQ, float scale, bool use_fast_accum=True) -> Tensor"); m.def("per_tensor_quantize_i8(Tensor X, float scale) -> Tensor"); @@ -254,6 +253,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { m.impl("f8f8bf16_rowwise_out", f8f8bf16_rowwise_out); m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched); m.impl("f8f8bf16_rowwise_grouped", f8f8bf16_rowwise_grouped); + m.impl("f8f8bf16_rowwise_grouped_stacked", f8f8bf16_rowwise_grouped_stacked); m.impl("f8f8bf16_rowwise_grouped_dynamic", f8f8bf16_rowwise_grouped_dynamic); m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor); m.impl("quantize_fp8_per_row", quantize_fp8_per_row); @@ -281,6 +281,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise); m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched); m.impl("f8f8bf16_rowwise_grouped", f8f8bf16_rowwise_grouped); + m.impl("f8f8bf16_rowwise_grouped_stacked", f8f8bf16_rowwise_grouped_stacked); m.impl("f8f8bf16_rowwise_grouped_dynamic", f8f8bf16_rowwise_grouped_dynamic); m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor); m.impl("quantize_fp8_per_row", quantize_fp8_per_row); diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index d5d8b378cb..4ca974e5b5 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -737,7 +737,7 @@ def fp8_loopover_bmm( K=st.sampled_from([512, 3584]), use_cudagraph=st.booleans(), use_padding_zeros=st.booleans(), - use_dynamic=st.booleans(), + return_stacked=st.booleans(), ) def test_fp8_grouped_gemm( self, @@ -747,7 +747,7 @@ def test_fp8_grouped_gemm( K: int, use_cudagraph: bool, use_padding_zeros: bool, - use_dynamic: bool, + return_stacked: bool, ) -> None: ms = ( torch.randint( @@ -759,13 +759,14 @@ def test_fp8_grouped_gemm( * 64 ) # When using padding or the dynamic kernel, Ns and Ks should be fixed. - if use_padding_zeros or use_dynamic: + if use_padding_zeros or return_stacked: ns = [N] * G ks = [K] * G # Otherwise, any value is supported. else: - ns = torch.randint(1, (N // 64) + 1, (G,), dtype=torch.int) * 64 - ks = torch.randint(1, (K // 64) + 1, (G,), dtype=torch.int) * 64 + # AMD requires N and K >= 512. + ns = torch.randint(512 // 64, (N // 64) + 1, (G,), dtype=torch.int) * 64 + ks = torch.randint(512 // 64, (K // 64) + 1, (G,), dtype=torch.int) * 64 x_group = [] w_group = [] @@ -800,10 +801,10 @@ def test_fp8_grouped_gemm( if use_padding_zeros: x_group = torch.unbind(torch.stack(x_group, dim=0).contiguous()) w_group = torch.unbind(torch.stack(w_group, dim=0).contiguous()) - xq_group = torch.unbind(torch.stack(xq_group, dim=0).contiguous()) - wq_group = torch.unbind(torch.stack(wq_group, dim=0).contiguous()) - x_scale_group = torch.unbind(torch.stack(x_scale_group, dim=0).contiguous()) - w_scale_group = torch.unbind(torch.stack(w_scale_group, dim=0).contiguous()) + xq_group = torch.stack(xq_group, dim=0).contiguous() + wq_group = torch.stack(wq_group, dim=0).contiguous() + x_scale_group = torch.stack(x_scale_group, dim=0).contiguous() + w_scale_group = torch.stack(w_scale_group, dim=0).contiguous() # FP8 grouped gemm kernel fp8_args = ( @@ -812,15 +813,19 @@ def test_fp8_grouped_gemm( wq_group, x_scale_group, w_scale_group, - zero_start_index_M if use_padding_zeros else None, + zero_start_index_M, ] - if use_dynamic + if use_padding_zeros else [xq_group, wq_group, x_scale_group, w_scale_group] ) fp8_op = ( torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic - if use_dynamic - else torch.ops.fbgemm.f8f8bf16_rowwise_grouped + if use_padding_zeros + else ( + torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked + if return_stacked + else torch.ops.fbgemm.f8f8bf16_rowwise_grouped + ) ) if use_cudagraph: # warmup @@ -843,12 +848,12 @@ def test_fp8_grouped_gemm( # BF16 grouped gemm kernel bf16_args = ( [x_group, w_group, zero_start_index_M if use_padding_zeros else None] - if use_dynamic + if return_stacked else [x_group, w_group] ) bf16_op = ( torch.ops.fbgemm.bf16bf16bf16_grouped_dynamic - if use_dynamic + if return_stacked else torch.ops.fbgemm.bf16bf16bf16_grouped ) if use_cudagraph: @@ -878,7 +883,7 @@ def test_fp8_grouped_gemm( # Assert FP8 outputs for i in range(len(y_group_ref)): torch.testing.assert_close( - y_fp8_group[i], y_group_ref[i], atol=8.0e-2, rtol=8.0e-2 + y_fp8_group[i], y_group_ref[i], atol=8.0e-2, rtol=2.0e-1 ) # Assert BF16 outputs