Skip to content

Commit 27388dd

Browse files
committed
fix_logprobs
1 parent 2e16808 commit 27388dd

File tree

6 files changed

+58
-39
lines changed

6 files changed

+58
-39
lines changed

fastdeploy/entrypoints/engine_client.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,10 @@ def valid_parameters(self, data):
418418
# logprobs
419419
logprobs = data.get("logprobs")
420420
top_logprobs = None
421+
is_chat = False
421422

422423
if isinstance(logprobs, bool) and logprobs:
424+
is_chat = True
423425
if not self.enable_logprob:
424426
err_msg = "Logprobs is disabled, please enable it in startup config."
425427
api_server_logger.error(err_msg)
@@ -478,38 +480,40 @@ def valid_parameters(self, data):
478480
raise ValueError("prompt_logprobs", err_msg)
479481

480482
# enable_logprob
481-
if top_logprobs:
483+
if top_logprobs is not None:
482484
if not self.enable_logprob:
483485
err_msg = "Logprobs is disabled, please enable it in startup config."
484486
api_server_logger.error(err_msg)
485-
raise ParameterError("logprobs", err_msg)
487+
raise ParameterError("top_logprobs" if is_chat else "logprobs", err_msg)
486488

487489
if not isinstance(top_logprobs, int):
488490
err_type = type(top_logprobs).__name__
489-
err_msg = f"Invalid type for 'top_logprobs': expected int but got {err_type}."
491+
err_msg = (
492+
f"Invalid type for {'top_logprobs' if is_chat else 'logprobs'}: expected int but got {err_type}."
493+
)
490494
api_server_logger.error(err_msg)
491-
raise ParameterError("top_logprobs", err_msg)
495+
raise ParameterError("top_logprobs" if is_chat else "logprobs", err_msg)
496+
497+
if top_logprobs > max_logprobs:
498+
err_msg = f"Number of {'top_logprobs' if is_chat else 'logprobs'} requested ({top_logprobs}) exceeds maximum allowed value ({max_logprobs})."
499+
api_server_logger.error(err_msg)
500+
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)
492501

493502
if not envs.FD_USE_GET_SAVE_OUTPUT_V1:
494-
if top_logprobs < 0 or top_logprobs > 20:
495-
err_msg = f"top_logprobs must be between 0 and 20; the current value is {top_logprobs}."
503+
if top_logprobs < 0 or top_logprobs > max_logprobs:
504+
err_msg = f"{'top_logprobs' if is_chat else 'logprobs'} must be between 0 and {max_logprobs}; the current value is {top_logprobs}."
496505
api_server_logger.error(err_msg)
497-
raise ValueError("top_logprobs", err_msg)
506+
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)
498507
else:
499508
if top_logprobs == -1 and self.ori_vocab_size > max_logprobs:
500-
err_msg = f"The requested value of ({self.ori_vocab_size}) for top_logprobs (-1) exceeds the maximum allowed value of ({max_logprobs})"
509+
err_msg = f"The requested value of ({self.ori_vocab_size}) for {'top_logprobs' if is_chat else 'logprobs'} (-1) exceeds the maximum allowed value of ({max_logprobs})"
501510
api_server_logger.error(err_msg)
502-
raise ValueError("top_logprobs", err_msg)
511+
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)
503512

504513
if top_logprobs < -1:
505-
err_msg = f"top_logprobs must be a non-negative value or -1; the current value is {top_logprobs}."
506-
api_server_logger.error(err_msg)
507-
raise ValueError("top_logprobs", err_msg)
508-
509-
if top_logprobs > max_logprobs:
510-
err_msg = f"Number of logprobs requested ({top_logprobs}) exceeds maximum allowed value ({max_logprobs})."
514+
err_msg = f"{'top_logprobs' if is_chat else 'logprobs'} must be a non-negative value or -1; the current value is {top_logprobs}."
511515
api_server_logger.error(err_msg)
512-
raise ValueError("top_logprobs", err_msg)
516+
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)
513517

