Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 76 additions & 72 deletions mojo_opset/backends/ttx/kernels/npu/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,22 +421,31 @@ def _topk_stage1_kernel(
off_col = chunk_offset + tl.arange(0, CHUNK_SIZE)
row_start = cur_batch * VOCAB_SIZE
safe_off_x = tl.where(off_col < VOCAB_SIZE, row_start + off_col, row_start)
batch_y_ptr = y_ptr + cur_batch * ROW_STRIDE + cur_chunk_idx * CHUNK_SIZE
batch_y_index_ptr = y_index_ptr + cur_batch * ROW_STRIDE + cur_chunk_idx * CHUNK_SIZE
batch_y_ptr = y_ptr + cur_batch * ROW_STRIDE + cur_chunk_idx * k
batch_y_index_ptr = y_index_ptr + cur_batch * ROW_STRIDE + cur_chunk_idx * k

mask_x = off_col < VOCAB_SIZE
pad_value = filter_value if DESCENDING else -filter_value
x = tl.load(x_ptr + safe_off_x, mask=mask_x, other=pad_value)
x = tl.where(mask_x, x, pad_value)
x_index = tl.where(mask_x, off_col, 0).to(tl.int32)
sort_keys = _pack_sort_keys(x, x_index)
sorted_keys, _ = argsort(sort_keys, tl.zeros_like(sort_keys), 0, descending=DESCENDING)
sorted_x = _unpack_sort_values(sorted_keys)
sorted_index = _unpack_sort_indices(sorted_keys)
x_index = tl.where(mask_x, off_col, 0x7FFFFFFF).to(tl.int32)

# Iterative max extraction: loop k times to extract top-k elements
for k_idx in range(k):
if DESCENDING:
select_val = tl.max(x, axis=0)
is_max = (x == select_val)
select_orig_idx = tl.min(tl.where(is_max, x_index, 0x7FFFFFFF), axis=0)
else:
select_val = tl.min(x, axis=0)
is_min = (x == select_val)
select_orig_idx = tl.min(tl.where(is_min, x_index, 0x7FFFFFFF), axis=0)

tl.store(batch_y_ptr + k_idx, select_val)
tl.store(batch_y_index_ptr + k_idx, select_orig_idx)

cols = tl.arange(0, CHUNK_SIZE)
tl.store(batch_y_ptr + cols, sorted_x)
tl.store(batch_y_index_ptr + cols, sorted_index.to(tl.int32))
# Mask out selected element for next iteration
x = tl.where(x_index == select_orig_idx, pad_value, x)


@libentry()
Expand All @@ -451,37 +460,51 @@ def _topk_merge_kernel(
INPUT_ELEMS: tl.constexpr,
INPUT_ROW_STRIDE: tl.constexpr,
OUTPUT_ROW_STRIDE: tl.constexpr,
TOTAL_TASKS,
NEXT_GROUPS: tl.constexpr,
GROUP_CHUNKS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
DESCENDING: tl.constexpr,
):
pid = tl.program_id(0)
cur_batch = pid // NEXT_GROUPS
cur_group = pid % NEXT_GROUPS

group_start = cur_group * GROUP_CHUNKS * k
off_col = tl.arange(0, BLOCK_SIZE)
valid = (off_col < GROUP_CHUNKS * k) & ((group_start + off_col) < INPUT_ELEMS)

in_row_start = cur_batch * INPUT_ROW_STRIDE
safe_off_x = tl.where(valid, in_row_start + group_start + off_col, in_row_start)

pad_value = filter_value if DESCENDING else -filter_value
chunk_x = tl.load(x_ptr + safe_off_x, mask=valid, other=pad_value)
chunk_x = tl.where(valid, chunk_x, pad_value)
chunk_index = tl.load(x_index_ptr + safe_off_x, mask=valid, other=0).to(tl.int32)
chunk_index = tl.where(valid, chunk_index, 0)

sort_keys = _pack_sort_keys(chunk_x, chunk_index)
sorted_keys, _ = argsort(sort_keys, tl.zeros_like(sort_keys), 0, descending=DESCENDING)
sorted_logits = _unpack_sort_values(sorted_keys)
sorted_index = _unpack_sort_indices(sorted_keys)

out_row_start = cur_batch * OUTPUT_ROW_STRIDE + cur_group * BLOCK_SIZE
tl.store(y_ptr + out_row_start + off_col, sorted_logits)
tl.store(y_index_ptr + out_row_start + off_col, sorted_index.to(tl.int32))

grid_size = tl.num_programs(0)

for task_id in range(pid, TOTAL_TASKS, grid_size):
cur_batch = task_id // NEXT_GROUPS
cur_group = task_id % NEXT_GROUPS

group_start = cur_group * GROUP_CHUNKS * k
off_col = tl.arange(0, BLOCK_SIZE)
valid = (off_col < GROUP_CHUNKS * k) & ((group_start + off_col) < INPUT_ELEMS)

in_row_start = cur_batch * INPUT_ROW_STRIDE
safe_off_x = tl.where(valid, in_row_start + group_start + off_col, in_row_start)

pad_value = filter_value if DESCENDING else -filter_value
chunk_x = tl.load(x_ptr + safe_off_x, mask=valid, other=pad_value)
chunk_x = tl.where(valid, chunk_x, pad_value)
chunk_index = tl.load(x_index_ptr + safe_off_x, mask=valid, other=0x7FFFFFFF).to(tl.int32)
chunk_index = tl.where(valid, chunk_index, 0x7FFFFFFF)

out_row_start = cur_batch * OUTPUT_ROW_STRIDE + cur_group * k

# Iterative max extraction: loop k times to extract top-k elements
for k_idx in range(k):
if DESCENDING:
select_val = tl.max(chunk_x, axis=0)
is_max = (chunk_x == select_val)
select_orig_idx = tl.min(tl.where(is_max, chunk_index, 0x7FFFFFFF), axis=0)
else:
select_val = tl.min(chunk_x, axis=0)
is_min = (chunk_x == select_val)
select_orig_idx = tl.min(tl.where(is_min, chunk_index, 0x7FFFFFFF), axis=0)

tl.store(y_ptr + out_row_start + k_idx, select_val)
tl.store(y_index_ptr + out_row_start + k_idx, select_orig_idx)

# Mask out selected element for next iteration
chunk_x = tl.where(chunk_index == select_orig_idx, pad_value, chunk_x)


def top_k_sampling_impl(
logits: torch.FloatTensor,
Expand All @@ -508,43 +531,32 @@ def top_k_sampling_impl(
chunk_size = triton.next_power_of_2(top_k)
chunk_num = triton.cdiv(vocab_size, chunk_size)

stage1_elem_cnt = chunk_num * top_k
row_stride = stage1_elem_cnt
stage1_sorted_row_stride = chunk_num * chunk_size

stage1_row_stride = chunk_num * top_k

pad_val = filter_value if descending else -filter_value

stage1_sorted = torch.full((batch_size, stage1_sorted_row_stride), pad_val, device=device, dtype=logits.dtype).contiguous()
stage1_sorted_index = torch.zeros((batch_size, stage1_sorted_row_stride), device=device, dtype=torch.int32).contiguous()
stage1_out = torch.empty((batch_size, stage1_row_stride), device=device, dtype=logits.dtype)
stage1_out_index = torch.empty((batch_size, stage1_row_stride), device=device, dtype=torch.int32)

stage1_total_tasks = batch_size * chunk_num
_topk_stage1_kernel[(min(stage1_total_tasks, 65535),)](
stage1_sorted,
stage1_sorted_index,
stage1_out,
stage1_out_index,
logits_2d,
filter_value,
top_k,
stage1_total_tasks,
vocab_size,
chunk_num,
chunk_size,
stage1_sorted_row_stride,
stage1_row_stride,
descending,
)

stage1_out, stage1_out_index = _compact_sorted_blocks(
stage1_sorted,
stage1_sorted_index,
batch_size,
chunk_num,
chunk_size,
top_k,
)

candidate_vals = stage1_out
candidate_idx = stage1_out_index
current_groups = chunk_num
current_row_stride = row_stride
current_row_stride = stage1_row_stride

max_merge_candidates = 128
merge_group_chunks = max(1, max_merge_candidates // top_k)
Expand All @@ -558,48 +570,40 @@ def top_k_sampling_impl(
next_groups = triton.cdiv(current_groups, group_chunks)
next_row_stride = next_groups * top_k
block_size = triton.next_power_of_2(group_chunks * top_k)
next_sorted_row_stride = next_groups * block_size

next_sorted_vals = torch.full((batch_size, next_sorted_row_stride), pad_val, device=device, dtype=logits.dtype).contiguous()
next_sorted_idx = torch.zeros((batch_size, next_sorted_row_stride), device=device, dtype=torch.int32).contiguous()
next_vals = torch.empty((batch_size, next_row_stride), device=device, dtype=logits.dtype)
next_idx = torch.empty((batch_size, next_row_stride), device=device, dtype=torch.int32)
merge_total_tasks = batch_size * next_groups
_topk_merge_kernel[(merge_total_tasks,)](
next_sorted_vals,
next_sorted_idx,
_topk_merge_kernel[(min(merge_total_tasks, 65535),)](
next_vals,
next_idx,
candidate_vals,
candidate_idx,
filter_value,
top_k,
current_groups * top_k,
current_row_stride,
next_sorted_row_stride,
next_row_stride,
merge_total_tasks,
next_groups,
group_chunks,
block_size,
descending,
)

next_vals, next_idx = _compact_sorted_blocks(
next_sorted_vals,
next_sorted_idx,
batch_size,
next_groups,
block_size,
top_k,
)

candidate_vals = next_vals
candidate_idx = next_idx
current_groups = next_groups
current_row_stride = next_row_stride

final_candidate_vals = candidate_vals[:, :top_k].contiguous()
final_candidate_idx = candidate_idx[:, :top_k].contiguous()

final_probs_dist = torch.nn.functional.softmax(final_candidate_vals, dim=-1)
select_index = torch.multinomial(final_probs_dist, num_samples=1)
stage2_out_prob = torch.gather(final_probs_dist, dim=-1, index=select_index)
stage2_out_token = torch.gather(final_candidate_idx, dim=-1, index=select_index)

return stage2_out_prob, stage2_out_token


Expand Down
Loading