Skip to content

Commit 775b1e0

Browse files
fluctluxmercykid
authored andcommitted
[Feature] Integrate Suffix Spec Decoding (vllm-project#4045)
### What this PR does / why we need it? This PR integrate suffix decoding (https://arxiv.org/abs/2411.04975) from vllm (vllm-project/vllm#25784) # Suffix Decoding is a dynamic n-gram matching method that: 1. Uses suffix trees to generate speculative tokens quickly using branch frequency counts. 2. Can keep a history of prior model responses, which tends to work very well with repetitive agentic use cases. 3. Can be dynamically updated with newly generated tokens, and FIFO eviction of older requests. # ### Does this PR introduce _any_ user-facing change? This feature should be implemented as opt-in and remain seamless for users who do not require suffix speculative decoding. For users who wish to enable it, they must first install arctic-inference: `pip install arctic-inference ` After installation, the suffix speculative decoding feature can be enabled using the following speculative config: `--speculative_config '{"method": "suffix", "num_speculative_tokens": 5}' ` ### How was this patch tested? This PR is currently being tested on vLLM main:vllm-project/vllm@83f478b with PR vllm-project/vllm#25784 In our previous testing, suffix decoding achieved a 13%-30% throughput improvement over n-gram on the sonnet dataset, tested on vllm-ascend v0.9.1 with concurrency ranging from 2 to 40. - vLLM version: v0.11.2 --------- Signed-off-by: fluctlux <[email protected]> Signed-off-by: Che Ruan <[email protected]>
1 parent eab8ece commit 775b1e0

File tree

7 files changed

+146
-3
lines changed

7 files changed

+146
-3
lines changed

requirements-dev.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ librosa
1919
soundfile
2020
pytest_mock
2121
msserviceprofiler>=1.2.2
22-
mindstudio-probe>=8.3.0
22+
mindstudio-probe>=8.3.0
23+
arctic-inference==0.1.1

tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,88 @@ def test_eagle_correctness(
146146
# Heuristic: expect at least 66% of the prompts to match exactly
147147
# Upon failure, inspect the outputs to check for inaccuracy.
148148
assert matches > int(0.66 * len(ref_outputs))
149+
150+
151+
def test_suffix_correctness(
152+
test_prompts: list[list[dict[str, Any]]],
153+
sampling_config: SamplingParams,
154+
model_name: str,
155+
):
156+
'''
157+
Compare the outputs of a original LLM and a speculative LLM
158+
should be the same when using ngram speculative decoding.
159+
'''
160+
ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=False)
161+
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
162+
del ref_llm
163+
with VllmRunner(model_name,
164+
speculative_config={
165+
"method": "suffix",
166+
"num_speculative_tokens": 8,
167+
},
168+
max_model_len=1024,
169+
enforce_eager=False) as runner:
170+
spec_outputs = runner.model.chat(test_prompts, sampling_config)
171+
matches = 0
172+
misses = 0
173+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
174+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
175+
matches += 1
176+
else:
177+
misses += 1
178+
print(f"ref_output: {ref_output.outputs[0].text}")
179+
print(f"spec_output: {spec_output.outputs[0].text}")
180+
181+
# Heuristic: expect at least 70% of the prompts to match exactly
182+
# Upon failure, inspect the outputs to check for inaccuracy.
183+
assert matches > int(0.66 * len(ref_outputs))
184+
185+
186+
def test_suffix_acceptance(
187+
test_prompts: list[list[dict[str, Any]]],
188+
sampling_config: SamplingParams,
189+
model_name: str,
190+
):
191+
'''
192+
Check that suffix decoding caching takes effect and improves acceptance
193+
lengths and acceptance rates over multiple runs of the same prompts.
194+
'''
195+
num_draft = []
196+
num_accept = []
197+
with VllmRunner(model_name,
198+
speculative_config={
199+
"method": "suffix",
200+
"suffix_decoding_max_spec_factor": 2.0,
201+
"suffix_decoding_max_cached_requests": 1000,
202+
"num_speculative_tokens": 10,
203+
},
204+
max_model_len=1024,
205+
disable_log_stats=False,
206+
enforce_eager=False) as runner:
207+
for i in range(10):
208+
runner.model.chat(test_prompts[i], sampling_config)
209+
metrics = runner.model.get_metrics()
210+
for metric in metrics:
211+
print(metric)
212+
if metric.name == "vllm:spec_decode_num_draft_tokens":
213+
num_draft.append(metric.value)
214+
if metric.name == "vllm:spec_decode_num_accepted_tokens":
215+
num_accept.append(metric.value)
216+
# Calculate the acceptance rates for the first and last runs.
217+
first_accept_tokens = num_accept[0]
218+
first_draft_tokens = num_draft[0]
219+
first_accept_rate = first_accept_tokens / first_draft_tokens
220+
221+
# Take the diff since the stats are cumulative.
222+
last_accept_tokens = num_accept[-1] - num_accept[-2]
223+
last_draft_tokens = num_draft[-1] - num_draft[-2]
224+
last_accept_rate = last_accept_tokens / last_draft_tokens
225+
226+
# Expect the acceptance length to improve.
227+
assert first_accept_tokens < last_accept_tokens
228+
229+
# Expect the acceptance rate to improve.
230+
assert first_accept_rate < last_accept_rate
231+
232+
# Heuristic: expect at least 80% acceptance rate at the end.
233+
assert last_accept_rate > 0.60

vllm_ascend/patch/platform/patch_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def __post_init__(self):
2828
self.quantization = self.target_model_config.quantization
2929
elif self.method in ("ngram", "[ngram]"):
3030
self.model = "ngram"
31+
elif self.method == "suffix":
32+
self.model = "suffix"
3133
else:
3234
raise ValueError("num_speculative_tokens was provided but without "
3335
"speculative model.")
@@ -70,6 +72,10 @@ def __post_init__(self):
7072
# draft related config as None here.
7173
self.draft_model_config = self.target_model_config
7274
self.draft_parallel_config = self.target_parallel_config
75+
elif self.method == "suffix":
76+
self.draft_model_config = self.target_model_config
77+
self.draft_parallel_config = self.target_parallel_config
78+
self._validate_suffix_decoding()
7379
else:
7480
self.prompt_lookup_max = 0
7581
self.prompt_lookup_min = 0

vllm_ascend/spec_decode/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
2020
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
2121
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
22+
from vllm_ascend.spec_decode.suffix_proposer import SuffixDecodingProposer
2223
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
2324

2425

@@ -35,6 +36,8 @@ def get_spec_decode_method(method,
3536
if is_torchair_graph:
3637
return TorchairMtpProposer(vllm_config, device, runner)
3738
return MtpProposer(vllm_config, device, runner)
39+
elif method == 'suffix':
40+
return SuffixDecodingProposer(vllm_config, device, runner)
3841
else:
3942
raise ValueError("Unknown speculative decoding method: "
4043
f"{method}")

vllm_ascend/spec_decode/interface.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class SpecDcodeType(enum.Enum):
1414
EAGLE = 1
1515
EAGLE3 = 2
1616
MTP = 4
17+
SUFFIX = 5
1718

1819

1920
class Proposer:
@@ -51,4 +52,4 @@ def generate_token_ids(self,
5152
attn_metadata=None,
5253
aux_hidden_states: torch.Tensor = None):
5354
"""Called by execute_model in model_runner"""
54-
raise NotImplementedError
55+
raise NotImplementedError
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import torch
2+
from vllm.config import CUDAGraphMode
3+
from vllm.v1.spec_decode.suffix_decoding import \
4+
SuffixDecodingProposer as VllmSuffixDecodingProposer
5+
6+
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
7+
8+
9+
class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer):
10+
11+
def __init__(self, vllm_config, device, runner):
12+
super().__init__(vllm_config)
13+
self.name = SpecDcodeType.SUFFIX
14+
self.device = device
15+
self.runner = runner
16+
17+
def load_model(self, *args, **kwargs):
18+
# No model to load.
19+
pass
20+
21+
@torch.inference_mode()
22+
def dummy_run(self,
23+
num_tokens,
24+
with_prefill=None,
25+
skip_attn=None,
26+
num_reqs=None,
27+
num_tokens_across_dp=None,
28+
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
29+
batch_descriptor=None,
30+
dummy_compute_logits=lambda hidden_states: None):
31+
pass
32+
33+
def generate_token_ids(self,
34+
valid_sampled_token_ids,
35+
sampling_metadata=None,
36+
scheduler_output=None,
37+
spec_decode_metadata=None,
38+
positions=None,
39+
num_scheduled_tokens=None,
40+
hidden_states=None,
41+
attn_metadata=None,
42+
aux_hidden_states=None) -> list[list[int]]:
43+
draft_token_ids = self.propose(self.runner.input_batch,
44+
valid_sampled_token_ids)
45+
return draft_token_ids

vllm_ascend/worker/model_runner_v1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
from vllm.v1.sample.metadata import SamplingMetadata
9797
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
9898
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
99+
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
99100
from vllm.v1.utils import CpuGpuBuffer
100101
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
101102
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@@ -630,7 +631,8 @@ def _set_up_drafter(self):
630631
# Set up speculative decoding.
631632
self.spec_attn_mask = None
632633
self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer,
633-
TorchairMtpProposer]] = None
634+
TorchairMtpProposer,
635+
SuffixDecodingProposer]] = None
634636
self.actual_seq_lengths_q: list[int] = []
635637
self.decode_token_per_req = 1
636638
if self.speculative_config:

0 commit comments

Comments
 (0)