From 8bf503fc5bf4e0d2aa2be7d09bd0d5ac46231782 Mon Sep 17 00:00:00 2001 From: lvzheng Date: Thu, 21 May 2026 11:36:13 +0800 Subject: [PATCH] [ttx/npu] Optimize top_k_sampling by replacing argsort with iterative extraction Replace full argsort in _topk_stage1_kernel and _topk_merge_kernel with iterative max/min extraction, only outputting k elements per chunk.Remove _compact_sorted_blocks as it is no longer needed. --- mojo_opset/backends/ttx/kernels/npu/sample.py | 148 +++++++++--------- 1 file changed, 76 insertions(+), 72 deletions(-) diff --git a/mojo_opset/backends/ttx/kernels/npu/sample.py b/mojo_opset/backends/ttx/kernels/npu/sample.py index 698e78304..aadde8159 100644 --- a/mojo_opset/backends/ttx/kernels/npu/sample.py +++ b/mojo_opset/backends/ttx/kernels/npu/sample.py @@ -421,22 +421,31 @@ def _topk_stage1_kernel( off_col = chunk_offset + tl.arange(0, CHUNK_SIZE) row_start = cur_batch * VOCAB_SIZE safe_off_x = tl.where(off_col < VOCAB_SIZE, row_start + off_col, row_start) - batch_y_ptr = y_ptr + cur_batch * ROW_STRIDE + cur_chunk_idx * CHUNK_SIZE - batch_y_index_ptr = y_index_ptr + cur_batch * ROW_STRIDE + cur_chunk_idx * CHUNK_SIZE + batch_y_ptr = y_ptr + cur_batch * ROW_STRIDE + cur_chunk_idx * k + batch_y_index_ptr = y_index_ptr + cur_batch * ROW_STRIDE + cur_chunk_idx * k mask_x = off_col < VOCAB_SIZE pad_value = filter_value if DESCENDING else -filter_value x = tl.load(x_ptr + safe_off_x, mask=mask_x, other=pad_value) x = tl.where(mask_x, x, pad_value) - x_index = tl.where(mask_x, off_col, 0).to(tl.int32) - sort_keys = _pack_sort_keys(x, x_index) - sorted_keys, _ = argsort(sort_keys, tl.zeros_like(sort_keys), 0, descending=DESCENDING) - sorted_x = _unpack_sort_values(sorted_keys) - sorted_index = _unpack_sort_indices(sorted_keys) + x_index = tl.where(mask_x, off_col, 0x7FFFFFFF).to(tl.int32) + + # Iterative max extraction: loop k times to extract top-k elements + for k_idx in range(k): + if DESCENDING: + select_val = tl.max(x, axis=0) + is_max = (x == select_val) + select_orig_idx = tl.min(tl.where(is_max, x_index, 0x7FFFFFFF), axis=0) + else: + select_val = tl.min(x, axis=0) + is_min = (x == select_val) + select_orig_idx = tl.min(tl.where(is_min, x_index, 0x7FFFFFFF), axis=0) + + tl.store(batch_y_ptr + k_idx, select_val) + tl.store(batch_y_index_ptr + k_idx, select_orig_idx) - cols = tl.arange(0, CHUNK_SIZE) - tl.store(batch_y_ptr + cols, sorted_x) - tl.store(batch_y_index_ptr + cols, sorted_index.to(tl.int32)) + # Mask out selected element for next iteration + x = tl.where(x_index == select_orig_idx, pad_value, x) @libentry() @@ -451,37 +460,51 @@ def _topk_merge_kernel( INPUT_ELEMS: tl.constexpr, INPUT_ROW_STRIDE: tl.constexpr, OUTPUT_ROW_STRIDE: tl.constexpr, + TOTAL_TASKS, NEXT_GROUPS: tl.constexpr, GROUP_CHUNKS: tl.constexpr, BLOCK_SIZE: tl.constexpr, DESCENDING: tl.constexpr, ): pid = tl.program_id(0) - cur_batch = pid // NEXT_GROUPS - cur_group = pid % NEXT_GROUPS - - group_start = cur_group * GROUP_CHUNKS * k - off_col = tl.arange(0, BLOCK_SIZE) - valid = (off_col < GROUP_CHUNKS * k) & ((group_start + off_col) < INPUT_ELEMS) - - in_row_start = cur_batch * INPUT_ROW_STRIDE - safe_off_x = tl.where(valid, in_row_start + group_start + off_col, in_row_start) - - pad_value = filter_value if DESCENDING else -filter_value - chunk_x = tl.load(x_ptr + safe_off_x, mask=valid, other=pad_value) - chunk_x = tl.where(valid, chunk_x, pad_value) - chunk_index = tl.load(x_index_ptr + safe_off_x, mask=valid, other=0).to(tl.int32) - chunk_index = tl.where(valid, chunk_index, 0) - - sort_keys = _pack_sort_keys(chunk_x, chunk_index) - sorted_keys, _ = argsort(sort_keys, tl.zeros_like(sort_keys), 0, descending=DESCENDING) - sorted_logits = _unpack_sort_values(sorted_keys) - sorted_index = _unpack_sort_indices(sorted_keys) - - out_row_start = cur_batch * OUTPUT_ROW_STRIDE + cur_group * BLOCK_SIZE - tl.store(y_ptr + out_row_start + off_col, sorted_logits) - tl.store(y_index_ptr + out_row_start + off_col, sorted_index.to(tl.int32)) - + grid_size = tl.num_programs(0) + + for task_id in range(pid, TOTAL_TASKS, grid_size): + cur_batch = task_id // NEXT_GROUPS + cur_group = task_id % NEXT_GROUPS + + group_start = cur_group * GROUP_CHUNKS * k + off_col = tl.arange(0, BLOCK_SIZE) + valid = (off_col < GROUP_CHUNKS * k) & ((group_start + off_col) < INPUT_ELEMS) + + in_row_start = cur_batch * INPUT_ROW_STRIDE + safe_off_x = tl.where(valid, in_row_start + group_start + off_col, in_row_start) + + pad_value = filter_value if DESCENDING else -filter_value + chunk_x = tl.load(x_ptr + safe_off_x, mask=valid, other=pad_value) + chunk_x = tl.where(valid, chunk_x, pad_value) + chunk_index = tl.load(x_index_ptr + safe_off_x, mask=valid, other=0x7FFFFFFF).to(tl.int32) + chunk_index = tl.where(valid, chunk_index, 0x7FFFFFFF) + + out_row_start = cur_batch * OUTPUT_ROW_STRIDE + cur_group * k + + # Iterative max extraction: loop k times to extract top-k elements + for k_idx in range(k): + if DESCENDING: + select_val = tl.max(chunk_x, axis=0) + is_max = (chunk_x == select_val) + select_orig_idx = tl.min(tl.where(is_max, chunk_index, 0x7FFFFFFF), axis=0) + else: + select_val = tl.min(chunk_x, axis=0) + is_min = (chunk_x == select_val) + select_orig_idx = tl.min(tl.where(is_min, chunk_index, 0x7FFFFFFF), axis=0) + + tl.store(y_ptr + out_row_start + k_idx, select_val) + tl.store(y_index_ptr + out_row_start + k_idx, select_orig_idx) + + # Mask out selected element for next iteration + chunk_x = tl.where(chunk_index == select_orig_idx, pad_value, chunk_x) + def top_k_sampling_impl( logits: torch.FloatTensor, @@ -508,19 +531,17 @@ def top_k_sampling_impl( chunk_size = triton.next_power_of_2(top_k) chunk_num = triton.cdiv(vocab_size, chunk_size) - stage1_elem_cnt = chunk_num * top_k - row_stride = stage1_elem_cnt - stage1_sorted_row_stride = chunk_num * chunk_size - + stage1_row_stride = chunk_num * top_k + pad_val = filter_value if descending else -filter_value - stage1_sorted = torch.full((batch_size, stage1_sorted_row_stride), pad_val, device=device, dtype=logits.dtype).contiguous() - stage1_sorted_index = torch.zeros((batch_size, stage1_sorted_row_stride), device=device, dtype=torch.int32).contiguous() - + stage1_out = torch.empty((batch_size, stage1_row_stride), device=device, dtype=logits.dtype) + stage1_out_index = torch.empty((batch_size, stage1_row_stride), device=device, dtype=torch.int32) + stage1_total_tasks = batch_size * chunk_num _topk_stage1_kernel[(min(stage1_total_tasks, 65535),)]( - stage1_sorted, - stage1_sorted_index, + stage1_out, + stage1_out_index, logits_2d, filter_value, top_k, @@ -528,23 +549,14 @@ def top_k_sampling_impl( vocab_size, chunk_num, chunk_size, - stage1_sorted_row_stride, + stage1_row_stride, descending, ) - stage1_out, stage1_out_index = _compact_sorted_blocks( - stage1_sorted, - stage1_sorted_index, - batch_size, - chunk_num, - chunk_size, - top_k, - ) - candidate_vals = stage1_out candidate_idx = stage1_out_index current_groups = chunk_num - current_row_stride = row_stride + current_row_stride = stage1_row_stride max_merge_candidates = 128 merge_group_chunks = max(1, max_merge_candidates // top_k) @@ -558,36 +570,27 @@ def top_k_sampling_impl( next_groups = triton.cdiv(current_groups, group_chunks) next_row_stride = next_groups * top_k block_size = triton.next_power_of_2(group_chunks * top_k) - next_sorted_row_stride = next_groups * block_size - next_sorted_vals = torch.full((batch_size, next_sorted_row_stride), pad_val, device=device, dtype=logits.dtype).contiguous() - next_sorted_idx = torch.zeros((batch_size, next_sorted_row_stride), device=device, dtype=torch.int32).contiguous() + next_vals = torch.empty((batch_size, next_row_stride), device=device, dtype=logits.dtype) + next_idx = torch.empty((batch_size, next_row_stride), device=device, dtype=torch.int32) merge_total_tasks = batch_size * next_groups - _topk_merge_kernel[(merge_total_tasks,)]( - next_sorted_vals, - next_sorted_idx, + _topk_merge_kernel[(min(merge_total_tasks, 65535),)]( + next_vals, + next_idx, candidate_vals, candidate_idx, filter_value, top_k, current_groups * top_k, current_row_stride, - next_sorted_row_stride, + next_row_stride, + merge_total_tasks, next_groups, group_chunks, block_size, descending, ) - next_vals, next_idx = _compact_sorted_blocks( - next_sorted_vals, - next_sorted_idx, - batch_size, - next_groups, - block_size, - top_k, - ) - candidate_vals = next_vals candidate_idx = next_idx current_groups = next_groups @@ -595,11 +598,12 @@ def top_k_sampling_impl( final_candidate_vals = candidate_vals[:, :top_k].contiguous() final_candidate_idx = candidate_idx[:, :top_k].contiguous() + final_probs_dist = torch.nn.functional.softmax(final_candidate_vals, dim=-1) select_index = torch.multinomial(final_probs_dist, num_samples=1) stage2_out_prob = torch.gather(final_probs_dist, dim=-1, index=select_index) stage2_out_token = torch.gather(final_candidate_idx, dim=-1, index=select_index) - + return stage2_out_prob, stage2_out_token