Skip to content

Commit a52aea0

Browse files
authored
fix logprobs (#5335)
1 parent 96ff402 commit a52aea0

File tree

9 files changed

+71
-50
lines changed

9 files changed

+71
-50
lines changed

fastdeploy/entrypoints/engine_client.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -418,13 +418,16 @@ def valid_parameters(self, data):
418418
# logprobs
419419
logprobs = data.get("logprobs")
420420
top_logprobs = None
421+
is_chat = False
421422

422-
if isinstance(logprobs, bool) and logprobs:
423-
if not self.enable_logprob:
424-
err_msg = "Logprobs is disabled, please enable it in startup config."
425-
api_server_logger.error(err_msg)
426-
raise ParameterError("logprobs", err_msg)
427-
top_logprobs = data.get("top_logprobs")
423+
if isinstance(logprobs, bool):
424+
if logprobs:
425+
is_chat = True
426+
if not self.enable_logprob:
427+
err_msg = "Logprobs is disabled, please enable it in startup config."
428+
api_server_logger.error(err_msg)
429+
raise ParameterError("logprobs", err_msg)
430+
top_logprobs = data.get("top_logprobs")
428431
elif isinstance(logprobs, int):
429432
top_logprobs = logprobs
430433
elif logprobs:
@@ -478,38 +481,40 @@ def valid_parameters(self, data):
478481
raise ValueError("prompt_logprobs", err_msg)
479482

480483
# enable_logprob
481-
if top_logprobs:
484+
if top_logprobs is not None:
482485
if not self.enable_logprob:
483486
err_msg = "Logprobs is disabled, please enable it in startup config."
484487
api_server_logger.error(err_msg)
485-
raise ParameterError("logprobs", err_msg)
488+
raise ParameterError("top_logprobs" if is_chat else "logprobs", err_msg)
486489

487490
if not isinstance(top_logprobs, int):
488491
err_type = type(top_logprobs).__name__
489-
err_msg = f"Invalid type for 'top_logprobs': expected int but got {err_type}."
492+
err_msg = (
493+
f"Invalid type for {'top_logprobs' if is_chat else 'logprobs'}: expected int but got {err_type}."
494+
)
495+
api_server_logger.error(err_msg)
496+
raise ParameterError("top_logprobs" if is_chat else "logprobs", err_msg)
497+
498+
if top_logprobs > max_logprobs:
499+
err_msg = f"Number of {'top_logprobs' if is_chat else 'logprobs'} requested ({top_logprobs}) exceeds maximum allowed value ({max_logprobs})."
490500
api_server_logger.error(err_msg)
491-
raise ParameterError("top_logprobs", err_msg)
501+
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)
492502

493503
if not envs.FD_USE_GET_SAVE_OUTPUT_V1:
494-
if top_logprobs < 0 or top_logprobs > 20:
495-
err_msg = f"top_logprobs must be between 0 and 20; the current value is {top_logprobs}."
504+
if top_logprobs < 0 or top_logprobs > max_logprobs:
505+
err_msg = f"{'top_logprobs' if is_chat else 'logprobs'} must be between 0 and {max_logprobs}; the current value is {top_logprobs}."
496506
api_server_logger.error(err_msg)
497-
raise ValueError("top_logprobs", err_msg)
507+
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)
498508
else:
499509
if top_logprobs == -1 and self.ori_vocab_size > max_logprobs:
500-
err_msg = f"The requested value of ({self.ori_vocab_size}) for top_logprobs (-1) exceeds the maximum allowed value of ({max_logprobs})"
510+
err_msg = f"The requested value of ({self.ori_vocab_size}) for {'top_logprobs' if is_chat else 'logprobs'} (-1) exceeds the maximum allowed value of ({max_logprobs})"
501511
api_server_logger.error(err_msg)
502-
raise ValueError("top_logprobs", err_msg)
512+
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)
503513

504514
if top_logprobs < -1:
505-
err_msg = f"top_logprobs must be a non-negative value or -1; the current value is {top_logprobs}."
506-
api_server_logger.error(err_msg)
507-
raise ValueError("top_logprobs", err_msg)
508-
509-
if top_logprobs > max_logprobs:
510-
err_msg = f"Number of logprobs requested ({top_logprobs}) exceeds maximum allowed value ({max_logprobs})."
515+
err_msg = f"{'top_logprobs' if is_chat else 'logprobs'} must be a non-negative value or -1; the current value is {top_logprobs}."
511516
api_server_logger.error(err_msg)
512-
raise ValueError("top_logprobs", err_msg)
517+
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)
513518

