From 4154d440b40ce392631071b0303a5219bab3a7c4 Mon Sep 17 00:00:00 2001 From: YangLong114514 Date: Fri, 22 May 2026 17:39:05 +0800 Subject: [PATCH 1/2] [KMCompiler] Optimize gelu with rowwise nomask kernels --- mojo_opset/backends/ttx/kernels/npu/gelu.py | 213 +++++++++++++++----- 1 file changed, 157 insertions(+), 56 deletions(-) diff --git a/mojo_opset/backends/ttx/kernels/npu/gelu.py b/mojo_opset/backends/ttx/kernels/npu/gelu.py index a33cf0d1..d3594142 100644 --- a/mojo_opset/backends/ttx/kernels/npu/gelu.py +++ b/mojo_opset/backends/ttx/kernels/npu/gelu.py @@ -1,16 +1,13 @@ import torch import triton import triton.language as tl - +import triton.language.extra.cann.libdevice as libdevice from .utils import libentry -from mojo_opset.backends.ttx.kernels.npu.utils import VEC_ALIGN_BYTES -from mojo_opset.backends.ttx.kernels.utils import align - """ This file contains the implementation of GELU (Gaussian Error Linear Unit) for NPU. -GELU formula: gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))) +GELU formula: gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) Based on Liger Kernel implementation: https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/geglu.py @@ -19,30 +16,29 @@ """ -COL_BLOCKING_THRESHOLD = 4096 +MAX_BLOCK_SIZE_N = 1024 +GELU_TANH_MAX_BLOCK_SIZE_M = 8 + + +GELU_TANH_BLOCK_SIZE_M_CONFIGS = [ + triton.Config({"BLOCK_SIZE_M": 1}), + triton.Config({"BLOCK_SIZE_M": 2}), + triton.Config({"BLOCK_SIZE_M": 4}), + triton.Config({"BLOCK_SIZE_M": 8}), +] @triton.jit def gelu_tanh_approx(x): - """GELU activation using tanh approximation""" - sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / π) + """GELU activation using tanh approximation.""" + sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) x_cubed = x * x * x tanh_arg = sqrt_2_over_pi * (x + 0.044715 * x_cubed) - return 0.5 * x * (1 + tl.tanh(tanh_arg)) + return 0.5 * x * (1 + libdevice.tanh(tanh_arg)) @triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE_M": 1}), - triton.Config({"BLOCK_SIZE_M": 2}), - triton.Config({"BLOCK_SIZE_M": 4}), - triton.Config({"BLOCK_SIZE_M": 8}), - triton.Config({"BLOCK_SIZE_M": 12}), - triton.Config({"BLOCK_SIZE_M": 16}), - triton.Config({"BLOCK_SIZE_M": 20}), - triton.Config({"BLOCK_SIZE_M": 24}), - triton.Config({"BLOCK_SIZE_M": 32}), - ], + configs=GELU_TANH_BLOCK_SIZE_M_CONFIGS, key=["n_rows", "n_cols"], ) @libentry() @@ -85,17 +81,81 @@ def _gelu_fwd_kernel( @triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE_M": 1}), - triton.Config({"BLOCK_SIZE_M": 2}), - triton.Config({"BLOCK_SIZE_M": 4}), - triton.Config({"BLOCK_SIZE_M": 8}), - triton.Config({"BLOCK_SIZE_M": 12}), - triton.Config({"BLOCK_SIZE_M": 16}), - triton.Config({"BLOCK_SIZE_M": 20}), - triton.Config({"BLOCK_SIZE_M": 24}), - triton.Config({"BLOCK_SIZE_M": 32}), - ], + configs=GELU_TANH_BLOCK_SIZE_M_CONFIGS, + key=["n_rows", "n_cols"], +) +@libentry() +@triton.jit +def _gelu_fwd_nomask_kernel( + x, + y, + stride_row, + n_rows, + n_cols, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + grid_size = tl.num_programs(axis=0) + + num_row_tasks = (n_rows + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + + for row_task_id in range(pid, num_row_tasks, grid_size): + block_start_row = row_task_id * BLOCK_SIZE_M + rows_off = block_start_row + tl.arange(0, BLOCK_SIZE_M) + + for col_offset in range(0, n_cols, BLOCK_SIZE_N): + cols_off = col_offset + tl.arange(0, BLOCK_SIZE_N) + + x_ptrs = x + rows_off[:, None] * stride_row + cols_off[None, :] + y_ptrs = y + rows_off[:, None] * stride_row + cols_off[None, :] + + x_chunk = tl.load(x_ptrs) + x_f32 = x_chunk.to(tl.float32) + y_f32 = gelu_tanh_approx(x_f32) + y_chunk = y_f32.to(x_chunk.dtype) + + tl.store(y_ptrs, y_chunk) + + +@triton.autotune( + configs=GELU_TANH_BLOCK_SIZE_M_CONFIGS, + key=["n_rows", "n_cols"], +) +@libentry() +@triton.jit +def _gelu_fwd_nomask_single_kernel( + x, + y, + stride_row, + n_rows, + n_cols, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + grid_size = tl.num_programs(axis=0) + + num_row_tasks = (n_rows + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + cols_off = tl.arange(0, BLOCK_SIZE_N) + + for row_task_id in range(pid, num_row_tasks, grid_size): + block_start_row = row_task_id * BLOCK_SIZE_M + rows_off = block_start_row + tl.arange(0, BLOCK_SIZE_M) + + x_ptrs = x + rows_off[:, None] * stride_row + cols_off[None, :] + y_ptrs = y + rows_off[:, None] * stride_row + cols_off[None, :] + + x_chunk = tl.load(x_ptrs) + x_f32 = x_chunk.to(tl.float32) + y_f32 = gelu_tanh_approx(x_f32) + y_chunk = y_f32.to(x_chunk.dtype) + + tl.store(y_ptrs, y_chunk) + + +@triton.autotune( + configs=GELU_TANH_BLOCK_SIZE_M_CONFIGS, key=["n_rows", "n_cols"], restore_value=["dy", "dx"], ) @@ -137,11 +197,16 @@ def _gelu_bwd_kernel( sqrt_2_over_pi = 0.7978845608028654 x_cubed = x_f32 * x_f32 * x_f32 tanh_arg = sqrt_2_over_pi * (x_f32 + 0.044715 * x_cubed) - tanh_result = tl.tanh(tanh_arg) + tanh_result = libdevice.tanh(tanh_arg) term1 = 0.5 * (1 + tanh_result) tanh_sq = tanh_result * tanh_result - term2 = 0.5 * x_f32 * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * x_f32 * x_f32)) + term2 = ( + 0.5 + * x_f32 + * (1 - tanh_sq) + * (sqrt_2_over_pi * (1 + 3 * 0.044715 * x_f32 * x_f32)) + ) dgelu_dx = term1 + term2 dx_chunk = dy_chunk * dgelu_dx.to(dy_chunk.dtype) @@ -149,6 +214,33 @@ def _gelu_bwd_kernel( tl.store(dx_ptrs, dx_chunk, mask=block_mask) +def _rowwise_block_size_n(n_cols): + return min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE_N) + + +def _num_vectorcores(): + return triton.runtime.driver.active.utils.get_device_properties("npu")[ + "num_vectorcore" + ] + + +def _rowwise_grid(n_rows, block_size_m): + num_row_tasks = (n_rows + block_size_m - 1) // block_size_m + return (min(_num_vectorcores(), num_row_tasks),) + + +def _rowwise_autotune_grid(n_rows): + return lambda META: _rowwise_grid(n_rows, META["BLOCK_SIZE_M"]) + + +def _can_use_nomask_kernel(n_rows, n_cols, block_size_n): + return n_cols % block_size_n == 0 and n_rows % GELU_TANH_MAX_BLOCK_SIZE_M == 0 + + +def _can_use_nomask_single_kernel(n_rows, n_cols, block_size_n): + return n_cols == block_size_n and n_rows % GELU_TANH_MAX_BLOCK_SIZE_M == 0 + + def gelu_fwd_impl(x: torch.Tensor) -> torch.Tensor: """ Forward pass for GELU. @@ -167,22 +259,36 @@ def gelu_fwd_impl(x: torch.Tensor) -> torch.Tensor: y = torch.empty_like(x_2d) - if n_cols > COL_BLOCKING_THRESHOLD: - BLOCK_SIZE_N = 2048 + block_size_n = _rowwise_block_size_n(n_cols) + grid = _rowwise_autotune_grid(n_rows) + + if _can_use_nomask_single_kernel(n_rows, n_cols, block_size_n): + _gelu_fwd_nomask_single_kernel[grid]( + x_2d, + y, + x_2d.stride(0), + n_rows, + n_cols, + BLOCK_SIZE_N=block_size_n, + ) + elif _can_use_nomask_kernel(n_rows, n_cols, block_size_n): + _gelu_fwd_nomask_kernel[grid]( + x_2d, + y, + x_2d.stride(0), + n_rows, + n_cols, + BLOCK_SIZE_N=block_size_n, + ) else: - BLOCK_SIZE_N = align(x, n_cols, VEC_ALIGN_BYTES) - - num_programs = triton.runtime.driver.active.utils.get_device_properties("npu")["num_vectorcore"] - grid = (num_programs,) - - _gelu_fwd_kernel[grid]( - x_2d, - y, - x_2d.stride(0), - n_rows, - n_cols, - BLOCK_SIZE_N=BLOCK_SIZE_N, - ) + _gelu_fwd_kernel[grid]( + x_2d, + y, + x_2d.stride(0), + n_rows, + n_cols, + BLOCK_SIZE_N=block_size_n, + ) return y.reshape(*ori_shape) @@ -210,13 +316,8 @@ def gelu_bwd_impl( dx = torch.empty_like(x_2d) - if n_cols > COL_BLOCKING_THRESHOLD: - BLOCK_SIZE_N = 2048 - else: - BLOCK_SIZE_N = align(dy, n_cols, VEC_ALIGN_BYTES) - - num_programs = triton.runtime.driver.active.utils.get_device_properties("npu")["num_vectorcore"] - grid = (num_programs,) + block_size_n = _rowwise_block_size_n(n_cols) + grid = _rowwise_autotune_grid(n_rows) _gelu_bwd_kernel[grid]( dy_2d, @@ -225,7 +326,7 @@ def gelu_bwd_impl( dy_2d.stride(0), n_rows, n_cols, - BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_N=block_size_n, ) return dx.reshape(*ori_shape) From fce218d86634b0a06c7ef0678e12ebeca6333eea Mon Sep 17 00:00:00 2001 From: YangLong114514 Date: Wed, 17 Jun 2026 16:19:51 +0800 Subject: [PATCH 2/2] [KMCompiler] Address GeLU review feedback --- mojo_opset/backends/ttx/kernels/npu/gelu.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/mojo_opset/backends/ttx/kernels/npu/gelu.py b/mojo_opset/backends/ttx/kernels/npu/gelu.py index d3594142..ae1bdbc9 100644 --- a/mojo_opset/backends/ttx/kernels/npu/gelu.py +++ b/mojo_opset/backends/ttx/kernels/npu/gelu.py @@ -2,7 +2,7 @@ import triton import triton.language as tl import triton.language.extra.cann.libdevice as libdevice -from .utils import libentry +from .utils import get_num_cores, libentry """ This file contains the implementation of GELU (Gaussian Error Linear Unit) for NPU. @@ -17,7 +17,6 @@ MAX_BLOCK_SIZE_N = 1024 -GELU_TANH_MAX_BLOCK_SIZE_M = 8 GELU_TANH_BLOCK_SIZE_M_CONFIGS = [ @@ -27,6 +26,10 @@ triton.Config({"BLOCK_SIZE_M": 8}), ] +GELU_TANH_MAX_BLOCK_SIZE_M = max( + config.kwargs["BLOCK_SIZE_M"] for config in GELU_TANH_BLOCK_SIZE_M_CONFIGS +) + @triton.jit def gelu_tanh_approx(x): @@ -218,15 +221,9 @@ def _rowwise_block_size_n(n_cols): return min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE_N) -def _num_vectorcores(): - return triton.runtime.driver.active.utils.get_device_properties("npu")[ - "num_vectorcore" - ] - - def _rowwise_grid(n_rows, block_size_m): num_row_tasks = (n_rows + block_size_m - 1) // block_size_m - return (min(_num_vectorcores(), num_row_tasks),) + return (max(1, min(get_num_cores("vector"), num_row_tasks)),) def _rowwise_autotune_grid(n_rows):