mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
* add prompt logprobs * trigger ci * fix unitest * Update fastdeploy/config.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/entrypoints/llm.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/engine/sampling_params.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/engine/test_sampling_params.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/engine/test_sampling_params.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix max_logprobs --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
223 lines
8.8 KiB
Python
223 lines
8.8 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import os
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
from fastdeploy.engine.sampling_params import SamplingParams
|
|
|
|
|
|
class TestSamplingParamsVerification(unittest.TestCase):
|
|
"""Test case for SamplingParams _verify_args method"""
|
|
|
|
def test_logprobs_valid_values(self):
|
|
"""Test valid logprobs values"""
|
|
# Test None value (should pass)
|
|
params = SamplingParams(logprobs=None)
|
|
params._verify_args() # Should not raise
|
|
|
|
# Test -1 value (should pass)
|
|
params = SamplingParams(logprobs=-1)
|
|
params._verify_args() # Should not raise
|
|
|
|
# Test 0 value (should pass)
|
|
params = SamplingParams(logprobs=0)
|
|
params._verify_args() # Should not raise
|
|
|
|
# Test 20 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "0")
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
params = SamplingParams(logprobs=20)
|
|
params._verify_args() # Should not raise
|
|
|
|
def test_logprobs_invalid_less_than_minus_one(self):
|
|
"""Test logprobs less than -1 should raise ValueError"""
|
|
with self.assertRaises(ValueError) as cm:
|
|
params = SamplingParams(logprobs=-2)
|
|
params._verify_args()
|
|
|
|
self.assertIn("logprobs must be greater than -1", str(cm.exception))
|
|
self.assertIn("got -2", str(cm.exception))
|
|
|
|
def test_logprobs_greater_than_20_with_v1_disabled(self):
|
|
"""Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is disabled"""
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
with self.assertRaises(ValueError) as cm:
|
|
params = SamplingParams(logprobs=21)
|
|
params._verify_args()
|
|
|
|
self.assertEqual("Invalid value for 'top_logprobs': must be less than or equal to 20.", str(cm.exception))
|
|
|
|
def test_logprobs_greater_than_20_with_v1_enabled(self):
|
|
"""Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is enabled"""
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
# Should not raise when v1 is enabled
|
|
params = SamplingParams(logprobs=21)
|
|
params._verify_args() # Should not raise
|
|
|
|
# Test even larger values when v1 is enabled
|
|
params = SamplingParams(logprobs=100)
|
|
params._verify_args() # Should not raise
|
|
|
|
def test_prompt_logprobs_valid_values(self):
|
|
"""Test valid prompt_logprobs values"""
|
|
# Test None value (should pass)
|
|
params = SamplingParams(prompt_logprobs=None)
|
|
params._verify_args() # Should not raise
|
|
|
|
# Test -1 value (should pass)
|
|
params = SamplingParams(prompt_logprobs=-1)
|
|
params._verify_args() # Should not raise
|
|
|
|
# Test 0 value (should pass)
|
|
params = SamplingParams(prompt_logprobs=0)
|
|
params._verify_args() # Should not raise
|
|
|
|
# Test positive values (should pass)
|
|
params = SamplingParams(prompt_logprobs=10)
|
|
params._verify_args() # Should not raise
|
|
|
|
def test_prompt_logprobs_invalid_less_than_minus_one(self):
|
|
"""Test prompt_logprobs less than -1 should raise ValueError"""
|
|
with self.assertRaises(ValueError) as cm:
|
|
params = SamplingParams(prompt_logprobs=-2)
|
|
params._verify_args()
|
|
|
|
self.assertIn("prompt_logprobs must be greater than or equal to -1", str(cm.exception))
|
|
self.assertIn("got -2", str(cm.exception))
|
|
|
|
def test_combined_logprobs_and_prompt_logprobs(self):
|
|
"""Test both logprobs and prompt_logprobs together"""
|
|
# Test valid combination
|
|
params = SamplingParams(logprobs=5, prompt_logprobs=3)
|
|
params._verify_args() # Should not raise
|
|
|
|
# Test invalid logprobs with valid prompt_logprobs
|
|
with self.assertRaises(ValueError):
|
|
params = SamplingParams(logprobs=-2, prompt_logprobs=5)
|
|
params._verify_args()
|
|
|
|
# Test valid logprobs with invalid prompt_logprobs
|
|
with self.assertRaises(ValueError):
|
|
params = SamplingParams(logprobs=5, prompt_logprobs=-2)
|
|
params._verify_args()
|
|
|
|
def test_logprobs_boundary_values(self):
|
|
"""Test boundary values for logprobs"""
|
|
# Test just below limit with v1 disabled
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
params = SamplingParams(logprobs=20)
|
|
params._verify_args() # Should pass
|
|
|
|
# Test just above limit with v1 disabled
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
with self.assertRaises(ValueError):
|
|
params = SamplingParams(logprobs=21)
|
|
params._verify_args()
|
|
|
|
def test_prompt_logprobs_boundary_values(self):
|
|
"""Test boundary values for prompt_logprobs"""
|
|
# Test boundary value -1 (should pass)
|
|
params = SamplingParams(prompt_logprobs=-1)
|
|
params._verify_args() # Should pass
|
|
|
|
# Test boundary value just below -1 (should fail)
|
|
with self.assertRaises(ValueError):
|
|
params = SamplingParams(prompt_logprobs=-2)
|
|
params._verify_args()
|
|
|
|
def test_environment_variable_handling(self):
|
|
"""Test different environment variable values"""
|
|
# Test FD_USE_GET_SAVE_OUTPUT_V1 = "0" (default behavior)
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
with self.assertRaises(ValueError):
|
|
params = SamplingParams(logprobs=21)
|
|
params._verify_args()
|
|
|
|
# Test FD_USE_GET_SAVE_OUTPUT_V1 = "1" (relaxed behavior)
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
params = SamplingParams(logprobs=21)
|
|
params._verify_args() # Should pass
|
|
|
|
# Test FD_USE_GET_SAVE_OUTPUT_V1 not set (default to "0")
|
|
if "FD_USE_GET_SAVE_OUTPUT_V1" in os.environ:
|
|
original_value = os.environ["FD_USE_GET_SAVE_OUTPUT_V1"]
|
|
del os.environ["FD_USE_GET_SAVE_OUTPUT_V1"]
|
|
else:
|
|
original_value = None
|
|
|
|
try:
|
|
with self.assertRaises(ValueError):
|
|
params = SamplingParams(logprobs=21)
|
|
params._verify_args()
|
|
finally:
|
|
if original_value is not None:
|
|
os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = original_value
|
|
|
|
def test_error_message_formatting(self):
|
|
"""Test that error messages are properly formatted"""
|
|
# Test logprobs error message
|
|
with self.assertRaises(ValueError) as cm:
|
|
params = SamplingParams(logprobs=-5)
|
|
params._verify_args()
|
|
|
|
error_msg = str(cm.exception)
|
|
self.assertIn("logprobs must be greater than -1", error_msg)
|
|
self.assertIn("got -5", error_msg)
|
|
|
|
# Test prompt_logprobs error message
|
|
with self.assertRaises(ValueError) as cm:
|
|
params = SamplingParams(prompt_logprobs=-10)
|
|
params._verify_args()
|
|
|
|
error_msg = str(cm.exception)
|
|
self.assertIn("prompt_logprobs must be greater than or equal to -1", error_msg)
|
|
self.assertIn("got -10", error_msg)
|
|
|
|
def test_post_init_calls_verify_args(self):
|
|
"""Test that __post_init__ calls _verify_args"""
|
|
# This should call _verify_args internally
|
|
params = SamplingParams(logprobs=5, prompt_logprobs=3)
|
|
|
|
# The params should be successfully created without errors
|
|
self.assertEqual(params.logprobs, 5)
|
|
self.assertEqual(params.prompt_logprobs, 3)
|
|
|
|
# Test that invalid values are caught during initialization
|
|
with self.assertRaises(ValueError):
|
|
SamplingParams(logprobs=-2)
|
|
|
|
with self.assertRaises(ValueError):
|
|
SamplingParams(prompt_logprobs=-2)
|
|
|
|
def test_logprobs_with_other_parameters(self):
|
|
"""Test logprobs validation with other sampling parameters"""
|
|
# Test with temperature
|
|
params = SamplingParams(logprobs=5, temperature=0.8)
|
|
params._verify_args() # Should pass
|
|
|
|
# Test with top_p
|
|
params = SamplingParams(logprobs=5, top_p=0.9)
|
|
params._verify_args() # Should pass
|
|
|
|
# Test with all parameters
|
|
params = SamplingParams(logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100)
|
|
params._verify_args() # Should pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|