[LogProbs]Enable prompt logprobs output and modify data transmission method for the online interface. (#5089)

* add prompt logprobs

* Merge prompt_logprobs_tensors and prompt_logprobs

* fix param check

* trigger ci

* fix unitest

* fix logprobs bug
This commit is contained in:
qwes5s5
2025-12-02 13:49:51 +08:00
committed by GitHub
parent af39819fcd
commit 117980dd4e
27 changed files with 4947 additions and 233 deletions

View File

@@ -26,17 +26,28 @@ class TestSamplingParamsVerification(unittest.TestCase):
def test_logprobs_valid_values(self):
"""Test valid logprobs values"""
# Test None value (should pass)
params = SamplingParams(logprobs=None)
params._verify_args() # Should not raise
# Test None value (should pass in both modes)
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
params = SamplingParams(logprobs=None)
params._verify_args() # Should not raise
# Test -1 value (should pass)
params = SamplingParams(logprobs=-1)
params._verify_args() # Should not raise
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=None)
params._verify_args() # Should not raise
# Test 0 value (should pass)
params = SamplingParams(logprobs=0)
params._verify_args() # Should not raise
# Test -1 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=-1)
params._verify_args() # Should not raise
# Test 0 value (should pass in both modes based on actual behavior)
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
params = SamplingParams(logprobs=0)
params._verify_args() # Should not raise
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=0)
params._verify_args() # Should not raise
# Test 20 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "0")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
@@ -44,13 +55,23 @@ class TestSamplingParamsVerification(unittest.TestCase):
params._verify_args() # Should not raise
def test_logprobs_invalid_less_than_minus_one(self):
"""Test logprobs less than -1 should raise ValueError"""
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=-2)
params._verify_args()
"""Test logprobs less than -1 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "1" """
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=-2)
params._verify_args()
self.assertIn("logprobs must be greater than -1", str(cm.exception))
self.assertIn("got -2", str(cm.exception))
self.assertIn("logprobs must be a non-negative value or -1", str(cm.exception))
self.assertIn("got -2", str(cm.exception))
def test_logprobs_invalid_less_than_zero(self):
"""Test logprobs less than 0 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "0" """
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=-1)
params._verify_args()
self.assertIn("Invalid value for 'top_logprobs': must be between 0 and 20", str(cm.exception))
def test_logprobs_greater_than_20_with_v1_disabled(self):
"""Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is disabled"""
@@ -59,7 +80,7 @@ class TestSamplingParamsVerification(unittest.TestCase):
params = SamplingParams(logprobs=21)
params._verify_args()
self.assertEqual("Invalid value for 'top_logprobs': must be less than or equal to 20.", str(cm.exception))
self.assertEqual("Invalid value for 'top_logprobs': must be between 0 and 20.", str(cm.exception))
def test_logprobs_greater_than_20_with_v1_enabled(self):
"""Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is enabled"""
@@ -74,46 +95,67 @@ class TestSamplingParamsVerification(unittest.TestCase):
def test_prompt_logprobs_valid_values(self):
"""Test valid prompt_logprobs values"""
# Test None value (should pass)
params = SamplingParams(prompt_logprobs=None)
params._verify_args() # Should not raise
# Test None value (should pass in both modes based on actual behavior)
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
params = SamplingParams(prompt_logprobs=None)
params._verify_args() # Should not raise
# Test -1 value (should pass)
params = SamplingParams(prompt_logprobs=-1)
params._verify_args() # Should not raise
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(prompt_logprobs=None)
params._verify_args() # Should not raise
# Test 0 value (should pass)
params = SamplingParams(prompt_logprobs=0)
params._verify_args() # Should not raise
# Test -1 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(prompt_logprobs=-1)
params._verify_args() # Should not raise
# Test positive values (should pass)
params = SamplingParams(prompt_logprobs=10)
params._verify_args() # Should not raise
# Test 0 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(prompt_logprobs=0)
params._verify_args() # Should not raise
# Test positive values (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(prompt_logprobs=10)
params._verify_args() # Should not raise
def test_prompt_logprobs_invalid_less_than_minus_one(self):
"""Test prompt_logprobs less than -1 should raise ValueError"""
with self.assertRaises(ValueError) as cm:
params = SamplingParams(prompt_logprobs=-2)
params._verify_args()
"""Test prompt_logprobs less than -1 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "1" """
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(prompt_logprobs=-2)
params._verify_args()
self.assertIn("prompt_logprobs must be greater than or equal to -1", str(cm.exception))
self.assertIn("got -2", str(cm.exception))
self.assertIn("prompt_logprobs a must be non-negative value or -1", str(cm.exception))
self.assertIn("got -2", str(cm.exception))
def test_combined_logprobs_and_prompt_logprobs(self):
"""Test both logprobs and prompt_logprobs together"""
# Test valid combination
params = SamplingParams(logprobs=5, prompt_logprobs=3)
params._verify_args() # Should not raise
# Test valid combination when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=5, prompt_logprobs=3)
params._verify_args() # Should not raise
# Test invalid logprobs with valid prompt_logprobs
with self.assertRaises(ValueError):
params = SamplingParams(logprobs=-2, prompt_logprobs=5)
params._verify_args()
# Test invalid logprobs with valid prompt_logprobs when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError):
params = SamplingParams(logprobs=-2, prompt_logprobs=5)
params._verify_args()
# Test valid logprobs with invalid prompt_logprobs
with self.assertRaises(ValueError):
params = SamplingParams(logprobs=5, prompt_logprobs=-2)
params._verify_args()
# Test valid logprobs with invalid prompt_logprobs when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError):
params = SamplingParams(logprobs=5, prompt_logprobs=-2)
params._verify_args()
# Test prompt_logprobs not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=5, prompt_logprobs=3)
params._verify_args()
self.assertIn(
"prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", str(cm.exception)
)
def test_logprobs_boundary_values(self):
"""Test boundary values for logprobs"""
@@ -130,14 +172,16 @@ class TestSamplingParamsVerification(unittest.TestCase):
def test_prompt_logprobs_boundary_values(self):
"""Test boundary values for prompt_logprobs"""
# Test boundary value -1 (should pass)
params = SamplingParams(prompt_logprobs=-1)
params._verify_args() # Should pass
# Test boundary value -1 (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(prompt_logprobs=-1)
params._verify_args() # Should pass
# Test boundary value just below -1 (should fail)
with self.assertRaises(ValueError):
params = SamplingParams(prompt_logprobs=-2)
params._verify_args()
# Test boundary value just below -1 (should fail when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError):
params = SamplingParams(prompt_logprobs=-2)
params._verify_args()
def test_environment_variable_handling(self):
"""Test different environment variable values"""
@@ -167,55 +211,111 @@ class TestSamplingParamsVerification(unittest.TestCase):
if original_value is not None:
os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = original_value
# Test prompt_logprobs behavior with different environment variables
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(prompt_logprobs=5)
params._verify_args()
self.assertIn(
"prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", str(cm.exception)
)
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(prompt_logprobs=5)
params._verify_args() # Should pass
def test_error_message_formatting(self):
"""Test that error messages are properly formatted"""
# Test logprobs error message
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=-5)
params._verify_args()
# Test logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=-5)
params._verify_args()
error_msg = str(cm.exception)
self.assertIn("logprobs must be greater than -1", error_msg)
self.assertIn("got -5", error_msg)
error_msg = str(cm.exception)
self.assertIn("logprobs must be a non-negative value or -1", error_msg)
self.assertIn("got -5", error_msg)
# Test prompt_logprobs error message
with self.assertRaises(ValueError) as cm:
params = SamplingParams(prompt_logprobs=-10)
params._verify_args()
# Test logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=-1)
params._verify_args()
error_msg = str(cm.exception)
self.assertIn("prompt_logprobs must be greater than or equal to -1", error_msg)
self.assertIn("got -10", error_msg)
error_msg = str(cm.exception)
self.assertIn("Invalid value for 'top_logprobs': must be between 0 and 20", error_msg)
# Test prompt_logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(prompt_logprobs=-10)
params._verify_args()
error_msg = str(cm.exception)
self.assertIn("prompt_logprobs a must be non-negative value or -1", error_msg)
self.assertIn("got -10", error_msg)
# Test prompt_logprobs not supported error message when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(prompt_logprobs=5)
params._verify_args()
error_msg = str(cm.exception)
self.assertIn("prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", error_msg)
def test_post_init_calls_verify_args(self):
"""Test that __post_init__ calls _verify_args"""
# This should call _verify_args internally
params = SamplingParams(logprobs=5, prompt_logprobs=3)
# This should call _verify_args internally when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=5, prompt_logprobs=3)
# The params should be successfully created without errors
self.assertEqual(params.logprobs, 5)
self.assertEqual(params.prompt_logprobs, 3)
# The params should be successfully created without errors
self.assertEqual(params.logprobs, 5)
self.assertEqual(params.prompt_logprobs, 3)
# Test that invalid values are caught during initialization
with self.assertRaises(ValueError):
SamplingParams(logprobs=-2)
# Test that invalid values are caught during initialization
with self.assertRaises(ValueError):
SamplingParams(logprobs=-2)
with self.assertRaises(ValueError):
SamplingParams(prompt_logprobs=-2)
with self.assertRaises(ValueError):
SamplingParams(prompt_logprobs=-2)
# Test that prompt_logprobs is not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError):
SamplingParams(prompt_logprobs=3)
# Test that logprobs < 0 is not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
with self.assertRaises(ValueError):
SamplingParams(logprobs=-1)
def test_logprobs_with_other_parameters(self):
"""Test logprobs validation with other sampling parameters"""
# Test with temperature
params = SamplingParams(logprobs=5, temperature=0.8)
params._verify_args() # Should pass
# Test with temperature when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=5, temperature=0.8)
params._verify_args() # Should pass
# Test with top_p
params = SamplingParams(logprobs=5, top_p=0.9)
params._verify_args() # Should pass
# Test with top_p when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=5, top_p=0.9)
params._verify_args() # Should pass
# Test with all parameters
params = SamplingParams(logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100)
params._verify_args() # Should pass
# Test with all parameters when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(
logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100
)
params._verify_args() # Should pass
# Test that prompt_logprobs fails when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError):
params = SamplingParams(
logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100
)
params._verify_args()
if __name__ == "__main__":