514518
def check_health(self, time_interval_threashold=30):
515519
"""

fastdeploy/entrypoints/llm.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,10 @@ def _add_request(
351351

352352
if current_sampling_params.logprobs is not None:
353353
num_logprobs = current_sampling_params.logprobs
354+
if not self.llm_engine.cfg.model_config.enable_logprob:
355+
raise ValueError(
356+
"logprobs is only supported if `enable_logprob` is set to true in startup config."
357+
)
354358
if num_logprobs == -1 and ori_vocab_size > max_logprobs:
355359
raise ValueError(
356360
f"Number of logprobs(-1) requested ({ori_vocab_size}) exceeds maximum allowed value ({max_logprobs})."
@@ -360,6 +364,10 @@ def _add_request(
360364
f"Number of logprobs requested ({num_logprobs}) exceeds maximum allowed value ({max_logprobs})."
361365
)
362366
if current_sampling_params.prompt_logprobs is not None:
367+
if not self.llm_engine.cfg.model_config.enable_logprob:
368+
raise ValueError(
369+
"prompt_logprobs is only supported if `enable_logprob` is set to true in startup config."
370+
)
363371
if self.llm_engine.cfg.cache_config.enable_prefix_caching:
364372
raise ValueError("prompt_logprobs is not supported with prefix caching enabled.")
365373
if kwargs.get("stream"):
@@ -403,19 +411,18 @@ def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: i
403411
llm_logger.warning("Empty logprob_token_ids in LogprobsLists")
404412
return None
405413

406-
# exclude sampled token at index 0
407-
available_topk = len(logprobs_lists.logprob_token_ids[0]) - 1
414+
available_topk = len(logprobs_lists.logprob_token_ids[0])
408415
effective_topk_logprobs = min(topk_logprobs, available_topk)
409416

410-
if effective_topk_logprobs <= 0:
417+
if effective_topk_logprobs < 0:
411418
llm_logger.warning(
412419
f"Invalid effective_topk_logprobs={effective_topk_logprobs}, "
413420
f"available_topk={available_topk}, topk_logprobs={topk_logprobs}; returning empty result."
414421
)
415422
return None
416423

417-
# sliced 1 ~ (1 + effective_topk_logprobs)
418-
sliced_logprobs_lists = logprobs_lists.slice_columns(1, 1 + effective_topk_logprobs)
424+
# sliced 0 ~ effective_topk_logprobs+1
425+
sliced_logprobs_lists = logprobs_lists.slice_columns(0, effective_topk_logprobs + 1)
419426
result = []
420427
for token_ids, logprobs in zip(sliced_logprobs_lists.logprob_token_ids, sliced_logprobs_lists.logprobs):
421428

@@ -559,7 +566,7 @@ def _run_engine(
559566
result = self.llm_engine.data_processor.process_response(result)
560567

561568
# filter logprobs
562-
if result.outputs.top_logprobs and topk_logprobs:
569+
if result.outputs.top_logprobs is not None and topk_logprobs is not None:
563570
if topk_logprobs == -1:
564571
topk_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
565572
result.outputs.logprobs = self._build_sample_logprobs(

fastdeploy/worker/output.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import paddle
2121

2222

23-
class Logprob(NamedTuple):
23+
@dataclass
24+
class Logprob:
2425
"""
2526
A named tuple containing information about a token's log probability.
2627
"""

tests/entrypoints/test_engine_client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,8 +489,9 @@ def test_top_logprobs_validation_with_fd_use_get_save_output_v1_disabled(self):
489489
data = {"logprobs": True, "top_logprobs": 25, "request_id": "test"}
490490
with self.assertRaises(ValueError) as context:
491491
self.engine_client.valid_parameters(data)
492-
self.assertIn("top_logprobs must be between 0 and 20", str(context.exception))
493-
self.assertIn("current value is 25", str(context.exception))
492+
self.assertIn(
493+
"Number of top_logprobs requested (25) exceeds maximum allowed value (20)", str(context.exception)
494+
)
494495

495496
# Test valid value
496497
data = {"logprobs": True, "top_logprobs": 10, "request_id": "test"}

tests/entrypoints/test_vllm_run_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010

1111

1212
class DummyModelConfig:
13-
def __init__(self, max_logprobs=10, ori_vocab_size=50):
13+
def __init__(self, max_logprobs=10, ori_vocab_size=50, enable_logprob=True):
1414
self.max_logprobs = max_logprobs
1515
self.ori_vocab_size = ori_vocab_size
16+
self.enable_logprob = enable_logprob
1617

1718

1819
class DummyCacheConfig:

tests/utils/test_clamp_prompt_logprobs.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,32 +45,37 @@ def test_normal_logprobs(self):
4545
self.assertEqual(result[0][1].logprob, -2.5)
4646
self.assertEqual(result[0][2].logprob, -1.0)
4747

48-
def test_negative_inf_logprobs_raises_error(self):
49-
"""Test that logprobs containing -inf raises AttributeError"""
48+
def test_negative_inf_logprobs_gets_clamped(self):
49+
"""Test that logprobs containing -inf get clamped to -9999.0"""
5050
logprob_dict = {
5151
1: Logprob(logprob=float("-inf"), rank=1, decoded_token="hello"),
5252
2: Logprob(logprob=-1.0, rank=2, decoded_token="world"),
5353
}
5454
prompt_logprobs = [logprob_dict]
5555

56-
# Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError
57-
with self.assertRaises(AttributeError) as context:
58-
clamp_prompt_logprobs(prompt_logprobs)
56+
# Since Logprob is now a dataclass, its fields can be modified
57+
result = clamp_prompt_logprobs(prompt_logprobs)
5958

60-
self.assertIn("can't set attribute", str(context.exception))
59+
# The -inf value should be clamped to -9999.0
60+
self.assertEqual(result[0][1].logprob, -9999.0)
61+
self.assertEqual(result[0][2].logprob, -1.0) # unchanged
6162

62-
def test_multiple_negative_inf_raises_error(self):
63-
"""Test that multiple -inf logprobs values raise AttributeError"""
63+
def test_multiple_negative_inf_gets_clamped(self):
64+
"""Test that multiple -inf logprobs values get clamped to -9999.0"""
6465
logprob_dict = {
6566
1: Logprob(logprob=float("-inf"), rank=1, decoded_token="hello"),
6667
2: Logprob(logprob=float("-inf"), rank=2, decoded_token="world"),
6768
3: Logprob(logprob=-0.5, rank=3, decoded_token="test"),
6869
}
6970
prompt_logprobs = [logprob_dict]
7071

71-
# Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError
72-
with self.assertRaises(AttributeError):
73-
clamp_prompt_logprobs(prompt_logprobs)
72+
# Since Logprob is now a dataclass, its fields can be modified
73+
result = clamp_prompt_logprobs(prompt_logprobs)
74+
75+
# All -inf values should be clamped to -9999.0
76+
self.assertEqual(result[0][1].logprob, -9999.0)
77+
self.assertEqual(result[0][2].logprob, -9999.0)
78+
self.assertEqual(result[0][3].logprob, -0.5) # unchanged
7479

7580
def test_none_dict_in_list(self):
7681
"""Test case when list contains None"""
@@ -116,15 +121,15 @@ def test_mixed_values_without_inf(self):
116121
self.assertEqual(result[0][4].logprob, -1.5)
117122

118123
def test_return_same_object(self):
119-
"""Test that function returns the same object (in-place modification attempt)"""
124+
"""Test that function returns the same object (in-place modification)"""
120125
logprob_dict = {
121126
1: Logprob(logprob=-2.0, rank=1, decoded_token="hello"),
122127
}
123128
prompt_logprobs = [logprob_dict]
124129

125130
result = clamp_prompt_logprobs(prompt_logprobs)
126131

127-
# Should return the same object (function attempts in-place modification)
132+
# Should return the same object (function performs in-place modification)
128133
self.assertIs(result, prompt_logprobs)
129134
self.assertIs(result[0], prompt_logprobs[0])
130135

0 commit comments

Comments
 (0)