Files
FastDeploy/tests/ce/server/test_compare_top_logprobs.py
YUNSHEN XIE 3a6058e445 Add stable ci (#3460)
* add stable ci

* fix

* update

* fix

* rename tests dir;fix stable ci bug

* add timeout limit

* update
2025-08-20 08:57:17 +08:00

118 lines
3.7 KiB
Python

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from core import TEMPLATE, URL, build_request_payload, send_request
def get_response(data):
"""
Get the response from the API using the given data.
Args:
data (dict): The input data to be sent to the API.
Returns:
dict: The JSON response from the API.
"""
payload = build_request_payload(TEMPLATE, data)
resp = send_request(URL, payload)
return resp.json()
def assert_top_logprobs_prefix_match(small_top, large_top, token_index):
"""
Assert that all entries in small_top are a prefix of large_top,
comparing token, logprob, and bytes values.
"""
for j, (s, l) in enumerate(zip(small_top, large_top)):
for field in ["token", "logprob", "bytes"]:
s_val = s[field]
l_val = l[field]
assert s_val == l_val, "{} mismatch at token {} pos {}: {} != {}".format(
field.capitalize(), token_index + 1, j + 1, repr(s_val), repr(l_val)
)
def compare_top_logprobs(base_data, top_logprobs_values=[5, 10]):
"""
Compare the top logprobs of two different values and check if they match.
Args:
base_data (dict): The base data used for generating the responses.
top_logprobs_values (list): A list of integers representing the top logprobs values to compare.
Raises:
AssertionError: If any mismatches are found between the top logprobs values.
"""
responses = {}
for val in top_logprobs_values:
data = base_data.copy()
data.update(
{
"top_logprobs": val,
"logprobs": True,
"stream": False,
"temperature": 0,
"top_p": 0,
"max_tokens": 10,
}
)
response = get_response(data)
responses[val] = response
# Assertion for prefix consistency
if len(top_logprobs_values) >= 2:
small = top_logprobs_values[0]
large = top_logprobs_values[1]
small_contents = responses[small]["choices"][0]["logprobs"]["content"]
large_contents = responses[large]["choices"][0]["logprobs"]["content"]
min_len = min(len(small_contents), len(large_contents))
for i in range(min_len):
small_top = small_contents[i]["top_logprobs"]
large_top = large_contents[i]["top_logprobs"]
assert_top_logprobs_prefix_match(small_top, large_top, i)
def test_compare_top_logprobs():
"""
Test the compare_top_logprobs function with a sample input data.
Returns:
None
AssertionError: If there is a mismatch between the top logprobs values.
"""
data = {
"model": "default",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
}
compare_top_logprobs(data, top_logprobs_values=[5, 10])
if __name__ == "__main__":
"""
Test the compare_top_logprobs function with a sample input data.
Returns:
None
AssertionError: If there is a mismatch between the top logprobs values.
"""
test_compare_top_logprobs()