514519
def check_health(self, time_interval_threashold=30):
515520
"""

fastdeploy/entrypoints/llm.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,10 @@ def _add_request(
351351

352352
if current_sampling_params.logprobs is not None:
353353
num_logprobs = current_sampling_params.logprobs
354+
if not self.llm_engine.cfg.model_config.enable_logprob:
355+
raise ValueError(
356+
"logprobs is only supported if `enable_logprob` is set to true in startup config."
357+
)
354358
if num_logprobs == -1 and ori_vocab_size > max_logprobs:
355359
raise ValueError(
356360
f"Number of logprobs(-1) requested ({ori_vocab_size}) exceeds maximum allowed value ({max_logprobs})."
@@ -360,6 +364,10 @@ def _add_request(
360364
f"Number of logprobs requested ({num_logprobs}) exceeds maximum allowed value ({max_logprobs})."
361365
)
362366
if current_sampling_params.prompt_logprobs is not None:
367+
if not self.llm_engine.cfg.model_config.enable_logprob:
368+
raise ValueError(
369+
"prompt_logprobs is only supported if `enable_logprob` is set to true in startup config."
370+
)
363371
if self.llm_engine.cfg.cache_config.enable_prefix_caching:
364372
raise ValueError("prompt_logprobs is not supported with prefix caching enabled.")
365373
if kwargs.get("stream"):
@@ -403,19 +411,18 @@ def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: i
403411
llm_logger.warning("Empty logprob_token_ids in LogprobsLists")
404412
return None
405413

406-
# exclude sampled token at index 0
407-
available_topk = len(logprobs_lists.logprob_token_ids[0]) - 1
414+
available_topk = len(logprobs_lists.logprob_token_ids[0])
408415
effective_topk_logprobs = min(topk_logprobs, available_topk)
409416

410-
if effective_topk_logprobs <= 0:
417+
if effective_topk_logprobs < 0:
411418
llm_logger.warning(
412419
f"Invalid effective_topk_logprobs={effective_topk_logprobs}, "
413420
f"available_topk={available_topk}, topk_logprobs={topk_logprobs}; returning empty result."
414421
)
415422
return None
416423

417-
# sliced 1 ~ (1 + effective_topk_logprobs)
418-
sliced_logprobs_lists = logprobs_lists.slice_columns(1, 1 + effective_topk_logprobs)
424+
# sliced 0 ~ effective_topk_logprobs+1
425+
sliced_logprobs_lists = logprobs_lists.slice_columns(0, effective_topk_logprobs + 1)
419426
result = []
420427
for token_ids, logprobs in zip(sliced_logprobs_lists.logprob_token_ids, sliced_logprobs_lists.logprobs):
421428

@@ -559,7 +566,7 @@ def _run_engine(
559566
result = self.llm_engine.data_processor.process_response(result)
560567

561568
# filter logprobs
562-
if result.outputs.top_logprobs and topk_logprobs:
569+
if result.outputs.top_logprobs is not None and topk_logprobs is not None:
563570
if topk_logprobs == -1:
564571
topk_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
565572
result.outputs.logprobs = self._build_sample_logprobs(

fastdeploy/entrypoints/openai/protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ class ChatCompletionRequest(BaseModel):
613613
model: Optional[str] = "default"
614614
frequency_penalty: Optional[float] = Field(None, le=2, ge=-2)
615615
logprobs: Optional[bool] = False
616-
top_logprobs: Optional[int] = 0
616+
top_logprobs: Optional[int] = None
617617
prompt_logprobs: Optional[int] = None
618618
include_draft_logprobs: Optional[bool] = False
619619

fastdeploy/worker/output.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import paddle
2121

2222

23-
class Logprob(NamedTuple):
23+
@dataclass
24+
class Logprob:
2425
"""
2526
A named tuple containing information about a token's log probability.
2627
"""

tests/entrypoints/openai/test_build_sample_logprobs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ def test_build_sample_logprobs_basic(self):
5656

5757
expected = [
5858
{
59-
101: Logprob(logprob=-0.5, rank=1, decoded_token="token_101"),
60-
102: Logprob(logprob=-1.0, rank=2, decoded_token="token_102"),
59+
100: Logprob(logprob=-0.1, rank=1, decoded_token="token_100"),
60+
101: Logprob(logprob=-0.5, rank=2, decoded_token="token_101"),
61+
102: Logprob(logprob=-1.0, rank=3, decoded_token="token_102"),
6162
}
6263
]
6364

@@ -79,7 +80,7 @@ def test_build_sample_logprobs_invalid_topk(self):
7980
logprobs_lists = MagicMock(spec=LogprobsLists)
8081
logprobs_lists.logprob_token_ids = [[100]]
8182
result = self.llm._build_sample_logprobs(logprobs_lists, topk_logprobs=2)
82-
self.assertIsNone(result)
83+
self.assertEqual(result, [])
8384

8485
def test_decode_token(self):
8586
"""

tests/entrypoints/openai/test_chatcompletion_request.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_default_values(self):
4141
req = ChatCompletionRequest(messages=[1])
4242
self.assertEqual(req.model, "default")
4343
self.assertFalse(req.logprobs)
44-
self.assertEqual(req.top_logprobs, 0)
44+
self.assertIsNone(req.top_logprobs)
4545
self.assertEqual(req.n, 1)
4646
self.assertEqual(req.stop, [])
4747

tests/entrypoints/test_engine_client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,8 +489,9 @@ def test_top_logprobs_validation_with_fd_use_get_save_output_v1_disabled(self):
489489
data = {"logprobs": True, "top_logprobs": 25, "request_id": "test"}
490490
with self.assertRaises(ValueError) as context:
491491
self.engine_client.valid_parameters(data)
492-
self.assertIn("top_logprobs must be between 0 and 20", str(context.exception))
493-
self.assertIn("current value is 25", str(context.exception))
492+
self.assertIn(
493+
"Number of top_logprobs requested (25) exceeds maximum allowed value (20)", str(context.exception)
494+
)
494495

495496
# Test valid value
496497
data = {"logprobs": True, "top_logprobs": 10, "request_id": "test"}

tests/entrypoints/test_vllm_run_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010

1111

1212
class DummyModelConfig:
13-
def __init__(self, max_logprobs=10, ori_vocab_size=50):
13+
def __init__(self, max_logprobs=10, ori_vocab_size=50, enable_logprob=True):
1414
self.max_logprobs = max_logprobs
1515
self.ori_vocab_size = ori_vocab_size
16+
self.enable_logprob = enable_logprob
1617

1718

1819
class DummyCacheConfig:

tests/utils/test_clamp_prompt_logprobs.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,32 +45,37 @@ def test_normal_logprobs(self):
4545
self.assertEqual(result[0][1].logprob, -2.5)
4646
self.assertEqual(result[0][2].logprob, -1.0)
4747

48-
def test_negative_inf_logprobs_raises_error(self):
49-
"""Test that logprobs containing -inf raises AttributeError"""
48+
def test_negative_inf_logprobs_gets_clamped(self):
49+
"""Test that logprobs containing -inf get clamped to -9999.0"""
5050
logprob_dict = {
5151
1: Logprob(logprob=float("-inf"), rank=1, decoded_token="hello"),
5252
2: Logprob(logprob=-1.0, rank=2, decoded_token="world"),
5353
}
5454
prompt_logprobs = [logprob_dict]
5555

56-
# Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError
57-
with self.assertRaises(AttributeError) as context:
58-
clamp_prompt_logprobs(prompt_logprobs)
56+
# Since Logprob is now a dataclass, its fields can be modified
57+
result = clamp_prompt_logprobs(prompt_logprobs)
5958

60-
self.assertIn("can't set attribute", str(context.exception))
59+
# The -inf value should be clamped to -9999.0
60+
self.assertEqual(result[0][1].logprob, -9999.0)
61+
self.assertEqual(result[0][2].logprob, -1.0) # unchanged
6162

62-
def test_multiple_negative_inf_raises_error(self):
63-
"""Test that multiple -inf logprobs values raise AttributeError"""
63+
def test_multiple_negative_inf_gets_clamped(self):
64+
"""Test that multiple -inf logprobs values get clamped to -9999.0"""
6465
logprob_dict = {
6566
1: Logprob(logprob=float("-inf"), rank=1, decoded_token="hello"),
6667
2: Logprob(logprob=float("-inf"), rank=2, decoded_token="world"),
6768
3: Logprob(logprob=-0.5, rank=3, decoded_token="test"),
6869
}
6970
prompt_logprobs = [logprob_dict]
7071

71-
# Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError
72-
with self.assertRaises(AttributeError):
73-
clamp_prompt_logprobs(prompt_logprobs)
72+
# Since Logprob is now a dataclass, its fields can be modified
73+
result = clamp_prompt_logprobs(prompt_logprobs)
74+
75+
# All -inf values should be clamped to -9999.0
76+
self.assertEqual(result[0][1].logprob, -9999.0)
77+
self.assertEqual(result[0][2].logprob, -9999.0)
78+
self.assertEqual(result[0][3].logprob, -0.5) # unchanged
7479

7580
def test_none_dict_in_list(self):
7681
"""Test case when list contains None"""
@@ -116,15 +121,15 @@ def test_mixed_values_without_inf(self):
116121
self.assertEqual(result[0][4].logprob, -1.5)
117122

118123
def test_return_same_object(self):
119-
"""Test that function returns the same object (in-place modification attempt)"""
124+
"""Test that function returns the same object (in-place modification)"""
120125
logprob_dict = {
121126
1: Logprob(logprob=-2.0, rank=1, decoded_token="hello"),
122127
}
123128
prompt_logprobs = [logprob_dict]
124129

125130
result = clamp_prompt_logprobs(prompt_logprobs)
126131

127-
# Should return the same object (function attempts in-place modification)
132+
# Should return the same object (function performs in-place modification)
128133
self.assertIs(result, prompt_logprobs)
129134
self.assertIs(result[0], prompt_logprobs[0])
130135

0 commit comments

Comments
 (0)