[stop_seq] fix out-bound value for stop sequence (#3216)

* fix out-bound value for stop sequence

* catch error if there are out-of-bounds value

* check in offline mode

* add ut tests
This commit is contained in:
JYChen
2025-08-07 15:40:21 +08:00
committed by GitHub
parent 5885285e57
commit 9423c577fe
3 changed files with 68 additions and 0 deletions

View File

@@ -530,6 +530,26 @@ class LLMEngine:
llm_logger.error(error_msg) llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=400) raise EngineError(error_msg, error_code=400)
if request.get("stop_seqs_len") is not None:
stop_seqs_len = request.get("stop_seqs_len")
max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
if len(stop_seqs_len) > max_stop_seqs_num:
error_msg = (
f"Length of stop ({stop_seqs_len}) exceeds the limit max_stop_seqs_num({max_stop_seqs_num})."
"Please reduce the number of stop or set a lager max_stop_seqs_num by `FD_MAX_STOP_SEQS_NUM`"
)
llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=400)
stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
for single_stop_seq_len in stop_seqs_len:
if single_stop_seq_len > stop_seqs_max_len:
error_msg = (
f"Length of stop_seqs({single_stop_seq_len}) exceeds the limit stop_seqs_max_len({stop_seqs_max_len})."
"Please reduce the length of stop sequences or set a larger stop_seqs_max_len by `FD_STOP_SEQS_MAX_LEN`"
)
llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=400)
if self.guided_decoding_checker is not None: if self.guided_decoding_checker is not None:
request, err_msg = self.guided_decoding_checker.schema_format(request) request, err_msg = self.guided_decoding_checker.schema_format(request)
if err_msg is not None: if err_msg is not None:

View File

@@ -19,6 +19,7 @@ import uuid
import numpy as np import numpy as np
from fastdeploy import envs
from fastdeploy.engine.config import ModelConfig from fastdeploy.engine.config import ModelConfig
from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import IPCSignal, ZmqClient from fastdeploy.inter_communicator import IPCSignal, ZmqClient
@@ -154,6 +155,26 @@ class EngineClient:
api_server_logger.error(error_msg) api_server_logger.error(error_msg)
raise EngineError(error_msg, error_code=400) raise EngineError(error_msg, error_code=400)
if "stop_seqs_len" in task:
stop_seqs_len = task["stop_seqs_len"]
max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
if len(stop_seqs_len) > max_stop_seqs_num:
error_msg = (
f"Length of stop ({stop_seqs_len}) exceeds the limit max_stop_seqs_num({max_stop_seqs_num})."
"Please reduce the number of stop or set a lager max_stop_seqs_num by `FD_MAX_STOP_SEQS_NUM`"
)
api_server_logger.error(error_msg)
raise EngineError(error_msg, error_code=400)
stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
for single_stop_seq_len in stop_seqs_len:
if single_stop_seq_len > stop_seqs_max_len:
error_msg = (
f"Length of stop_seqs({single_stop_seq_len}) exceeds the limit stop_seqs_max_len({stop_seqs_max_len})."
"Please reduce the length of stop sequences or set a larger stop_seqs_max_len by `FD_STOP_SEQS_MAX_LEN`"
)
api_server_logger.error(error_msg)
raise EngineError(error_msg, error_code=400)
task["preprocess_end_time"] = time.time() task["preprocess_end_time"] = time.time()
preprocess_cost_time = task["preprocess_end_time"] - task["preprocess_start_time"] preprocess_cost_time = task["preprocess_end_time"] - task["preprocess_start_time"]
api_server_logger.info( api_server_logger.info(

View File

@@ -95,3 +95,30 @@ def test_mixed_valid_invalid_fields():
resp = send_request(URL, payload).json() resp = send_request(URL, payload).json()
assert "error" not in resp, "非法字段不应导致请求失败" assert "error" not in resp, "非法字段不应导致请求失败"
def test_stop_seq_exceed_num():
"""stop 字段包含超过 FD_MAX_STOP_SEQS_NUM 个元素,服务应报错"""
data = {
"stream": False,
"messages": [{"role": "user", "content": "非洲的首都是?"}],
"top_p": 0,
"stop": ["11", "22", "33", "44", "55", "66", "77"],
}
payload = build_request_payload(TEMPLATE, data)
resp = send_request(URL, payload).json()
assert resp.get("object") == "error", "stop 超出个数应触发异常"
assert "exceeds the limit max_stop_seqs_num" in resp.get("message", ""), "未返回预期的报错信息"
def test_stop_seq_exceed_length():
"""stop 中包含长度超过 FD_STOP_SEQS_MAX_LEN 的元素,服务应报错"""
data = {
"stream": False,
"messages": [{"role": "user", "content": "非洲的首都是?"}],
"top_p": 0,
"stop": ["11", "今天天气比明天好多了,请问你会出门还是和我一起玩"],
}
payload = build_request_payload(TEMPLATE, data)
resp = send_request(URL, payload).json()
assert resp.get("object") == "error", "stop 超出长度应触发异常"
assert "exceeds the limit stop_seqs_max_len" in resp.get("message", ""), "未返回预期的报错信息"