mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
* add prompt logprobs * Merge prompt_logprobs_tensors and prompt_logprobs * fix param check * trigger ci * fix unitest * fix logprobs bug
323 lines
15 KiB
Python
323 lines
15 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import os
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
from fastdeploy.engine.sampling_params import SamplingParams
|
|
|
|
|
|
class TestSamplingParamsVerification(unittest.TestCase):
|
|
"""Test case for SamplingParams _verify_args method"""
|
|
|
|
def test_logprobs_valid_values(self):
|
|
"""Test valid logprobs values"""
|
|
# Test None value (should pass in both modes)
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
params = SamplingParams(logprobs=None)
|
|
params._verify_args() # Should not raise
|
|
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
params = SamplingParams(logprobs=None)
|
|
params._verify_args() # Should not raise
|
|
|
|
# Test -1 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
params = SamplingParams(logprobs=-1)
|
|
params._verify_args() # Should not raise
|
|
|
|
# Test 0 value (should pass in both modes based on actual behavior)
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
params = SamplingParams(logprobs=0)
|
|
params._verify_args() # Should not raise
|
|
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
params = SamplingParams(logprobs=0)
|
|
params._verify_args() # Should not raise
|
|
|
|
# Test 20 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "0")
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
params = SamplingParams(logprobs=20)
|
|
params._verify_args() # Should not raise
|
|
|
|
def test_logprobs_invalid_less_than_minus_one(self):
|
|
"""Test logprobs less than -1 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "1" """
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
with self.assertRaises(ValueError) as cm:
|
|
params = SamplingParams(logprobs=-2)
|
|
params._verify_args()
|
|
|
|
self.assertIn("logprobs must be a non-negative value or -1", str(cm.exception))
|
|
self.assertIn("got -2", str(cm.exception))
|
|
|
|
def test_logprobs_invalid_less_than_zero(self):
|
|
"""Test logprobs less than 0 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "0" """
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
with self.assertRaises(ValueError) as cm:
|
|
params = SamplingParams(logprobs=-1)
|
|
params._verify_args()
|
|
|
|
self.assertIn("Invalid value for 'top_logprobs': must be between 0 and 20", str(cm.exception))
|
|
|
|
def test_logprobs_greater_than_20_with_v1_disabled(self):
|
|
"""Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is disabled"""
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
with self.assertRaises(ValueError) as cm:
|
|
params = SamplingParams(logprobs=21)
|
|
params._verify_args()
|
|
|
|
self.assertEqual("Invalid value for 'top_logprobs': must be between 0 and 20.", str(cm.exception))
|
|
|
|
def test_logprobs_greater_than_20_with_v1_enabled(self):
|
|
"""Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is enabled"""
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
# Should not raise when v1 is enabled
|
|
params = SamplingParams(logprobs=21)
|
|
params._verify_args() # Should not raise
|
|
|
|
# Test even larger values when v1 is enabled
|
|
params = SamplingParams(logprobs=100)
|
|
params._verify_args() # Should not raise
|
|
|
|
def test_prompt_logprobs_valid_values(self):
|
|
"""Test valid prompt_logprobs values"""
|
|
# Test None value (should pass in both modes based on actual behavior)
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
params = SamplingParams(prompt_logprobs=None)
|
|
params._verify_args() # Should not raise
|
|
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
params = SamplingParams(prompt_logprobs=None)
|
|
params._verify_args() # Should not raise
|
|
|
|
# Test -1 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
params = SamplingParams(prompt_logprobs=-1)
|
|
params._verify_args() # Should not raise
|
|
|
|
# Test 0 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
params = SamplingParams(prompt_logprobs=0)
|
|
params._verify_args() # Should not raise
|
|
|
|
# Test positive values (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
params = SamplingParams(prompt_logprobs=10)
|
|
params._verify_args() # Should not raise
|
|
|
|
def test_prompt_logprobs_invalid_less_than_minus_one(self):
|
|
"""Test prompt_logprobs less than -1 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "1" """
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
with self.assertRaises(ValueError) as cm:
|
|
params = SamplingParams(prompt_logprobs=-2)
|
|
params._verify_args()
|
|
|
|
self.assertIn("prompt_logprobs a must be non-negative value or -1", str(cm.exception))
|
|
self.assertIn("got -2", str(cm.exception))
|
|
|
|
def test_combined_logprobs_and_prompt_logprobs(self):
|
|
"""Test both logprobs and prompt_logprobs together"""
|
|
# Test valid combination when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
params = SamplingParams(logprobs=5, prompt_logprobs=3)
|
|
params._verify_args() # Should not raise
|
|
|
|
# Test invalid logprobs with valid prompt_logprobs when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
with self.assertRaises(ValueError):
|
|
params = SamplingParams(logprobs=-2, prompt_logprobs=5)
|
|
params._verify_args()
|
|
|
|
# Test valid logprobs with invalid prompt_logprobs when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
with self.assertRaises(ValueError):
|
|
params = SamplingParams(logprobs=5, prompt_logprobs=-2)
|
|
params._verify_args()
|
|
|
|
# Test prompt_logprobs not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
with self.assertRaises(ValueError) as cm:
|
|
params = SamplingParams(logprobs=5, prompt_logprobs=3)
|
|
params._verify_args()
|
|
self.assertIn(
|
|
"prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", str(cm.exception)
|
|
)
|
|
|
|
def test_logprobs_boundary_values(self):
|
|
"""Test boundary values for logprobs"""
|
|
# Test just below limit with v1 disabled
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
params = SamplingParams(logprobs=20)
|
|
params._verify_args() # Should pass
|
|
|
|
# Test just above limit with v1 disabled
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
with self.assertRaises(ValueError):
|
|
params = SamplingParams(logprobs=21)
|
|
params._verify_args()
|
|
|
|
def test_prompt_logprobs_boundary_values(self):
|
|
"""Test boundary values for prompt_logprobs"""
|
|
# Test boundary value -1 (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
params = SamplingParams(prompt_logprobs=-1)
|
|
params._verify_args() # Should pass
|
|
|
|
# Test boundary value just below -1 (should fail when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
with self.assertRaises(ValueError):
|
|
params = SamplingParams(prompt_logprobs=-2)
|
|
params._verify_args()
|
|
|
|
def test_environment_variable_handling(self):
|
|
"""Test different environment variable values"""
|
|
# Test FD_USE_GET_SAVE_OUTPUT_V1 = "0" (default behavior)
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
with self.assertRaises(ValueError):
|
|
params = SamplingParams(logprobs=21)
|
|
params._verify_args()
|
|
|
|
# Test FD_USE_GET_SAVE_OUTPUT_V1 = "1" (relaxed behavior)
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
params = SamplingParams(logprobs=21)
|
|
params._verify_args() # Should pass
|
|
|
|
# Test FD_USE_GET_SAVE_OUTPUT_V1 not set (default to "0")
|
|
if "FD_USE_GET_SAVE_OUTPUT_V1" in os.environ:
|
|
original_value = os.environ["FD_USE_GET_SAVE_OUTPUT_V1"]
|
|
del os.environ["FD_USE_GET_SAVE_OUTPUT_V1"]
|
|
else:
|
|
original_value = None
|
|
|
|
try:
|
|
with self.assertRaises(ValueError):
|
|
params = SamplingParams(logprobs=21)
|
|
params._verify_args()
|
|
finally:
|
|
if original_value is not None:
|
|
os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = original_value
|
|
|
|
# Test prompt_logprobs behavior with different environment variables
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
with self.assertRaises(ValueError) as cm:
|
|
params = SamplingParams(prompt_logprobs=5)
|
|
params._verify_args()
|
|
self.assertIn(
|
|
"prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", str(cm.exception)
|
|
)
|
|
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
params = SamplingParams(prompt_logprobs=5)
|
|
params._verify_args() # Should pass
|
|
|
|
def test_error_message_formatting(self):
|
|
"""Test that error messages are properly formatted"""
|
|
# Test logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
with self.assertRaises(ValueError) as cm:
|
|
params = SamplingParams(logprobs=-5)
|
|
params._verify_args()
|
|
|
|
error_msg = str(cm.exception)
|
|
self.assertIn("logprobs must be a non-negative value or -1", error_msg)
|
|
self.assertIn("got -5", error_msg)
|
|
|
|
# Test logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
with self.assertRaises(ValueError) as cm:
|
|
params = SamplingParams(logprobs=-1)
|
|
params._verify_args()
|
|
|
|
error_msg = str(cm.exception)
|
|
self.assertIn("Invalid value for 'top_logprobs': must be between 0 and 20", error_msg)
|
|
|
|
# Test prompt_logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
with self.assertRaises(ValueError) as cm:
|
|
params = SamplingParams(prompt_logprobs=-10)
|
|
params._verify_args()
|
|
|
|
error_msg = str(cm.exception)
|
|
self.assertIn("prompt_logprobs a must be non-negative value or -1", error_msg)
|
|
self.assertIn("got -10", error_msg)
|
|
|
|
# Test prompt_logprobs not supported error message when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
with self.assertRaises(ValueError) as cm:
|
|
params = SamplingParams(prompt_logprobs=5)
|
|
params._verify_args()
|
|
|
|
error_msg = str(cm.exception)
|
|
self.assertIn("prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", error_msg)
|
|
|
|
def test_post_init_calls_verify_args(self):
|
|
"""Test that __post_init__ calls _verify_args"""
|
|
# This should call _verify_args internally when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
params = SamplingParams(logprobs=5, prompt_logprobs=3)
|
|
|
|
# The params should be successfully created without errors
|
|
self.assertEqual(params.logprobs, 5)
|
|
self.assertEqual(params.prompt_logprobs, 3)
|
|
|
|
# Test that invalid values are caught during initialization
|
|
with self.assertRaises(ValueError):
|
|
SamplingParams(logprobs=-2)
|
|
|
|
with self.assertRaises(ValueError):
|
|
SamplingParams(prompt_logprobs=-2)
|
|
|
|
# Test that prompt_logprobs is not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
with self.assertRaises(ValueError):
|
|
SamplingParams(prompt_logprobs=3)
|
|
|
|
# Test that logprobs < 0 is not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
|
|
with self.assertRaises(ValueError):
|
|
SamplingParams(logprobs=-1)
|
|
|
|
def test_logprobs_with_other_parameters(self):
|
|
"""Test logprobs validation with other sampling parameters"""
|
|
# Test with temperature when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
params = SamplingParams(logprobs=5, temperature=0.8)
|
|
params._verify_args() # Should pass
|
|
|
|
# Test with top_p when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
params = SamplingParams(logprobs=5, top_p=0.9)
|
|
params._verify_args() # Should pass
|
|
|
|
# Test with all parameters when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
|
params = SamplingParams(
|
|
logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100
|
|
)
|
|
params._verify_args() # Should pass
|
|
|
|
# Test that prompt_logprobs fails when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
|
|
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
|
with self.assertRaises(ValueError):
|
|
params = SamplingParams(
|
|
logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100
|
|
)
|
|
params._verify_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|