Skip to content

Commit 2ebbacf

Browse files
authored
[Param] Recheck and update repetition penalty parameter (#202)
* sampling_metadata -> sampling_state * update repetition penalties * correct repetition penalties calculation use vLLM and HF approaches * construct mask_prompt and transfer to repetition penalty * fix comments * fix penalty calculation * add mask_prompt in correct place * fix dim for scatter * fix device
1 parent 4cb330f commit 2ebbacf

File tree

6 files changed

+137
-76
lines changed

6 files changed

+137
-76
lines changed

serve/mlc_serve/engine/engine_common.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Common utilites for engine classes.
33
"""
44

5+
import torch
56
import time
67
from typing import Tuple, Deque, Dict, Optional, Union, Callable, List
78
from collections import deque
@@ -240,6 +241,18 @@ def prepare_output(
240241
return delta, out_logprob_info
241242

242243

244+
def set_mask_prompt_to(state: RequestState):
245+
# Prompt tokens
246+
tokens=torch.tensor(state.prompt_token_ids, dtype=torch.long)
247+
vocab_size = state.sampling_params.vocab_size
248+
bin_counts = torch.zeros((vocab_size + 1,),
249+
dtype=torch.long,
250+
device=tokens.device)
251+
bin_counts.scatter_add_(0, tokens, torch.ones_like(tokens))
252+
bin_counts = bin_counts[:vocab_size]
253+
state.sampling_params.mask_prompt = bin_counts > 0
254+
255+
243256
def get_requests_to_process(
244257
current_states: list[RequestState],
245258
cache_manager: KVCacheManager,
@@ -264,6 +277,9 @@ def get_requests_to_process(
264277
if is_prompt_batch:
265278
for state in current_states:
266279
if is_evicted_parallel_sampling_request(state):
280+
# TODO(vvchernov): we still need mask if apply_penalty = True
281+
# if state.sampling_params.repetition_penalty != 1.0:
282+
# set_mask_prompt_to(state)
267283
requests.append(
268284
PrefillRequest(
269285
request_id=state.request_id,
@@ -311,6 +327,9 @@ def get_requests_to_process(
311327
else:
312328
token_ids = state.prompt_token_ids
313329

330+
# TODO(vvchernov): we still need mask if apply_penalty = True
331+
# if state.sampling_params.repetition_penalty != 1.0:
332+
set_mask_prompt_to(state)
314333
requests.append(
315334
PrefillRequest(
316335
request_id=state.request_id,

serve/mlc_serve/engine/sampling_params.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from enum import IntEnum
88
from functools import cached_property
99
from typing import Dict, Optional, Any
10+
import torch
1011

1112
_SAMPLING_EPS = 1e-5
1213
LOGPROB_TOP_K_MAX = 5
@@ -75,6 +76,7 @@ class SamplingParams:
7576
vocab_size = 32000
7677
json_schema: Optional[Dict[str, Any]] = None
7778
logits_processor: Optional[Any] = None
79+
mask_prompt: Optional[torch.Tensor] = None
7880

7981
def __post_init__(self):
8082
if self.logit_bias:

serve/mlc_serve/model/model_common.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def sample_from_logits(
8383
logits: Union[tvm.nd.NDArray, torch.Tensor],
8484
sequence_ids: List[SequenceId],
8585
requests: Sequence[RequestType],
86-
sampling_metadata: SamplingState,
86+
sampling_state: SamplingState,
8787
vocab_size: int,
8888
copy_stream: torch.cuda.Stream,
8989
torch_dtype: torch.dtype,
@@ -110,13 +110,13 @@ def sample_from_logits(
110110
sequence_id, cs_input_ids, logits[i]
111111
)
112112

113-
logits = adjust_logits(logits, sampling_metadata, vocab_size)
113+
logits = adjust_logits(logits, sampling_state, vocab_size)
114114
outputs: List[TextGenerationResult] = []
115115

116116
try:
117117
sampling_output: Optional[SamplingOutput] = sample(
118118
logits,
119-
sampling_metadata,
119+
sampling_state,
120120
)
121121

122122
for i, (new_token, logprob_info) in enumerate(
@@ -142,13 +142,13 @@ def sample_from_logits(
142142
for i in range(batch_size):
143143
sequence_id = sequence_ids[i]
144144
logits_per_token = logits[i]
145-
sampling_param = sampling_metadata.sampling_params[i]
145+
sampling_param = sampling_state.sampling_params[i]
146146
past_decode_tokens_per_request = past_decode_tokens[i]
147147
# NOTE: Rerun the preparation for simplicity.
148148
# Assume this code path is taken rarely and the recomputation overhead is
149149
# marginal.
150150
with torch.cuda.stream(copy_stream):
151-
new_sampling_metadata = SamplingState.from_sampling_params(
151+
new_sampling_state = SamplingState.from_sampling_params(
152152
[sampling_param],
153153
[past_decode_tokens_per_request],
154154
torch_dtype,
@@ -158,7 +158,7 @@ def sample_from_logits(
158158
torch.cuda.current_stream().wait_stream(copy_stream)
159159
maybe_sampling_output: Optional[SamplingOutput] = sample(
160160
torch.unsqueeze(logits_per_token, 0),
161-
new_sampling_metadata,
161+
new_sampling_state,
162162
check_safety=True,
163163
)
164164

serve/mlc_serve/model/sampler.py

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ class SamplingTensors:
5353
mask_top_logprob: torch.Tensor
5454
Mask for requests with top_logprob.
5555
shape: (LOGPROB_TOP_K_MAX) + 1, batch_size,)
56+
mask_prompt: torch.Tensor
57+
Mask for request with repetition penalty (prompt part)
58+
shape: (batch_size, vocab_size)
5659
temperatures: torch.Tensor
5760
Tensor for temperature values
5861
shape: (batch_size, )
@@ -85,6 +88,7 @@ class SamplingTensors:
8588
mask_random: torch.Tensor
8689
mask_greedy: torch.Tensor
8790
mask_top_logprob: torch.Tensor
91+
mask_prompt: torch.Tensor
8892
temperatures: torch.Tensor
8993
top_ps: torch.Tensor
9094
top_ks: torch.Tensor
@@ -102,6 +106,7 @@ def from_lists(
102106
dev,
103107
list_mask_random: List[bool],
104108
list_mask_top_logprob: List[List[bool]],
109+
list_mask_prompt: List[torch.Tensor],
105110
list_temperatures: List[float],
106111
list_top_ps: List[float],
107112
list_top_ks: List[int],
@@ -124,6 +129,7 @@ def from_lists(
124129
)
125130
# `mask_top_logprob` will be on cpu
126131
mask_top_logprob = torch.from_numpy(list_mask_top_logprob)
132+
mask_prompt = torch.stack(list_mask_prompt)
127133
temp = torch.tensor(
128134
list_temperatures,
129135
dtype=dtype,
@@ -185,6 +191,7 @@ def from_lists(
185191
mask_random,
186192
mask_greedy,
187193
mask_top_logprob,
194+
mask_prompt,
188195
temp.to(device=dev, non_blocking=True),
189196
top_ps.to(device=dev, non_blocking=True),
190197
top_ks.to(device=dev, non_blocking=True),
@@ -250,6 +257,7 @@ def from_sampling_params(
250257
vocab_size: int,
251258
):
252259
list_mask_random = []
260+
list_mask_prompt = []
253261
list_temperatures = []
254262
list_top_ps = []
255263
list_top_ks = []
@@ -307,6 +315,7 @@ def from_sampling_params(
307315
list_frequency_penalties.append(param.frequency_penalty)
308316
list_presence_penalties.append(param.presence_penalty)
309317
list_repetition_penalties.append(param.repetition_penalty)
318+
list_mask_prompt.append(param.mask_prompt)
310319

311320
if param.logit_bias_index:
312321
assert param.logit_bias_value
@@ -348,6 +357,7 @@ def from_sampling_params(
348357
dev,
349358
list_mask_random,
350359
list_mask_top_logprob,
360+
list_mask_prompt,
351361
list_temperatures,
352362
list_top_ps,
353363
list_top_ks,
@@ -372,20 +382,39 @@ def from_sampling_params(
372382
)
373383

374384

375-
def adjust_logits(logits, sampling_metadata, vocab_size):
385+
def get_bin_counts_and_mask(
386+
tokens: torch.Tensor,
387+
vocab_size: int,
388+
num_seqs: int,
389+
) -> Tuple[torch.Tensor, torch.Tensor]:
390+
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
391+
dtype=torch.long,
392+
device=tokens.device)
393+
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
394+
bin_counts = bin_counts[:, :vocab_size]
395+
mask = bin_counts > 0
396+
397+
return bin_counts, mask
398+
399+
400+
def adjust_logits(
401+
logits: torch.Tensor,
402+
sampling_state: SamplingState,
403+
vocab_size: int):
376404
batch_size = logits.shape[0]
377405
(
378406
apply_top_p_top_k,
379407
apply_penalty,
380408
apply_bias,
381409
sampling_tensors,
382410
) = (
383-
sampling_metadata.apply_top_p_top_k,
384-
sampling_metadata.apply_penalty,
385-
sampling_metadata.apply_bias,
386-
sampling_metadata.sampling_tensors,
411+
sampling_state.apply_top_p_top_k,
412+
sampling_state.apply_penalty,
413+
sampling_state.apply_bias,
414+
sampling_state.sampling_tensors,
387415
)
388416
(
417+
prompt_mask,
389418
temp_t,
390419
top_ps_t,
391420
top_ks_t,
@@ -396,6 +425,7 @@ def adjust_logits(logits, sampling_metadata, vocab_size):
396425
logit_bias_indices_t,
397426
logit_bias_values_t,
398427
) = (
428+
sampling_tensors.mask_prompt,
399429
sampling_tensors.temperatures,
400430
sampling_tensors.top_ps,
401431
sampling_tensors.top_ks,
@@ -411,20 +441,30 @@ def adjust_logits(logits, sampling_metadata, vocab_size):
411441
# (e.g., repetition penalty, frequency/presence penalty, logit bias, temperature...)
412442
# in the right order.
413443
if apply_penalty:
444+
bin_counts, output_mask = get_bin_counts_and_mask(
445+
past_output_tokens_t,
446+
vocab_size,
447+
batch_size,
448+
)
449+
450+
# It was checked that vLLM and HF approaches for repetition penalty are the same
451+
# For calculation of it their combination is used (see references below)
452+
# Calculate repetition penalty use vLLM approach
453+
# https://github.com/vllm-project/vllm/blob/0580aab02ffe60fee50bddc80b787828eb233c44/vllm/model_executor/layers/sampler.py#L177
454+
# and RepetitionPenaltyLogitsProcessor approach from HF TGI API
455+
# https://github.com/huggingface/transformers/blob/de11e654c962d5b23eb53a4387cd637b01987491/src/transformers/generation/logits_process.py#L332C1-L339C22
456+
# where score is logits
457+
# https://github.com/huggingface/transformers/blob/de11e654c962d5b23eb53a4387cd637b01987491/src/transformers/generation/logits_process.py#L76C1-L78C92
414458
repetition_penalties_t = repetition_penalties_t[:, None].repeat(1, vocab_size)
459+
prompt_mask = prompt_mask.to(repetition_penalties_t.device)
460+
repetition_penalties_t[~(prompt_mask | output_mask)] = 1.0
415461
logits = torch.where(
416462
logits > 0, logits / repetition_penalties_t, logits * repetition_penalties_t
417463
)
418-
bin_counts = torch.zeros(
419-
(batch_size, vocab_size + 1), dtype=torch.long, device=logits.device
420-
)
421-
bin_counts.scatter_add_(
422-
1, past_output_tokens_t, torch.ones_like(past_output_tokens_t)
423-
)
424-
bin_counts = bin_counts[:, :vocab_size]
425-
mask = bin_counts > 0
464+
465+
# Calculate frequency and presence penalties
426466
logits -= frequency_penalties_t.unsqueeze_(dim=1) * bin_counts
427-
logits -= presence_penalties_t.unsqueeze_(dim=1) * mask
467+
logits -= presence_penalties_t.unsqueeze_(dim=1) * output_mask
428468

429469
# Adjust temperature
430470
logits.div_(temp_t.unsqueeze(dim=1))
@@ -447,7 +487,7 @@ class SamplingOutput:
447487

448488
def sample(
449489
logits: torch.Tensor,
450-
sampling_metadata: SamplingState,
490+
sampling_state: SamplingState,
451491
check_safety: bool = False,
452492
) -> SamplingOutput:
453493
def _is_safe_to_sample(prob_like):
@@ -457,7 +497,7 @@ def _is_safe_to_sample(prob_like):
457497
)
458498

459499
res_greedy, res_random = None, None
460-
sampling_tensors = sampling_metadata.sampling_tensors
500+
sampling_tensors = sampling_state.sampling_tensors
461501

462502
batch_size = logits.shape[0]
463503
mask_greedy_t, mask_random_t = (
@@ -466,13 +506,13 @@ def _is_safe_to_sample(prob_like):
466506
)
467507

468508
next_tokens = np.empty((batch_size,), dtype=np.int64)
469-
if sampling_metadata.has_greedy:
509+
if sampling_state.has_greedy:
470510
res_greedy = torch.argmax(logits[mask_greedy_t], -1)
471511
np_mask_greedy = mask_greedy_t.cpu().numpy()
472512
next_tokens[np_mask_greedy] = res_greedy.cpu().numpy()
473513

474514
probs_random = None
475-
if sampling_metadata.has_random:
515+
if sampling_state.has_random:
476516
probs_random = torch.softmax(logits[mask_random_t], dim=-1)
477517
if check_safety and not _is_safe_to_sample(probs_random):
478518
return None
@@ -481,9 +521,9 @@ def _is_safe_to_sample(prob_like):
481521
next_tokens[np_mask_random] = res_random.cpu().numpy()
482522

483523
logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * batch_size
484-
if sampling_metadata.has_logprob:
524+
if sampling_state.has_logprob:
485525
# If everything is random sampling, save one extra softmax
486-
if not sampling_metadata.has_greedy:
526+
if not sampling_state.has_greedy:
487527
assert probs_random is not None
488528
logprobs = torch.log(probs_random)
489529
else:
@@ -494,13 +534,13 @@ def _is_safe_to_sample(prob_like):
494534
all_top_logprobs, all_top_tokens = torch.topk(
495535
extended_logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True
496536
)
497-
mask = sampling_metadata.sampling_tensors.mask_top_logprob
537+
mask = sampling_state.sampling_tensors.mask_top_logprob
498538
top_tokens = all_top_tokens[mask]
499539
top_logprobs = all_top_logprobs[mask]
500-
for idx, batch_idx in enumerate(sampling_metadata.logprob_batch_indices):
540+
for idx, batch_idx in enumerate(sampling_state.logprob_batch_indices):
501541
next_token = next_tokens[batch_idx]
502-
assert sampling_metadata.sampling_params[batch_idx].logprobs
503-
top_k = sampling_metadata.sampling_params[batch_idx].top_logprobs
542+
assert sampling_state.sampling_params[batch_idx].logprobs
543+
top_k = sampling_state.sampling_params[batch_idx].top_logprobs
504544
logprob_infos[batch_idx] = RawLogprobsInfo(
505545
current_token_id=next_token,
506546
current_logprob=logprobs[batch_idx][next_token],

serve/mlc_serve/model/tvm_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def generate_multi_query(
259259
# Prepare sampling tensors in another stream to overlap
260260
# CPU<->GPU data transfer with GPU computation in forward pass.
261261
with torch.cuda.stream(self._copy_stream):
262-
sampling_metadata = SamplingState.from_sampling_params(
262+
sampling_state = SamplingState.from_sampling_params(
263263
sampling_params,
264264
past_decode_tokens,
265265
self.torch_dtype,
@@ -318,7 +318,7 @@ def generate_multi_query(
318318
last_query_logits,
319319
sequence_ids,
320320
requests,
321-
sampling_metadata,
321+
sampling_state,
322322
self.vocab_size,
323323
self._copy_stream,
324324
self.torch_dtype,
@@ -381,7 +381,7 @@ def generate(
381381
# Prepare sampling tensors in another stream to overlap
382382
# CPU<->GPU data transfer with GPU computation in forward pass.
383383
with torch.cuda.stream(self._copy_stream):
384-
sampling_metadata = SamplingState.from_sampling_params(
384+
sampling_state = SamplingState.from_sampling_params(
385385
sampling_params,
386386
past_decode_tokens,
387387
self.torch_dtype,
@@ -502,7 +502,7 @@ def generate(
502502
logits,
503503
sequence_ids,
504504
requests,
505-
sampling_metadata,
505+
sampling_state,
506506
self.vocab_size,
507507
self._copy_stream,
508508
self.torch_dtype,

0 commit comments

Comments
 (0)