Files
FastDeploy/test/ce/server/test_repetition_early_stop.py
Divano eaae4a580d Split cases (#3297)
* add repitation early stop cases

* add repitation early stop cases

* split repetition_early_stop from the base test
2025-08-11 09:38:35 +08:00

55 lines
1.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/bin/env python3
# -*- coding: utf-8 -*-
# @author DDDivano
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
from core import TEMPLATE, URL, build_request_payload, get_probs_list, send_request
def test_repetition_early_stop():
"""
用于验证 repetition early stop 功能是否生效:
设置 window_size=6threshold=0.93,输入内容设计成易重复,观察模型是否提前截断输出。
threshold = 0.93
window_size = 6 这个必须是启动模型的时候加上这个参数 负责不能用!!!!
"""
data = {
"stream": False,
"messages": [
{"role": "user", "content": "输出'我爱吃果冻' 10次"},
],
"max_tokens": 10000,
"temperature": 0.8,
"top_p": 0,
}
payload = build_request_payload(TEMPLATE, data)
response = send_request(URL, payload).json()
content = response["choices"][0]["message"]["content"]
print("🧪 repetition early stop 输出内容:\n", content)
probs_list = get_probs_list(response)
threshold = 0.93
window_size = 6
assert len(probs_list) >= window_size, "列表长度不足 window_size"
# 条件 1末尾 6 个都 > threshold
tail = probs_list[-window_size:]
assert all(v > threshold for v in tail), "末尾 window_size 个数不全大于阈值"
# 条件 2前面不能有连续 >=6 个值 > threshold
head = probs_list[:-window_size]
count = 0
for v in head:
if v > threshold:
count += 1
assert count < window_size, f"在末尾之前出现了连续 {count} 个大于阈值的数"
else:
count = 0
print("repetition early stop 功能验证通过")