[ttx/npu] Optimize top_k_sampling by replacing argsort with iterative extraction#318
[ttx/npu] Optimize top_k_sampling by replacing argsort with iterative extraction#318lyujheng wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the Top-K sampling implementation by replacing the sorting-based logic in the stage 1 and merge kernels with an iterative max/min extraction loop. It also introduces grid-stride loops in the merge kernel and optimizes memory management by removing the separate compaction step. Review feedback identifies a critical issue where using 0 as a padding value for indices leads to collisions with valid data at index 0, potentially causing incorrect masking. It is recommended to use a large constant like 0x7FFFFFFF for padding and to use scalar broadcasting in Triton operations for better efficiency.
… extraction Replace full argsort in _topk_stage1_kernel and _topk_merge_kernel with iterative max/min extraction, only outputting k elements per chunk.Remove _compact_sorted_blocks as it is no longer needed.
bba7445 to
8bf503f
Compare
|
Hi @wwens7, could you help review this PR when you have a chance? This PR optimizes the top_k_sampling kernels — replacing argsort with iterative max extraction, achieving 19.9x–45.4x speedup on ascend 910C NPU. All accuracy tests passed. Thanks! |
Description
Optimize
top_k_sampling_implby replacing full argsort in_topk_stage1_kerneland_topk_merge_kernelwith iterative max/min extraction. Only output k elements per chunk,eliminating
_compact_sorted_blocksand reducing output buffer size.Changes
_topk_stage1_kerneland_topk_merge_kernel_topk_merge_kernelfor better core utilization_compact_sorted_blocks(no longer needed)torch.fullwith full chunk_size totorch.emptywith k elements onlyPerformance
Benchmark on NPU (device latency) using triton-ascend 3.2.0:
Accuracy
All accuracy tests passed using triton-ascend 3.2.0. Tested across various top_k values covering
small k, medium k, large k, chunk_size boundary, and extreme cases: