Files
FastDeploy/tests/engine/test_sampling_params.py
qwes5s5 117980dd4e [LogProbs]Enable prompt logprobs output and modify data transmission method for the online interface. (#5089)
* add prompt logprobs

* Merge prompt_logprobs_tensors and prompt_logprobs

* fix param check

* trigger ci

* fix unitest

* fix logprobs bug
2025-12-02 13:49:51 +08:00

323 lines
15 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import unittest
from unittest.mock import patch
from fastdeploy.engine.sampling_params import SamplingParams
class TestSamplingParamsVerification(unittest.TestCase):
"""Test case for SamplingParams _verify_args method"""
def test_logprobs_valid_values(self):
"""Test valid logprobs values"""
# Test None value (should pass in both modes)
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
params = SamplingParams(logprobs=None)
params._verify_args() # Should not raise
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=None)
params._verify_args() # Should not raise
# Test -1 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=-1)
params._verify_args() # Should not raise
# Test 0 value (should pass in both modes based on actual behavior)
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
params = SamplingParams(logprobs=0)
params._verify_args() # Should not raise
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=0)
params._verify_args() # Should not raise
# Test 20 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "0")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
params = SamplingParams(logprobs=20)
params._verify_args() # Should not raise
def test_logprobs_invalid_less_than_minus_one(self):
"""Test logprobs less than -1 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "1" """
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=-2)
params._verify_args()
self.assertIn("logprobs must be a non-negative value or -1", str(cm.exception))
self.assertIn("got -2", str(cm.exception))
def test_logprobs_invalid_less_than_zero(self):
"""Test logprobs less than 0 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "0" """
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=-1)
params._verify_args()
self.assertIn("Invalid value for 'top_logprobs': must be between 0 and 20", str(cm.exception))
def test_logprobs_greater_than_20_with_v1_disabled(self):
"""Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is disabled"""
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=21)
params._verify_args()
self.assertEqual("Invalid value for 'top_logprobs': must be between 0 and 20.", str(cm.exception))
def test_logprobs_greater_than_20_with_v1_enabled(self):
"""Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is enabled"""
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
# Should not raise when v1 is enabled
params = SamplingParams(logprobs=21)
params._verify_args() # Should not raise
# Test even larger values when v1 is enabled
params = SamplingParams(logprobs=100)
params._verify_args() # Should not raise
def test_prompt_logprobs_valid_values(self):
"""Test valid prompt_logprobs values"""
# Test None value (should pass in both modes based on actual behavior)
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
params = SamplingParams(prompt_logprobs=None)
params._verify_args() # Should not raise
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(prompt_logprobs=None)
params._verify_args() # Should not raise
# Test -1 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(prompt_logprobs=-1)
params._verify_args() # Should not raise
# Test 0 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(prompt_logprobs=0)
params._verify_args() # Should not raise
# Test positive values (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(prompt_logprobs=10)
params._verify_args() # Should not raise
def test_prompt_logprobs_invalid_less_than_minus_one(self):
"""Test prompt_logprobs less than -1 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "1" """
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(prompt_logprobs=-2)
params._verify_args()
self.assertIn("prompt_logprobs a must be non-negative value or -1", str(cm.exception))
self.assertIn("got -2", str(cm.exception))
def test_combined_logprobs_and_prompt_logprobs(self):
"""Test both logprobs and prompt_logprobs together"""
# Test valid combination when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=5, prompt_logprobs=3)
params._verify_args() # Should not raise
# Test invalid logprobs with valid prompt_logprobs when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError):
params = SamplingParams(logprobs=-2, prompt_logprobs=5)
params._verify_args()
# Test valid logprobs with invalid prompt_logprobs when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError):
params = SamplingParams(logprobs=5, prompt_logprobs=-2)
params._verify_args()
# Test prompt_logprobs not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=5, prompt_logprobs=3)
params._verify_args()
self.assertIn(
"prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", str(cm.exception)
)
def test_logprobs_boundary_values(self):
"""Test boundary values for logprobs"""
# Test just below limit with v1 disabled
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
params = SamplingParams(logprobs=20)
params._verify_args() # Should pass
# Test just above limit with v1 disabled
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError):
params = SamplingParams(logprobs=21)
params._verify_args()
def test_prompt_logprobs_boundary_values(self):
"""Test boundary values for prompt_logprobs"""
# Test boundary value -1 (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(prompt_logprobs=-1)
params._verify_args() # Should pass
# Test boundary value just below -1 (should fail when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError):
params = SamplingParams(prompt_logprobs=-2)
params._verify_args()
def test_environment_variable_handling(self):
"""Test different environment variable values"""
# Test FD_USE_GET_SAVE_OUTPUT_V1 = "0" (default behavior)
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError):
params = SamplingParams(logprobs=21)
params._verify_args()
# Test FD_USE_GET_SAVE_OUTPUT_V1 = "1" (relaxed behavior)
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=21)
params._verify_args() # Should pass
# Test FD_USE_GET_SAVE_OUTPUT_V1 not set (default to "0")
if "FD_USE_GET_SAVE_OUTPUT_V1" in os.environ:
original_value = os.environ["FD_USE_GET_SAVE_OUTPUT_V1"]
del os.environ["FD_USE_GET_SAVE_OUTPUT_V1"]
else:
original_value = None
try:
with self.assertRaises(ValueError):
params = SamplingParams(logprobs=21)
params._verify_args()
finally:
if original_value is not None:
os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = original_value
# Test prompt_logprobs behavior with different environment variables
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(prompt_logprobs=5)
params._verify_args()
self.assertIn(
"prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", str(cm.exception)
)
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(prompt_logprobs=5)
params._verify_args() # Should pass
def test_error_message_formatting(self):
"""Test that error messages are properly formatted"""
# Test logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=-5)
params._verify_args()
error_msg = str(cm.exception)
self.assertIn("logprobs must be a non-negative value or -1", error_msg)
self.assertIn("got -5", error_msg)
# Test logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=-1)
params._verify_args()
error_msg = str(cm.exception)
self.assertIn("Invalid value for 'top_logprobs': must be between 0 and 20", error_msg)
# Test prompt_logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(prompt_logprobs=-10)
params._verify_args()
error_msg = str(cm.exception)
self.assertIn("prompt_logprobs a must be non-negative value or -1", error_msg)
self.assertIn("got -10", error_msg)
# Test prompt_logprobs not supported error message when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(prompt_logprobs=5)
params._verify_args()
error_msg = str(cm.exception)
self.assertIn("prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", error_msg)
def test_post_init_calls_verify_args(self):
"""Test that __post_init__ calls _verify_args"""
# This should call _verify_args internally when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=5, prompt_logprobs=3)
# The params should be successfully created without errors
self.assertEqual(params.logprobs, 5)
self.assertEqual(params.prompt_logprobs, 3)
# Test that invalid values are caught during initialization
with self.assertRaises(ValueError):
SamplingParams(logprobs=-2)
with self.assertRaises(ValueError):
SamplingParams(prompt_logprobs=-2)
# Test that prompt_logprobs is not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError):
SamplingParams(prompt_logprobs=3)
# Test that logprobs < 0 is not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
with self.assertRaises(ValueError):
SamplingParams(logprobs=-1)
def test_logprobs_with_other_parameters(self):
"""Test logprobs validation with other sampling parameters"""
# Test with temperature when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=5, temperature=0.8)
params._verify_args() # Should pass
# Test with top_p when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=5, top_p=0.9)
params._verify_args() # Should pass
# Test with all parameters when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(
logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100
)
params._verify_args() # Should pass
# Test that prompt_logprobs fails when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError):
params = SamplingParams(
logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100
)
params._verify_args()
if __name__ == "__main__":
unittest.main()