[KMCompiler][ttx] Optimize Layernorm performance#366
Conversation
Refactor kernel into two stages: Stage 1 computes mean and variance in one pass; Stage 2 performs normalization and writes output.
…and skip infer stats stores
There was a problem hiding this comment.
Code Review
This pull request refactors the LayerNorm forward kernel in Triton to support optimized masking paths via compile-time constants, introduces a single-pass kernel for smaller column sizes, and adds dynamic heuristics for block size selection. The feedback suggests optimizing performance by using a cached get_num_cores helper to avoid CPU launch overhead from direct driver queries, and decoupling drop_rows_mask from drop_cols_mask to enable optimized paths for partially aligned shapes. Additionally, it is recommended to simplify the code by deduplicating weight and bias loading, and to monitor potential numerical instability arising from the variance formula used.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| mean = mean_acc / n_cols | ||
| var = var_acc / n_cols - mean * mean | ||
| var = tl.maximum(var, 0.0) |
There was a problem hiding this comment.
Using the formula Var(x) = E[x^2] - E[x]^2 can lead to catastrophic cancellation and numerical instability when the input values are large and the variance is small (especially in FP16/BF16, even with FP32 accumulation). While this enables a 2-pass implementation instead of 3-pass, please be aware of this limitation. If accuracy issues or NaNs are observed in production models, consider using a more stable approach like Welford's algorithm or reverting to the stable two-pass variance calculation.
Description
Optimize NPU LayerNorm forward performance by reducing input reads, eliminating unnecessary masks and statistic writes, and tuning kernel scheduling for different shapes and dtypes.
Changes
Var(x) = E[x^2] - E[x]^2x_centeredin the single-pass pathMean/RSTDallocation and writes during inferenceBLOCK_SIZE_Mbased onn_rows,n_cols, and dtypePerformance
Measured with
torch.float32inputs.Overall speedup: 1.14x–2.51x.
Accuracy
Accuracy tests in
mojo_opset/tests/accuracy/operators/test_normalization.pypassed.