Skip to content

[KMCompiler][ttx] Optimize Layernorm performance#366

Open
YangLong114514 wants to merge 4 commits into
XPU-Forces:masterfrom
YangLong114514:KMCompiler-LayerNorm
Open

[KMCompiler][ttx] Optimize Layernorm performance#366
YangLong114514 wants to merge 4 commits into
XPU-Forces:masterfrom
YangLong114514:KMCompiler-LayerNorm

Conversation

@YangLong114514

Copy link
Copy Markdown

Description

Optimize NPU LayerNorm forward performance by reducing input reads, eliminating unnecessary masks and statistic writes, and tuning kernel scheduling for different shapes and dtypes.

Changes

  • Merge mean and variance statistics into one input pass using Var(x) = E[x^2] - E[x]^2
  • Dynamically trim the grid based on the actual number of row tasks
  • Add no-mask paths for tile-aligned shapes
  • Add a single-pass kernel when the hidden dimension fits in one tile
  • Use centered variance and reuse x_centered in the single-pass path
  • Skip Mean/RSTD allocation and writes during inference
  • Tune BLOCK_SIZE_M based on n_rows, n_cols, and dtype
  • Use conservative fp16/bf16 tiles to avoid UB overflow

Performance

Measured with torch.float32 inputs.

Shape Before [us] After [us] Speedup
(8, 8192) 11.9520 4.7712 2.51x
(16, 4096) 8.2624 3.9824 2.07x
(32, 64) 4.0880 2.2608 1.81x
(32, 128) 4.1568 2.3184 1.79x
(32, 2048) 6.4656 3.8000 1.70x
(128, 128) 5.1024 2.7152 1.88x
(128, 2048) 6.7168 5.9072 1.14x
(256, 128) 5.4800 3.2048 1.71x
(256, 512) 6.3760 4.7824 1.33x
(256, 12800) 35.9152 24.5744 1.46x
(1024, 1024) 14.3488 10.5536 1.36x

Overall speedup: 1.14x–2.51x.

Accuracy

Accuracy tests in mojo_opset/tests/accuracy/operators/test_normalization.py passed.

lyujheng and others added 3 commits June 15, 2026 15:52
Refactor kernel into two stages: Stage 1 computes mean and variance
in one pass; Stage 2 performs normalization and writes output.
@YangLong114514 YangLong114514 changed the title [KMCompiler] Optimize Layernorm performance [KMCompiler][ttx] Optimize Layernorm performance Jun 15, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the LayerNorm forward kernel in Triton to support optimized masking paths via compile-time constants, introduces a single-pass kernel for smaller column sizes, and adds dynamic heuristics for block size selection. The feedback suggests optimizing performance by using a cached get_num_cores helper to avoid CPU launch overhead from direct driver queries, and decoupling drop_rows_mask from drop_cols_mask to enable optimized paths for partially aligned shapes. Additionally, it is recommended to simplify the code by deduplicating weight and bias loading, and to monitor potential numerical instability arising from the variance formula used.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread mojo_opset/backends/ttx/kernels/npu/layernorm.py Outdated
Comment on lines +134 to +136
mean = mean_acc / n_cols
var = var_acc / n_cols - mean * mean
var = tl.maximum(var, 0.0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using the formula Var(x) = E[x^2] - E[x]^2 can lead to catastrophic cancellation and numerical instability when the input values are large and the variance is small (especially in FP16/BF16, even with FP32 accumulation). While this enables a 2-pass implementation instead of 3-pass, please be aware of this limitation. If accuracy issues or NaNs are observed in production models, consider using a more stable approach like Welford's algorithm or reverting to the stable two-pass variance calculation.

Comment thread mojo_opset/backends/ttx/kernels/npu/layernorm.py
Comment thread mojo_opset/backends/ttx/kernels/npu/layernorm.py Outdated
Comment thread mojo_opset/backends/ttx/kernels/npu/layernorm.py Outdated
Comment thread mojo_opset/backends/ttx/kernels/npu/layernorm.py Outdated
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants