fix logprobs (#5335)

This commit is contained in:
qwes5s5
2025-12-04 10:38:51 +08:00
committed by GitHub
parent 96ff402d44
commit a52aea073c
9 changed files with 71 additions and 50 deletions

View File

@@ -418,13 +418,16 @@ class EngineClient:
# logprobs
logprobs = data.get("logprobs")
top_logprobs = None
is_chat = False
if isinstance(logprobs, bool) and logprobs:
if not self.enable_logprob:
err_msg = "Logprobs is disabled, please enable it in startup config."
api_server_logger.error(err_msg)
raise ParameterError("logprobs", err_msg)
top_logprobs = data.get("top_logprobs")
if isinstance(logprobs, bool):
if logprobs:
is_chat = True
if not self.enable_logprob:
err_msg = "Logprobs is disabled, please enable it in startup config."
api_server_logger.error(err_msg)
raise ParameterError("logprobs", err_msg)
top_logprobs = data.get("top_logprobs")
elif isinstance(logprobs, int):
top_logprobs = logprobs
elif logprobs:
@@ -478,38 +481,40 @@ class EngineClient:
raise ValueError("prompt_logprobs", err_msg)
# enable_logprob
if top_logprobs:
if top_logprobs is not None:
if not self.enable_logprob:
err_msg = "Logprobs is disabled, please enable it in startup config."
api_server_logger.error(err_msg)
raise ParameterError("logprobs", err_msg)
raise ParameterError("top_logprobs" if is_chat else "logprobs", err_msg)
if not isinstance(top_logprobs, int):
err_type = type(top_logprobs).__name__
err_msg = f"Invalid type for 'top_logprobs': expected int but got {err_type}."
err_msg = (
f"Invalid type for {'top_logprobs' if is_chat else 'logprobs'}: expected int but got {err_type}."
)
api_server_logger.error(err_msg)
raise ParameterError("top_logprobs", err_msg)
raise ParameterError("top_logprobs" if is_chat else "logprobs", err_msg)
if top_logprobs > max_logprobs:
err_msg = f"Number of {'top_logprobs' if is_chat else 'logprobs'} requested ({top_logprobs}) exceeds maximum allowed value ({max_logprobs})."
api_server_logger.error(err_msg)
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)
if not envs.FD_USE_GET_SAVE_OUTPUT_V1:
if top_logprobs < 0 or top_logprobs > 20:
err_msg = f"top_logprobs must be between 0 and 20; the current value is {top_logprobs}."
if top_logprobs < 0 or top_logprobs > max_logprobs:
err_msg = f"{'top_logprobs' if is_chat else 'logprobs'} must be between 0 and {max_logprobs}; the current value is {top_logprobs}."
api_server_logger.error(err_msg)
raise ValueError("top_logprobs", err_msg)
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)
else:
if top_logprobs == -1 and self.ori_vocab_size > max_logprobs:
err_msg = f"The requested value of ({self.ori_vocab_size}) for top_logprobs (-1) exceeds the maximum allowed value of ({max_logprobs})"
err_msg = f"The requested value of ({self.ori_vocab_size}) for {'top_logprobs' if is_chat else 'logprobs'} (-1) exceeds the maximum allowed value of ({max_logprobs})"
api_server_logger.error(err_msg)
raise ValueError("top_logprobs", err_msg)
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)
if top_logprobs < -1:
err_msg = f"top_logprobs must be a non-negative value or -1; the current value is {top_logprobs}."
err_msg = f"{'top_logprobs' if is_chat else 'logprobs'} must be a non-negative value or -1; the current value is {top_logprobs}."
api_server_logger.error(err_msg)
raise ValueError("top_logprobs", err_msg)
if top_logprobs > max_logprobs:
err_msg = f"Number of logprobs requested ({top_logprobs}) exceeds maximum allowed value ({max_logprobs})."
api_server_logger.error(err_msg)
raise ValueError("top_logprobs", err_msg)
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)
def check_health(self, time_interval_threashold=30):
"""

View File

@@ -351,6 +351,10 @@ class LLM:
if current_sampling_params.logprobs is not None:
num_logprobs = current_sampling_params.logprobs
if not self.llm_engine.cfg.model_config.enable_logprob:
raise ValueError(
"logprobs is only supported if `enable_logprob` is set to true in startup config."
)
if num_logprobs == -1 and ori_vocab_size > max_logprobs:
raise ValueError(
f"Number of logprobs(-1) requested ({ori_vocab_size}) exceeds maximum allowed value ({max_logprobs})."
@@ -360,6 +364,10 @@ class LLM:
f"Number of logprobs requested ({num_logprobs}) exceeds maximum allowed value ({max_logprobs})."
)
if current_sampling_params.prompt_logprobs is not None:
if not self.llm_engine.cfg.model_config.enable_logprob:
raise ValueError(
"prompt_logprobs is only supported if `enable_logprob` is set to true in startup config."
)
if self.llm_engine.cfg.cache_config.enable_prefix_caching:
raise ValueError("prompt_logprobs is not supported with prefix caching enabled.")
if kwargs.get("stream"):
@@ -403,19 +411,18 @@ class LLM:
llm_logger.warning("Empty logprob_token_ids in LogprobsLists")
return None
# exclude sampled token at index 0
available_topk = len(logprobs_lists.logprob_token_ids[0]) - 1
available_topk = len(logprobs_lists.logprob_token_ids[0])
effective_topk_logprobs = min(topk_logprobs, available_topk)
if effective_topk_logprobs <= 0:
if effective_topk_logprobs < 0:
llm_logger.warning(
f"Invalid effective_topk_logprobs={effective_topk_logprobs}, "
f"available_topk={available_topk}, topk_logprobs={topk_logprobs}; returning empty result."
)
return None
# sliced 1 ~ (1 + effective_topk_logprobs)
sliced_logprobs_lists = logprobs_lists.slice_columns(1, 1 + effective_topk_logprobs)
# sliced 0 ~ effective_topk_logprobs+1
sliced_logprobs_lists = logprobs_lists.slice_columns(0, effective_topk_logprobs + 1)
result = []
for token_ids, logprobs in zip(sliced_logprobs_lists.logprob_token_ids, sliced_logprobs_lists.logprobs):
@@ -559,7 +566,7 @@ class LLM:
result = self.llm_engine.data_processor.process_response(result)
# filter logprobs
if result.outputs.top_logprobs and topk_logprobs:
if result.outputs.top_logprobs is not None and topk_logprobs is not None:
if topk_logprobs == -1:
topk_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
result.outputs.logprobs = self._build_sample_logprobs(

View File

@@ -613,7 +613,7 @@ class ChatCompletionRequest(BaseModel):
model: Optional[str] = "default"
frequency_penalty: Optional[float] = Field(None, le=2, ge=-2)
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0
top_logprobs: Optional[int] = None
prompt_logprobs: Optional[int] = None
include_draft_logprobs: Optional[bool] = False

View File

@@ -20,7 +20,8 @@ from typing import NamedTuple, Optional
import paddle
class Logprob(NamedTuple):
@dataclass
class Logprob:
"""
A named tuple containing information about a token's log probability.
"""

View File

@@ -56,8 +56,9 @@ class TestBuildSampleLogprobs(unittest.TestCase):
expected = [
{
101: Logprob(logprob=-0.5, rank=1, decoded_token="token_101"),
102: Logprob(logprob=-1.0, rank=2, decoded_token="token_102"),
100: Logprob(logprob=-0.1, rank=1, decoded_token="token_100"),
101: Logprob(logprob=-0.5, rank=2, decoded_token="token_101"),
102: Logprob(logprob=-1.0, rank=3, decoded_token="token_102"),
}
]
@@ -79,7 +80,7 @@ class TestBuildSampleLogprobs(unittest.TestCase):
logprobs_lists = MagicMock(spec=LogprobsLists)
logprobs_lists.logprob_token_ids = [[100]]
result = self.llm._build_sample_logprobs(logprobs_lists, topk_logprobs=2)
self.assertIsNone(result)
self.assertEqual(result, [])
def test_decode_token(self):
"""

View File

@@ -41,7 +41,7 @@ class TestChatCompletionRequest(unittest.TestCase):
req = ChatCompletionRequest(messages=[1])
self.assertEqual(req.model, "default")
self.assertFalse(req.logprobs)
self.assertEqual(req.top_logprobs, 0)
self.assertIsNone(req.top_logprobs)
self.assertEqual(req.n, 1)
self.assertEqual(req.stop, [])

View File

@@ -489,8 +489,9 @@ class TestEngineClientValidParameters(unittest.TestCase):
data = {"logprobs": True, "top_logprobs": 25, "request_id": "test"}
with self.assertRaises(ValueError) as context:
self.engine_client.valid_parameters(data)
self.assertIn("top_logprobs must be between 0 and 20", str(context.exception))
self.assertIn("current value is 25", str(context.exception))
self.assertIn(
"Number of top_logprobs requested (25) exceeds maximum allowed value (20)", str(context.exception)
)
# Test valid value
data = {"logprobs": True, "top_logprobs": 10, "request_id": "test"}

View File

@@ -10,9 +10,10 @@ from fastdeploy.worker.output import Logprob, LogprobsTensors
class DummyModelConfig:
def __init__(self, max_logprobs=10, ori_vocab_size=50):
def __init__(self, max_logprobs=10, ori_vocab_size=50, enable_logprob=True):
self.max_logprobs = max_logprobs
self.ori_vocab_size = ori_vocab_size
self.enable_logprob = enable_logprob
class DummyCacheConfig:

View File

@@ -45,22 +45,23 @@ class TestClampPromptLogprobs(unittest.TestCase):
self.assertEqual(result[0][1].logprob, -2.5)
self.assertEqual(result[0][2].logprob, -1.0)
def test_negative_inf_logprobs_raises_error(self):
"""Test that logprobs containing -inf raises AttributeError"""
def test_negative_inf_logprobs_gets_clamped(self):
"""Test that logprobs containing -inf get clamped to -9999.0"""
logprob_dict = {
1: Logprob(logprob=float("-inf"), rank=1, decoded_token="hello"),
2: Logprob(logprob=-1.0, rank=2, decoded_token="world"),
}
prompt_logprobs = [logprob_dict]
# Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError
with self.assertRaises(AttributeError) as context:
clamp_prompt_logprobs(prompt_logprobs)
# Since Logprob is now a dataclass, its fields can be modified
result = clamp_prompt_logprobs(prompt_logprobs)
self.assertIn("can't set attribute", str(context.exception))
# The -inf value should be clamped to -9999.0
self.assertEqual(result[0][1].logprob, -9999.0)
self.assertEqual(result[0][2].logprob, -1.0) # unchanged
def test_multiple_negative_inf_raises_error(self):
"""Test that multiple -inf logprobs values raise AttributeError"""
def test_multiple_negative_inf_gets_clamped(self):
"""Test that multiple -inf logprobs values get clamped to -9999.0"""
logprob_dict = {
1: Logprob(logprob=float("-inf"), rank=1, decoded_token="hello"),
2: Logprob(logprob=float("-inf"), rank=2, decoded_token="world"),
@@ -68,9 +69,13 @@ class TestClampPromptLogprobs(unittest.TestCase):
}
prompt_logprobs = [logprob_dict]
# Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError
with self.assertRaises(AttributeError):
clamp_prompt_logprobs(prompt_logprobs)
# Since Logprob is now a dataclass, its fields can be modified
result = clamp_prompt_logprobs(prompt_logprobs)
# All -inf values should be clamped to -9999.0
self.assertEqual(result[0][1].logprob, -9999.0)
self.assertEqual(result[0][2].logprob, -9999.0)
self.assertEqual(result[0][3].logprob, -0.5) # unchanged
def test_none_dict_in_list(self):
"""Test case when list contains None"""
@@ -116,7 +121,7 @@ class TestClampPromptLogprobs(unittest.TestCase):
self.assertEqual(result[0][4].logprob, -1.5)
def test_return_same_object(self):
"""Test that function returns the same object (in-place modification attempt)"""
"""Test that function returns the same object (in-place modification)"""
logprob_dict = {
1: Logprob(logprob=-2.0, rank=1, decoded_token="hello"),
}
@@ -124,7 +129,7 @@ class TestClampPromptLogprobs(unittest.TestCase):
result = clamp_prompt_logprobs(prompt_logprobs)
# Should return the same object (function attempts in-place modification)
# Should return the same object (function performs in-place modification)
self.assertIs(result, prompt_logprobs)
self.assertIs(result[0], prompt_logprobs[0])