Skip to content

[ttx/npu] Optimize top_k_sampling by replacing argsort with iterative extraction#318

Open
lyujheng wants to merge 1 commit into
XPU-Forces:masterfrom
lyujheng:lvzheng/ttx-optimize-top-k-sampling
Open

[ttx/npu] Optimize top_k_sampling by replacing argsort with iterative extraction#318
lyujheng wants to merge 1 commit into
XPU-Forces:masterfrom
lyujheng:lvzheng/ttx-optimize-top-k-sampling

Conversation

@lyujheng

Copy link
Copy Markdown

Description

Optimize top_k_sampling_impl by replacing full argsort in _topk_stage1_kernel and
_topk_merge_kernel with iterative max/min extraction. Only output k elements per chunk,
eliminating _compact_sorted_blocks and reducing output buffer size.

Changes

  • Replace argsort with iterative max/min extraction in _topk_stage1_kernel and _topk_merge_kernel
  • Add grid-stride loop in _topk_merge_kernel for better core utilization
  • Remove _compact_sorted_blocks (no longer needed)
  • Reduce output buffer from torch.full with full chunk_size to torch.empty with k elements only

Performance

Benchmark on NPU (device latency) using triton-ascend 3.2.0:

Shape (batch, vocab) top_k min_tokens_to_keep Before (us) After (us) Speedup
(120, 151936) 20 1 2,564,932.88 56,548.74 45.4x
(15, 155136) 50 1 661,352.18 26,072.81 25.4x
(15, 155136) 64 1 661,615.28 31,454.01 21.0x
(18, 155136) 100 1 1,194,172.62 59,976.63 19.9x

Accuracy

All accuracy tests passed using triton-ascend 3.2.0. Tested across various top_k values covering
small k, medium k, large k, chunk_size boundary, and extreme cases:

Shape (batch, vocab) top_k min_tokens_to_keep Result
(15, 155136) 10 1 PASSED
(15, 155136) 20 1 PASSED
(120, 151936) 20 1 PASSED
(15, 155136) 50 1 PASSED
(120, 151936) 50 1 PASSED
(15, 155136) 64 1 PASSED
(15, 155136) 100 1 PASSED
(18, 155136) 100 1 PASSED
(15, 155136) 128 1 PASSED
(15, 155136) 256 1 PASSED
(15, 155136) 500 1 PASSED
(15, 155136) 1000 1 PASSED

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Top-K sampling implementation by replacing the sorting-based logic in the stage 1 and merge kernels with an iterative max/min extraction loop. It also introduces grid-stride loops in the merge kernel and optimizes memory management by removing the separate compaction step. Review feedback identifies a critical issue where using 0 as a padding value for indices leads to collisions with valid data at index 0, potentially causing incorrect masking. It is recommended to use a large constant like 0x7FFFFFFF for padding and to use scalar broadcasting in Triton operations for better efficiency.

Comment thread mojo_opset/backends/ttx/kernels/npu/sample.py Outdated
Comment thread mojo_opset/backends/ttx/kernels/npu/sample.py Outdated
… extraction

Replace full argsort in _topk_stage1_kernel and _topk_merge_kernel with iterative max/min extraction,
only outputting k elements per chunk.Remove _compact_sorted_blocks as it is no longer needed.
@lyujheng lyujheng force-pushed the lvzheng/ttx-optimize-top-k-sampling branch from bba7445 to 8bf503f Compare May 21, 2026 04:55
@lyujheng

lyujheng commented Jun 1, 2026

Copy link
Copy Markdown
Author

Hi @wwens7, could you help review this PR when you have a chance?

This PR optimizes the top_k_sampling kernels — replacing argsort with iterative max extraction, achieving 19.9x–45.4x speedup on ascend 910C NPU. All accuracy tests passed.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant