mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
fix logprobs (#5335)
This commit is contained in:
@@ -418,13 +418,16 @@ class EngineClient:
|
||||
# logprobs
|
||||
logprobs = data.get("logprobs")
|
||||
top_logprobs = None
|
||||
is_chat = False
|
||||
|
||||
if isinstance(logprobs, bool) and logprobs:
|
||||
if not self.enable_logprob:
|
||||
err_msg = "Logprobs is disabled, please enable it in startup config."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ParameterError("logprobs", err_msg)
|
||||
top_logprobs = data.get("top_logprobs")
|
||||
if isinstance(logprobs, bool):
|
||||
if logprobs:
|
||||
is_chat = True
|
||||
if not self.enable_logprob:
|
||||
err_msg = "Logprobs is disabled, please enable it in startup config."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ParameterError("logprobs", err_msg)
|
||||
top_logprobs = data.get("top_logprobs")
|
||||
elif isinstance(logprobs, int):
|
||||
top_logprobs = logprobs
|
||||
elif logprobs:
|
||||
@@ -478,38 +481,40 @@ class EngineClient:
|
||||
raise ValueError("prompt_logprobs", err_msg)
|
||||
|
||||
# enable_logprob
|
||||
if top_logprobs:
|
||||
if top_logprobs is not None:
|
||||
if not self.enable_logprob:
|
||||
err_msg = "Logprobs is disabled, please enable it in startup config."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ParameterError("logprobs", err_msg)
|
||||
raise ParameterError("top_logprobs" if is_chat else "logprobs", err_msg)
|
||||
|
||||
if not isinstance(top_logprobs, int):
|
||||
err_type = type(top_logprobs).__name__
|
||||
err_msg = f"Invalid type for 'top_logprobs': expected int but got {err_type}."
|
||||
err_msg = (
|
||||
f"Invalid type for {'top_logprobs' if is_chat else 'logprobs'}: expected int but got {err_type}."
|
||||
)
|
||||
api_server_logger.error(err_msg)
|
||||
raise ParameterError("top_logprobs", err_msg)
|
||||
raise ParameterError("top_logprobs" if is_chat else "logprobs", err_msg)
|
||||
|
||||
if top_logprobs > max_logprobs:
|
||||
err_msg = f"Number of {'top_logprobs' if is_chat else 'logprobs'} requested ({top_logprobs}) exceeds maximum allowed value ({max_logprobs})."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)
|
||||
|
||||
if not envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
if top_logprobs < 0 or top_logprobs > 20:
|
||||
err_msg = f"top_logprobs must be between 0 and 20; the current value is {top_logprobs}."
|
||||
if top_logprobs < 0 or top_logprobs > max_logprobs:
|
||||
err_msg = f"{'top_logprobs' if is_chat else 'logprobs'} must be between 0 and {max_logprobs}; the current value is {top_logprobs}."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError("top_logprobs", err_msg)
|
||||
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)
|
||||
else:
|
||||
if top_logprobs == -1 and self.ori_vocab_size > max_logprobs:
|
||||
err_msg = f"The requested value of ({self.ori_vocab_size}) for top_logprobs (-1) exceeds the maximum allowed value of ({max_logprobs})"
|
||||
err_msg = f"The requested value of ({self.ori_vocab_size}) for {'top_logprobs' if is_chat else 'logprobs'} (-1) exceeds the maximum allowed value of ({max_logprobs})"
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError("top_logprobs", err_msg)
|
||||
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)
|
||||
|
||||
if top_logprobs < -1:
|
||||
err_msg = f"top_logprobs must be a non-negative value or -1; the current value is {top_logprobs}."
|
||||
err_msg = f"{'top_logprobs' if is_chat else 'logprobs'} must be a non-negative value or -1; the current value is {top_logprobs}."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError("top_logprobs", err_msg)
|
||||
|
||||
if top_logprobs > max_logprobs:
|
||||
err_msg = f"Number of logprobs requested ({top_logprobs}) exceeds maximum allowed value ({max_logprobs})."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError("top_logprobs", err_msg)
|
||||
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)
|
||||
|
||||
def check_health(self, time_interval_threashold=30):
|
||||
"""
|
||||
|
||||
@@ -351,6 +351,10 @@ class LLM:
|
||||
|
||||
if current_sampling_params.logprobs is not None:
|
||||
num_logprobs = current_sampling_params.logprobs
|
||||
if not self.llm_engine.cfg.model_config.enable_logprob:
|
||||
raise ValueError(
|
||||
"logprobs is only supported if `enable_logprob` is set to true in startup config."
|
||||
)
|
||||
if num_logprobs == -1 and ori_vocab_size > max_logprobs:
|
||||
raise ValueError(
|
||||
f"Number of logprobs(-1) requested ({ori_vocab_size}) exceeds maximum allowed value ({max_logprobs})."
|
||||
@@ -360,6 +364,10 @@ class LLM:
|
||||
f"Number of logprobs requested ({num_logprobs}) exceeds maximum allowed value ({max_logprobs})."
|
||||
)
|
||||
if current_sampling_params.prompt_logprobs is not None:
|
||||
if not self.llm_engine.cfg.model_config.enable_logprob:
|
||||
raise ValueError(
|
||||
"prompt_logprobs is only supported if `enable_logprob` is set to true in startup config."
|
||||
)
|
||||
if self.llm_engine.cfg.cache_config.enable_prefix_caching:
|
||||
raise ValueError("prompt_logprobs is not supported with prefix caching enabled.")
|
||||
if kwargs.get("stream"):
|
||||
@@ -403,19 +411,18 @@ class LLM:
|
||||
llm_logger.warning("Empty logprob_token_ids in LogprobsLists")
|
||||
return None
|
||||
|
||||
# exclude sampled token at index 0
|
||||
available_topk = len(logprobs_lists.logprob_token_ids[0]) - 1
|
||||
available_topk = len(logprobs_lists.logprob_token_ids[0])
|
||||
effective_topk_logprobs = min(topk_logprobs, available_topk)
|
||||
|
||||
if effective_topk_logprobs <= 0:
|
||||
if effective_topk_logprobs < 0:
|
||||
llm_logger.warning(
|
||||
f"Invalid effective_topk_logprobs={effective_topk_logprobs}, "
|
||||
f"available_topk={available_topk}, topk_logprobs={topk_logprobs}; returning empty result."
|
||||
)
|
||||
return None
|
||||
|
||||
# sliced 1 ~ (1 + effective_topk_logprobs)
|
||||
sliced_logprobs_lists = logprobs_lists.slice_columns(1, 1 + effective_topk_logprobs)
|
||||
# sliced 0 ~ effective_topk_logprobs+1
|
||||
sliced_logprobs_lists = logprobs_lists.slice_columns(0, effective_topk_logprobs + 1)
|
||||
result = []
|
||||
for token_ids, logprobs in zip(sliced_logprobs_lists.logprob_token_ids, sliced_logprobs_lists.logprobs):
|
||||
|
||||
@@ -559,7 +566,7 @@ class LLM:
|
||||
result = self.llm_engine.data_processor.process_response(result)
|
||||
|
||||
# filter logprobs
|
||||
if result.outputs.top_logprobs and topk_logprobs:
|
||||
if result.outputs.top_logprobs is not None and topk_logprobs is not None:
|
||||
if topk_logprobs == -1:
|
||||
topk_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
|
||||
result.outputs.logprobs = self._build_sample_logprobs(
|
||||
|
||||
@@ -613,7 +613,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
model: Optional[str] = "default"
|
||||
frequency_penalty: Optional[float] = Field(None, le=2, ge=-2)
|
||||
logprobs: Optional[bool] = False
|
||||
top_logprobs: Optional[int] = 0
|
||||
top_logprobs: Optional[int] = None
|
||||
prompt_logprobs: Optional[int] = None
|
||||
include_draft_logprobs: Optional[bool] = False
|
||||
|
||||
|
||||
@@ -20,7 +20,8 @@ from typing import NamedTuple, Optional
|
||||
import paddle
|
||||
|
||||
|
||||
class Logprob(NamedTuple):
|
||||
@dataclass
|
||||
class Logprob:
|
||||
"""
|
||||
A named tuple containing information about a token's log probability.
|
||||
"""
|
||||
|
||||
@@ -56,8 +56,9 @@ class TestBuildSampleLogprobs(unittest.TestCase):
|
||||
|
||||
expected = [
|
||||
{
|
||||
101: Logprob(logprob=-0.5, rank=1, decoded_token="token_101"),
|
||||
102: Logprob(logprob=-1.0, rank=2, decoded_token="token_102"),
|
||||
100: Logprob(logprob=-0.1, rank=1, decoded_token="token_100"),
|
||||
101: Logprob(logprob=-0.5, rank=2, decoded_token="token_101"),
|
||||
102: Logprob(logprob=-1.0, rank=3, decoded_token="token_102"),
|
||||
}
|
||||
]
|
||||
|
||||
@@ -79,7 +80,7 @@ class TestBuildSampleLogprobs(unittest.TestCase):
|
||||
logprobs_lists = MagicMock(spec=LogprobsLists)
|
||||
logprobs_lists.logprob_token_ids = [[100]]
|
||||
result = self.llm._build_sample_logprobs(logprobs_lists, topk_logprobs=2)
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_decode_token(self):
|
||||
"""
|
||||
|
||||
@@ -41,7 +41,7 @@ class TestChatCompletionRequest(unittest.TestCase):
|
||||
req = ChatCompletionRequest(messages=[1])
|
||||
self.assertEqual(req.model, "default")
|
||||
self.assertFalse(req.logprobs)
|
||||
self.assertEqual(req.top_logprobs, 0)
|
||||
self.assertIsNone(req.top_logprobs)
|
||||
self.assertEqual(req.n, 1)
|
||||
self.assertEqual(req.stop, [])
|
||||
|
||||
|
||||
@@ -489,8 +489,9 @@ class TestEngineClientValidParameters(unittest.TestCase):
|
||||
data = {"logprobs": True, "top_logprobs": 25, "request_id": "test"}
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.engine_client.valid_parameters(data)
|
||||
self.assertIn("top_logprobs must be between 0 and 20", str(context.exception))
|
||||
self.assertIn("current value is 25", str(context.exception))
|
||||
self.assertIn(
|
||||
"Number of top_logprobs requested (25) exceeds maximum allowed value (20)", str(context.exception)
|
||||
)
|
||||
|
||||
# Test valid value
|
||||
data = {"logprobs": True, "top_logprobs": 10, "request_id": "test"}
|
||||
|
||||
@@ -10,9 +10,10 @@ from fastdeploy.worker.output import Logprob, LogprobsTensors
|
||||
|
||||
|
||||
class DummyModelConfig:
|
||||
def __init__(self, max_logprobs=10, ori_vocab_size=50):
|
||||
def __init__(self, max_logprobs=10, ori_vocab_size=50, enable_logprob=True):
|
||||
self.max_logprobs = max_logprobs
|
||||
self.ori_vocab_size = ori_vocab_size
|
||||
self.enable_logprob = enable_logprob
|
||||
|
||||
|
||||
class DummyCacheConfig:
|
||||
|
||||
@@ -45,22 +45,23 @@ class TestClampPromptLogprobs(unittest.TestCase):
|
||||
self.assertEqual(result[0][1].logprob, -2.5)
|
||||
self.assertEqual(result[0][2].logprob, -1.0)
|
||||
|
||||
def test_negative_inf_logprobs_raises_error(self):
|
||||
"""Test that logprobs containing -inf raises AttributeError"""
|
||||
def test_negative_inf_logprobs_gets_clamped(self):
|
||||
"""Test that logprobs containing -inf get clamped to -9999.0"""
|
||||
logprob_dict = {
|
||||
1: Logprob(logprob=float("-inf"), rank=1, decoded_token="hello"),
|
||||
2: Logprob(logprob=-1.0, rank=2, decoded_token="world"),
|
||||
}
|
||||
prompt_logprobs = [logprob_dict]
|
||||
|
||||
# Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError
|
||||
with self.assertRaises(AttributeError) as context:
|
||||
clamp_prompt_logprobs(prompt_logprobs)
|
||||
# Since Logprob is now a dataclass, its fields can be modified
|
||||
result = clamp_prompt_logprobs(prompt_logprobs)
|
||||
|
||||
self.assertIn("can't set attribute", str(context.exception))
|
||||
# The -inf value should be clamped to -9999.0
|
||||
self.assertEqual(result[0][1].logprob, -9999.0)
|
||||
self.assertEqual(result[0][2].logprob, -1.0) # unchanged
|
||||
|
||||
def test_multiple_negative_inf_raises_error(self):
|
||||
"""Test that multiple -inf logprobs values raise AttributeError"""
|
||||
def test_multiple_negative_inf_gets_clamped(self):
|
||||
"""Test that multiple -inf logprobs values get clamped to -9999.0"""
|
||||
logprob_dict = {
|
||||
1: Logprob(logprob=float("-inf"), rank=1, decoded_token="hello"),
|
||||
2: Logprob(logprob=float("-inf"), rank=2, decoded_token="world"),
|
||||
@@ -68,9 +69,13 @@ class TestClampPromptLogprobs(unittest.TestCase):
|
||||
}
|
||||
prompt_logprobs = [logprob_dict]
|
||||
|
||||
# Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError
|
||||
with self.assertRaises(AttributeError):
|
||||
clamp_prompt_logprobs(prompt_logprobs)
|
||||
# Since Logprob is now a dataclass, its fields can be modified
|
||||
result = clamp_prompt_logprobs(prompt_logprobs)
|
||||
|
||||
# All -inf values should be clamped to -9999.0
|
||||
self.assertEqual(result[0][1].logprob, -9999.0)
|
||||
self.assertEqual(result[0][2].logprob, -9999.0)
|
||||
self.assertEqual(result[0][3].logprob, -0.5) # unchanged
|
||||
|
||||
def test_none_dict_in_list(self):
|
||||
"""Test case when list contains None"""
|
||||
@@ -116,7 +121,7 @@ class TestClampPromptLogprobs(unittest.TestCase):
|
||||
self.assertEqual(result[0][4].logprob, -1.5)
|
||||
|
||||
def test_return_same_object(self):
|
||||
"""Test that function returns the same object (in-place modification attempt)"""
|
||||
"""Test that function returns the same object (in-place modification)"""
|
||||
logprob_dict = {
|
||||
1: Logprob(logprob=-2.0, rank=1, decoded_token="hello"),
|
||||
}
|
||||
@@ -124,7 +129,7 @@ class TestClampPromptLogprobs(unittest.TestCase):
|
||||
|
||||
result = clamp_prompt_logprobs(prompt_logprobs)
|
||||
|
||||
# Should return the same object (function attempts in-place modification)
|
||||
# Should return the same object (function performs in-place modification)
|
||||
self.assertIs(result, prompt_logprobs)
|
||||
self.assertIs(result[0], prompt_logprobs[0])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user