Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ ignore_missing_imports = True
[mypy-lm_eval.*]
ignore_missing_imports = True

[mypy-triton.*]
ignore_missing_imports = True

[mypy-msprobe.*]
ignore_missing_imports = True
allow_untyped_imports = True
331 changes: 289 additions & 42 deletions vllm_ascend/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,149 @@
import torch
import torch.nn as nn
import vllm.v1.sample.rejection_sampler as rs
from vllm.triton_utils import HAS_TRITON, tl, triton
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import (RejectionSampler,
apply_sampling_constraints,
generate_uniform_probs)
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata

if HAS_TRITON:

@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr, # [batch_size]
is_greedy_ptr, # [batch_size] or None
max_spec_len,
):
req_idx = tl.program_id(0)
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
# re-compilation may happen during runtime when is_greedy_ptr is None.
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr +
req_idx)
if not is_greedy:
# Early exit for non-greedy sampling requests.
return

start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx

rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
target_argmax_id,
)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True

if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens,
bonus_token_id,
)

@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_block_verify_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
NO_DRAFT_PROBS: tl.constexpr,
SUB_BLOCK: tl.constexpr = 1500,
):
req_idx = tl.program_id(0)
is_greedy = tl.load(is_greedy_ptr + req_idx)
if is_greedy:
# Early exit for greedy sampling requests.
return

start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx

rejected = False
pi = 1.0
uniform_prob = 1.0
last_accepted_token_pos = -1

for pos in range(num_draft_tokens):
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
uniform_prob = uniform_prob * tmp_uniform_prob

if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)

pi = min(pi * target_prob / draft_prob, 1.0)
if draft_prob > 0 and pi >= uniform_prob:
last_accepted_token_pos = pos
rejected = False
else:
rejected = True

if last_accepted_token_pos > -1:
for pos in range(last_accepted_token_pos + 1):
token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
token_id)

if rejected:
loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK
global_recovered_id = -1
global_max_p = -1.0
for loop_i in range(loop):
vocab_start = loop_i * SUB_BLOCK
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
tmp_target_prob = tl.load(
target_probs_ptr +
(start_idx + last_accepted_token_pos + 1) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
recovered_id = tl.argmax(tmp_target_prob, axis=-1)
max_p = tl.get_element(tmp_target_prob, (recovered_id, ))
if max_p > global_max_p:
global_max_p = max_p
global_recovered_id = vocab_start + recovered_id
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
last_accepted_token_pos + 1, global_recovered_id)
else:
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens, bonus_token_id)


PLACEHOLDER_TOKEN_ID = -1
GREEDY_TEMPERATURE = -1
# Maximum number of speculative draft tokens allowed per request in a single
Expand Down Expand Up @@ -134,6 +271,9 @@ def rejection_sample(
assert bonus_token_ids.is_contiguous()
assert target_probs.shape == (num_tokens, vocab_size)

# When num_speculative_tokens>=3, using block verify.
using_block_verify = max_spec_len >= 3

# Create output buffer.
output_token_ids = torch.empty(
(batch_size, max_spec_len + 1),
Expand All @@ -149,25 +289,36 @@ def rejection_sample(
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
if min(num_draft_tokens) == 1 and max(
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
rejection_greedy_sample_spec_len_1_pytorch(
output_token_ids,
draft_token_ids,
target_argmax,
bonus_token_ids,
)
else:
rejection_greedy_sample_pytorch(
if HAS_TRITON:
rejection_greedy_sample_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
num_draft_tokens,
max_spec_len,
is_greedy,
max_spec_len,
)
else:
if min(num_draft_tokens) == 1 and max(
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
rejection_greedy_sample_spec_len_1_pytorch(
output_token_ids,
draft_token_ids,
target_argmax,
bonus_token_ids,
)
else:
rejection_greedy_sample_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
num_draft_tokens,
max_spec_len,
is_greedy,
)
if sampling_metadata.all_greedy:
return output_token_ids

Expand All @@ -178,37 +329,68 @@ def rejection_sample(
num_draft_tokens,
sampling_metadata.generators,
device,
)

# Sample recovered tokens for each position.
# [num_tokens]
recovered_token_ids = sample_recovered_tokens(
max_spec_len,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
device,
)
).to(torch.float32)

if not using_block_verify:
# Sample recovered tokens for each position.
# [num_tokens]
recovered_token_ids = sample_recovered_tokens(
max_spec_len,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
device,
)

# Rejection sampling for random sampling requests.
rejection_random_sample_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs is None,
# num_warps=1,
)
# Rejection sampling for random sampling requests.
rejection_random_sample_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs is None,
# num_warps=1,
)
else:
# MagicMTP: Improving acceptance rate with Block Verify.
if HAS_TRITON:
rejection_random_sample_block_verify_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
NO_DRAFT_PROBS=draft_probs is None,
multibuffer=True,
)
else:
rejection_random_sample_block_verify_pytorch(output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs
is None)
return output_token_ids


Expand Down Expand Up @@ -504,4 +686,69 @@ def sample_recovered_tokens_pytorch(
target_probs[token_idx, draft_token_id] = orig_prob


def rejection_random_sample_block_verify_pytorch(
output_token_ids, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens]
draft_probs, # [num_tokens, vocab_size] or None
target_probs, # [num_tokens, vocab_size]
bonus_token_ids, # [batch_size]
uniform_probs, # [num_tokens]
is_greedy, # [batch_size]
max_spec_len,
vocab_size,
IS_NGRAM=False,
):
batch_size = output_token_ids.shape[0]

for req_idx in range(batch_size):
if is_greedy[req_idx]:
continue

if req_idx == 0:
start_idx = 0
else:
start_idx = cu_num_draft_tokens[req_idx - 1].item()
end_idx = cu_num_draft_tokens[req_idx].item()
num_draft_tokens = end_idx - start_idx

rejected = False
pi = 1.0
uniform_prob = 1.0
last_accepted_token_pos = -1
for pos in range(num_draft_tokens):
draft_token_id = draft_token_ids[start_idx + pos].item()

target_prob = target_probs[start_idx + pos, draft_token_id].item()
uniform_prob = uniform_prob * uniform_probs[start_idx + pos].item()

if IS_NGRAM:
draft_prob = 1.0
else:
draft_prob = draft_probs[start_idx + pos,
draft_token_id].item()

pi = min(pi * target_prob / draft_prob, 1.0)

if draft_prob > 0 and pi >= uniform_prob:
last_accepted_token_pos = pos
rejected = False
else:
rejected = True

if last_accepted_token_pos > -1:
for pos in range(last_accepted_token_pos + 1):
draft_token_id = draft_token_ids[start_idx + pos].item()
output_token_ids[req_idx, pos] = draft_token_id

if rejected:
recovered_token_id = torch.argmax(
target_probs[start_idx + last_accepted_token_pos + 1]).item()
output_token_ids[req_idx,
last_accepted_token_pos + 1] = recovered_token_id
else:
bonus_token_id = bonus_token_ids[req_idx].item()
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id


rs.expand_batch_to_tokens = expand_batch_to_tokens
Loading