mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00

* add stable ci * fix * update * fix * rename tests dir;fix stable ci bug * add timeout limit * update
55 lines
1.7 KiB
Python
55 lines
1.7 KiB
Python
#!/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
# @author DDDivano
|
||
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
|
||
|
||
|
||
from core import TEMPLATE, URL, build_request_payload, get_probs_list, send_request
|
||
|
||
|
||
def test_repetition_early_stop():
|
||
"""
|
||
用于验证 repetition early stop 功能是否生效:
|
||
设置 window_size=6,threshold=0.93,输入内容设计成易重复,观察模型是否提前截断输出。
|
||
threshold = 0.93
|
||
window_size = 6 这个必须是启动模型的时候加上这个参数 负责不能用!!!!
|
||
"""
|
||
|
||
data = {
|
||
"stream": False,
|
||
"messages": [
|
||
{"role": "user", "content": "输出'我爱吃果冻' 10次"},
|
||
],
|
||
"max_tokens": 10000,
|
||
"temperature": 0.8,
|
||
"top_p": 0,
|
||
}
|
||
|
||
payload = build_request_payload(TEMPLATE, data)
|
||
response = send_request(URL, payload).json()
|
||
content = response["choices"][0]["message"]["content"]
|
||
|
||
print("🧪 repetition early stop 输出内容:\n", content)
|
||
probs_list = get_probs_list(response)
|
||
|
||
threshold = 0.93
|
||
window_size = 6
|
||
|
||
assert len(probs_list) >= window_size, "列表长度不足 window_size"
|
||
|
||
# 条件 1:末尾 6 个都 > threshold
|
||
tail = probs_list[-window_size:]
|
||
assert all(v > threshold for v in tail), "末尾 window_size 个数不全大于阈值"
|
||
|
||
# 条件 2:前面不能有连续 >=6 个值 > threshold
|
||
head = probs_list[:-window_size]
|
||
count = 0
|
||
for v in head:
|
||
if v > threshold:
|
||
count += 1
|
||
assert count < window_size, f"在末尾之前出现了连续 {count} 个大于阈值的数"
|
||
else:
|
||
count = 0
|
||
|
||
print("repetition early stop 功能验证通过")
|