[format] Valid para format error info (#4035)

* feat(log):add_request_and_response_log

* 报错信息与OpenAI对齐
This commit is contained in:
xiaolei373
2025-09-12 19:05:17 +08:00
committed by GitHub
parent 88ea565aba
commit 9ac539471d
11 changed files with 435 additions and 90 deletions

View File

@@ -31,7 +31,12 @@ from fastdeploy.inter_communicator import IPCSignal, ZmqClient
from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.multimodal.registry import MultimodalRegistry
from fastdeploy.platforms import current_platform
from fastdeploy.utils import EngineError, StatefulSemaphore, api_server_logger
from fastdeploy.utils import (
EngineError,
ParameterError,
StatefulSemaphore,
api_server_logger,
)
class EngineClient:
@@ -218,42 +223,21 @@ class EngineClient:
def valid_parameters(self, data):
"""
Validate stream options
超参数top_p、seed、frequency_penalty、temperature、presence_penalty的校验逻辑
前置到了ChatCompletionRequest/CompletionRequest中
"""
if data.get("n") is not None:
if data["n"] != 1:
raise ValueError("n only support 1.")
raise ParameterError("n", "n only support 1.")
if data.get("max_tokens") is not None:
if data["max_tokens"] < 1 or data["max_tokens"] >= self.max_model_len:
raise ValueError(f"max_tokens can be defined [1, {self.max_model_len}).")
raise ParameterError("max_tokens", f"max_tokens can be defined [1, {self.max_model_len}).")
if data.get("reasoning_max_tokens") is not None:
if data["reasoning_max_tokens"] > data["max_tokens"] or data["reasoning_max_tokens"] < 1:
raise ValueError("reasoning_max_tokens must be between max_tokens and 1")
if data.get("top_p") is not None:
if data["top_p"] > 1 or data["top_p"] < 0:
raise ValueError("top_p value can only be defined [0, 1].")
if data.get("frequency_penalty") is not None:
if not -2.0 <= data["frequency_penalty"] <= 2.0:
raise ValueError("frequency_penalty must be in [-2, 2]")
if data.get("temperature") is not None:
if data["temperature"] < 0:
raise ValueError("temperature must be non-negative")
if data.get("presence_penalty") is not None:
if not -2.0 <= data["presence_penalty"] <= 2.0:
raise ValueError("presence_penalty must be in [-2, 2]")
if data.get("seed") is not None:
if not 0 <= data["seed"] <= 922337203685477580:
raise ValueError("seed must be in [0, 922337203685477580]")
if data.get("stream_options") and not data.get("stream"):
raise ValueError("Stream options can only be defined when `stream=True`.")
raise ParameterError("reasoning_max_tokens", "reasoning_max_tokens must be between max_tokens and 1")
# logprobs
logprobs = data.get("logprobs")
@@ -263,35 +247,35 @@ class EngineClient:
if not self.enable_logprob:
err_msg = "Logprobs is disabled, please enable it in startup config."
api_server_logger.error(err_msg)
raise ValueError(err_msg)
raise ParameterError("logprobs", err_msg)
top_logprobs = data.get("top_logprobs")
elif isinstance(logprobs, int):
top_logprobs = logprobs
elif logprobs:
raise ValueError("Invalid type for 'logprobs'")
raise ParameterError("logprobs", "Invalid type for 'logprobs'")
# enable_logprob
if top_logprobs:
if not self.enable_logprob:
err_msg = "Logprobs is disabled, please enable it in startup config."
api_server_logger.error(err_msg)
raise ValueError(err_msg)
raise ParameterError("logprobs", err_msg)
if not isinstance(top_logprobs, int):
err_type = type(top_logprobs).__name__
err_msg = f"Invalid type for 'top_logprobs': expected int but got {err_type}."
api_server_logger.error(err_msg)
raise ValueError(err_msg)
raise ParameterError("top_logprobs", err_msg)
if top_logprobs < 0:
err_msg = f"Invalid 'top_logprobs': must be >= 0, got {top_logprobs}."
api_server_logger.error(err_msg)
raise ValueError(err_msg)
raise ParameterError("top_logprobs", err_msg)
if top_logprobs > 20:
err_msg = "Invalid value for 'top_logprobs': must be <= 20."
api_server_logger.error(err_msg)
raise ValueError(err_msg)
raise ParameterError("top_logprobs", err_msg)
def check_health(self, time_interval_threashold=30):
"""

View File

@@ -26,6 +26,7 @@ from multiprocessing import current_process
import uvicorn
import zmq
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import CONTENT_TYPE_LATEST
@@ -40,6 +41,7 @@ from fastdeploy.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
ControlSchedulerRequest,
ErrorInfo,
ErrorResponse,
ModelList,
)
@@ -56,6 +58,7 @@ from fastdeploy.metrics.metrics import (
)
from fastdeploy.metrics.trace_util import fd_start_span, inject_to_metadata, instrument
from fastdeploy.utils import (
ExceptionHandler,
FlexibleArgumentParser,
StatefulSemaphore,
api_server_logger,
@@ -232,6 +235,8 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan)
app.add_exception_handler(RequestValidationError, ExceptionHandler.handle_request_validation_exception)
app.add_exception_handler(Exception, ExceptionHandler.handle_exception)
instrument(app)
@@ -336,7 +341,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
if isinstance(generator, ErrorResponse):
api_server_logger.debug(f"release: {connection_semaphore.status()}")
connection_semaphore.release()
return JSONResponse(content={"detail": generator.model_dump()}, status_code=generator.code)
return JSONResponse(content=generator.model_dump(), status_code=500)
elif isinstance(generator, ChatCompletionResponse):
api_server_logger.debug(f"release: {connection_semaphore.status()}")
connection_semaphore.release()
@@ -365,7 +370,7 @@ async def create_completion(request: CompletionRequest):
generator = await app.state.completion_handler.create_completion(request)
if isinstance(generator, ErrorResponse):
connection_semaphore.release()
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
return JSONResponse(content=generator.model_dump(), status_code=500)
elif isinstance(generator, CompletionResponse):
connection_semaphore.release()
return JSONResponse(content=generator.model_dump())
@@ -388,7 +393,7 @@ async def list_models() -> Response:
models = await app.state.model_handler.list_models()
if isinstance(models, ErrorResponse):
return JSONResponse(content=models.model_dump(), status_code=models.code)
return JSONResponse(content=models.model_dump())
elif isinstance(models, ModelList):
return JSONResponse(content=models.model_dump())
@@ -502,7 +507,8 @@ def control_scheduler(request: ControlSchedulerRequest):
"""
Control the scheduler behavior with the given parameters.
"""
content = ErrorResponse(object="", message="Scheduler updated successfully", code=0)
content = ErrorResponse(error=ErrorInfo(message="Scheduler updated successfully", code=0))
global llm_engine
if llm_engine is None:

View File

@@ -32,9 +32,14 @@ class ErrorResponse(BaseModel):
Error response from OpenAI API.
"""
object: str = "error"
error: ErrorInfo
class ErrorInfo(BaseModel):
message: str
code: int
type: Optional[str] = None
param: Optional[str] = None
code: Optional[str] = None
class PromptTokenUsageInfo(BaseModel):
@@ -403,21 +408,21 @@ class CompletionRequest(BaseModel):
prompt: Union[List[int], List[List[int]], str, List[str]]
best_of: Optional[int] = None
echo: Optional[bool] = False
frequency_penalty: Optional[float] = None
frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
logprobs: Optional[int] = None
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
max_tokens: Optional[int] = None
n: int = 1
presence_penalty: Optional[float] = None
seed: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
seed: Optional[int] = Field(default=None, ge=0, le=922337203685477580)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
suffix: Optional[dict] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
temperature: Optional[float] = Field(default=None, ge=0)
top_p: Optional[float] = Field(default=None, ge=0, le=1)
user: Optional[str] = None
# doc: begin-completion-sampling-params
@@ -537,7 +542,7 @@ class ChatCompletionRequest(BaseModel):
messages: Union[List[Any], List[int]]
tools: Optional[List[ChatCompletionToolsParam]] = None
model: Optional[str] = "default"
frequency_penalty: Optional[float] = None
frequency_penalty: Optional[float] = Field(None, le=2, ge=-2)
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0
@@ -552,13 +557,13 @@ class ChatCompletionRequest(BaseModel):
)
max_completion_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = None
seed: Optional[int] = None
presence_penalty: Optional[float] = Field(None, le=2, ge=-2)
seed: Optional[int] = Field(default=None, ge=0, le=922337203685477580)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
temperature: Optional[float] = Field(None, ge=0)
top_p: Optional[float] = Field(None, le=1, ge=0)
user: Optional[str] = None
metadata: Optional[dict] = None
response_format: Optional[AnyResponseFormat] = None

View File

@@ -30,6 +30,7 @@ from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
ErrorInfo,
ErrorResponse,
LogProbEntry,
LogProbs,
@@ -38,7 +39,7 @@ from fastdeploy.entrypoints.openai.protocol import (
)
from fastdeploy.entrypoints.openai.response_processors import ChatResponseProcessor
from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.utils import api_server_logger
from fastdeploy.utils import ErrorCode, ErrorType, ParameterError, api_server_logger
from fastdeploy.worker.output import LogprobsLists
@@ -86,14 +87,16 @@ class OpenAIServingChat:
f"Only master node can accept completion request, please send request to master node: {self.master_ip}"
)
api_server_logger.error(err_msg)
return ErrorResponse(message=err_msg, code=400)
return ErrorResponse(error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR))
if self.models:
is_supported, request.model = self.models.is_supported_model(request.model)
if not is_supported:
err_msg = f"Unsupported model: [{request.model}], support [{', '.join([x.name for x in self.models.model_paths])}] or default"
api_server_logger.error(err_msg)
return ErrorResponse(message=err_msg, code=400)
return ErrorResponse(
error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR, code=ErrorCode.MODEL_NOT_SUPPORT)
)
try:
if self.max_waiting_time < 0:
@@ -117,11 +120,17 @@ class OpenAIServingChat:
text_after_process = current_req_dict.get("text_after_process")
if isinstance(prompt_token_ids, np.ndarray):
prompt_token_ids = prompt_token_ids.tolist()
except ParameterError as e:
api_server_logger.error(e.message)
self.engine_client.semaphore.release()
return ErrorResponse(
error=ErrorInfo(message=str(e.message), type=ErrorType.INVALID_REQUEST_ERROR, param=e.param)
)
except Exception as e:
error_msg = f"request[{request_id}] generator error: {str(e)}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg)
self.engine_client.semaphore.release()
return ErrorResponse(code=400, message=error_msg)
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INVALID_REQUEST_ERROR))
del current_req_dict
if request.stream:
@@ -136,21 +145,20 @@ class OpenAIServingChat:
except Exception as e:
error_msg = f"request[{request_id}]full generator error: {str(e)}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg)
return ErrorResponse(code=408, message=error_msg)
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.SERVER_ERROR))
except Exception as e:
error_msg = (
f"request[{request_id}] waiting error: {str(e)}, {str(traceback.format_exc())}, "
f"max waiting time: {self.max_waiting_time}"
)
api_server_logger.error(error_msg)
return ErrorResponse(code=408, message=error_msg)
return ErrorResponse(
error=ErrorInfo(message=error_msg, type=ErrorType.TIMEOUT_ERROR, code=ErrorCode.TIMEOUT)
)
def _create_streaming_error_response(self, message: str) -> str:
api_server_logger.error(message)
error_response = ErrorResponse(
code=400,
message=message,
)
error_response = ErrorResponse(error=ErrorInfo(message=message, type=ErrorType.SERVER_ERROR))
return error_response.model_dump_json()
async def chat_completion_stream_generator(

View File

@@ -30,10 +30,11 @@ from fastdeploy.entrypoints.openai.protocol import (
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorInfo,
ErrorResponse,
UsageInfo,
)
from fastdeploy.utils import api_server_logger
from fastdeploy.utils import ErrorCode, ErrorType, ParameterError, api_server_logger
from fastdeploy.worker.output import LogprobsLists
@@ -63,13 +64,15 @@ class OpenAIServingCompletion:
f"Only master node can accept completion request, please send request to master node: {self.master_ip}"
)
api_server_logger.error(err_msg)
return ErrorResponse(message=err_msg, code=400)
return ErrorResponse(error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR))
if self.models:
is_supported, request.model = self.models.is_supported_model(request.model)
if not is_supported:
err_msg = f"Unsupported model: [{request.model}], support [{', '.join([x.name for x in self.models.model_paths])}] or default"
api_server_logger.error(err_msg)
return ErrorResponse(message=err_msg, code=400)
return ErrorResponse(
error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR, code=ErrorCode.MODEL_NOT_SUPPORT)
)
created_time = int(time.time())
if request.user is not None:
request_id = f"cmpl-{request.user}-{uuid.uuid4()}"
@@ -112,7 +115,7 @@ class OpenAIServingCompletion:
except Exception as e:
error_msg = f"OpenAIServingCompletion create_completion: {e}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg)
return ErrorResponse(message=error_msg, code=400)
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.SERVER_ERROR))
if request_prompt_ids is not None:
request_prompts = request_prompt_ids
@@ -132,7 +135,9 @@ class OpenAIServingCompletion:
f"max waiting time: {self.max_waiting_time}"
)
api_server_logger.error(error_msg)
return ErrorResponse(code=408, message=error_msg)
return ErrorResponse(
error=ErrorInfo(message=error_msg, code=ErrorCode.TIMEOUT, type=ErrorType.TIMEOUT_ERROR)
)
try:
try:
@@ -146,11 +151,17 @@ class OpenAIServingCompletion:
text_after_process_list.append(current_req_dict.get("text_after_process"))
prompt_batched_token_ids.append(prompt_token_ids)
del current_req_dict
except ParameterError as e:
api_server_logger.error(e.message)
self.engine_client.semaphore.release()
return ErrorResponse(code=400, message=str(e.message), type="invalid_request", param=e.param)
except Exception as e:
error_msg = f"OpenAIServingCompletion format error: {e}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg)
self.engine_client.semaphore.release()
return ErrorResponse(message=str(e), code=400)
return ErrorResponse(
error=ErrorInfo(message=str(e), code=ErrorCode.INVALID_VALUE, type=ErrorType.INVALID_REQUEST_ERROR)
)
if request.stream:
return self.completion_stream_generator(
@@ -178,12 +189,12 @@ class OpenAIServingCompletion:
f"OpenAIServingCompletion completion_full_generator error: {e}, {str(traceback.format_exc())}"
)
api_server_logger.error(error_msg)
return ErrorResponse(code=400, message=error_msg)
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.SERVER_ERROR))
except Exception as e:
error_msg = f"OpenAIServingCompletion create_completion error: {e}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg)
return ErrorResponse(message=error_msg, code=400)
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.SERVER_ERROR))
async def completion_full_generator(
self,

View File

@@ -18,12 +18,13 @@ from dataclasses import dataclass
from typing import List, Union
from fastdeploy.entrypoints.openai.protocol import (
ErrorInfo,
ErrorResponse,
ModelInfo,
ModelList,
ModelPermission,
)
from fastdeploy.utils import api_server_logger, get_host_ip
from fastdeploy.utils import ErrorType, api_server_logger, get_host_ip
@dataclass
@@ -86,7 +87,7 @@ class OpenAIServingModels:
f"Only master node can accept models request, please send request to master node: {self.master_ip}"
)
api_server_logger.error(err_msg)
return ErrorResponse(message=err_msg, code=400)
return ErrorResponse(error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR))
model_infos = [
ModelInfo(
id=model.name, max_model_len=self.max_model_len, root=model.model_path, permission=[ModelPermission()]

View File

@@ -28,6 +28,8 @@ import sys
import tarfile
import time
from datetime import datetime
from enum import Enum
from http import HTTPStatus
from importlib.metadata import PackageNotFoundError, distribution
from logging.handlers import BaseRotatingHandler
from pathlib import Path
@@ -38,10 +40,14 @@ import paddle
import requests
import yaml
from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from tqdm import tqdm
from typing_extensions import TypeIs, assert_never
from fastdeploy import envs
from fastdeploy.entrypoints.openai.protocol import ErrorInfo, ErrorResponse
from fastdeploy.logger.logger import FastDeployLogger
T = TypeVar("T")
@@ -59,6 +65,61 @@ class EngineError(Exception):
self.error_code = error_code
class ParameterError(Exception):
def __init__(self, param: str, message: str):
self.param = param
self.message = message
super().__init__(message)
class ExceptionHandler:
# 全局异常兜底处理
@staticmethod
async def handle_exception(request: Request, exc: Exception) -> JSONResponse:
error = ErrorResponse(error=ErrorInfo(message=str(exc), type=ErrorType.INTERNAL_ERROR))
return JSONResponse(content=error.model_dump(), status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
# 处理请求参数验证异常
@staticmethod
async def handle_request_validation_exception(_: Request, exc: RequestValidationError) -> JSONResponse:
errors = exc.errors()
if not errors:
message = str(exc)
param = None
else:
first_error = errors[0]
loc = first_error.get("loc", [])
param = loc[-1] if loc else None
message = first_error.get("msg", str(exc))
err = ErrorResponse(
error=ErrorInfo(
message=message,
type=ErrorType.INVALID_REQUEST_ERROR,
code=ErrorCode.MISSING_REQUIRED_PARAMETER if param == "messages" else ErrorCode.INVALID_VALUE,
param=param,
)
)
return JSONResponse(content=err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
class ErrorType(str, Enum):
INVALID_REQUEST_ERROR = "invalid_request_error"
TIMEOUT_ERROR = "timeout_error"
SERVER_ERROR = "server_error"
INTERNAL_ERROR = "internal_error"
API_CONNECTION_ERROR = "api_connection_error"
class ErrorCode(str, Enum):
INVALID_VALUE = "invalid_value"
CONTEXT_LENGTH_EXCEEDED = "context_length_exceeded"
MODEL_NOT_SUPPORT = "model_not_support"
TIMEOUT = "timeout"
CONNECTION_ERROR = "connection_error"
MISSING_REQUIRED_PARAMETER = "missing_required_parameter"
class ColoredFormatter(logging.Formatter):
"""自定义日志格式器,用于控制台输出带颜色"""

View File

@@ -20,9 +20,14 @@ def test_missing_messages_field():
payload = build_request_payload(TEMPLATE, data)
resp = send_request(URL, payload).json()
assert "detail" in resp, "返回中未包含 detail 错误信息字段"
assert any("messages" in err.get("loc", []) for err in resp["detail"]), "未检测到 messages 字段缺失的报错"
assert any("Field required" in err.get("msg", "") for err in resp["detail"]), "未检测到 'Field required' 错误提示"
assert "error" in resp, "返回中未包含 error 错误信息字段"
error = resp["error"]
assert "Field required" in error.get("message", ""), "未检测到 messages 字段缺失的报错"
assert error.get("code") == "missing_required_parameter", "code 字段不正确"
# assert "detail" in resp, "返回中未包含 detail 错误信息字段"
# assert any("messages" in err.get("loc", []) for err in resp["detail"]), "未检测到 messages 字段缺失的报错"
# assert any("Field required" in err.get("msg", "") for err in resp["detail"]), "未检测到 'Field required' 错误提示"
def test_malformed_messages_format():
@@ -34,11 +39,15 @@ def test_malformed_messages_format():
}
payload = build_request_payload(TEMPLATE, data)
resp = send_request(URL, payload).json()
assert "detail" in resp, "非法结构未被识别"
assert any("messages" in err.get("loc", []) for err in resp["detail"]), "未检测到 messages 字段结构错误"
assert any(
"Input should be a valid list" in err.get("msg", "") for err in resp["detail"]
), "未检测到 'Input should be a valid list' 错误提示"
assert "error" in resp, "非法结构未被识别"
err = resp["error"]
assert err.get("param") == "list[any]", f"param 字段错误: {err.get('param')}"
assert err.get("message") == "Input should be a valid list", "错误提示不符合预期"
# assert "detail" in resp, "非法结构未被识别"
# assert any("messages" in err.get("loc", []) for err in resp["detail"]), "未检测到 messages 字段结构错误"
# assert any(
# "Input should be a valid list" in err.get("msg", "") for err in resp["detail"]
# ), "未检测到 'Input should be a valid list' 错误提示"
def test_extremely_large_max_tokens():
@@ -79,8 +88,13 @@ def test_top_p_exceed_1():
}
payload = build_request_payload(TEMPLATE, data)
resp = send_request(URL, payload).json()
assert resp.get("detail").get("object") == "error", "top_p > 1 应触发校验异常"
assert "top_p value can only be defined" in resp.get("detail").get("message", ""), "未返回预期的 top_p 错误信息"
assert "error" in resp, "未返回 error 字段"
err = resp["error"]
assert err.get("param") == "top_p", f"param 字段错误: {err.get('param')}"
assert err.get("message") == "Input should be less than or equal to 1", "错误提示不符合预期"
# assert resp.get("detail").get("object") == "error", "top_p > 1 应触发校验异常"
# assert "top_p value can only be defined" in resp.get("detail").get("message", ""), "未返回预期的 top_p 错误信息"
def test_mixed_valid_invalid_fields():
@@ -106,8 +120,8 @@ def test_stop_seq_exceed_num():
}
payload = build_request_payload(TEMPLATE, data)
resp = send_request(URL, payload).json()
assert resp.get("detail").get("object") == "error", "stop 超出个数应触发异常"
assert "exceeds the limit max_stop_seqs_num" in resp.get("detail").get("message", ""), "未返回预期的报错信息"
assert resp.get("error").get("type") == "invalid_request_error", "stop 超出个数应触发异常"
assert "exceeds the limit max_stop_seqs_num" in resp.get("error").get("message", ""), "未返回预期的报错信息"
def test_stop_seq_exceed_length():
@@ -120,8 +134,8 @@ def test_stop_seq_exceed_length():
}
payload = build_request_payload(TEMPLATE, data)
resp = send_request(URL, payload).json()
assert resp.get("detail").get("object") == "error", "stop 超出长度应触发异常"
assert "exceeds the limit stop_seqs_max_len" in resp.get("detail").get("message", ""), "未返回预期的报错信息"
assert resp.get("error").get("type") == "invalid_request_error", "stop 超出长度应触发异常"
assert "exceeds the limit stop_seqs_max_len" in resp.get("error").get("message", ""), "未返回预期的报错信息"
def test_multilingual_input():
@@ -154,8 +168,8 @@ def test_too_long_input():
data = {"messages": [{"role": "user", "content": "a" * 200000}], "stream": False} # 超过最大输入长度
payload = build_request_payload(TEMPLATE, data)
resp = send_request(URL, payload).json()
assert resp["detail"].get("object") == "error", "超长输入未被识别为错误"
assert "Input text is too long" in resp["detail"].get("message", ""), "未检测到最大长度限制错误"
# assert resp["detail"].get("object") == "error", "超长输入未被识别为错误"
assert "Input text is too long" in resp["error"].get("message", ""), "未检测到最大长度限制错误"
def test_empty_input():
@@ -361,8 +375,8 @@ def test_max_tokens_negative():
}
payload = build_request_payload(TEMPLATE, data)
resp = send_request(URL, payload).json()
assert resp.get("detail").get("object") == "error", "max_tokens < 0 未触发校验异常"
assert "max_tokens can be defined [1," in resp.get("detail").get("message"), "未返回预期的 max_tokens 错误信息"
# assert resp.get("detail").get("object") == "error", "max_tokens < 0 未触发校验异常"
assert "max_tokens can be defined [1," in resp.get("error").get("message"), "未返回预期的 max_tokens 错误信息"
def test_max_tokens_min():
@@ -379,7 +393,10 @@ def test_max_tokens_min():
}
payload = build_request_payload(TEMPLATE, data)
resp = send_request(URL, payload).json()
assert resp.get("detail").get("object") == "error", "max_tokens未0时API未拦截住"
# assert resp.get("detail").get("object") == "error", "max_tokens未0时API未拦截住"
assert "max_tokens can be defined [1," in resp.get("error").get(
"message"
), "max_tokens未0时API未拦截住,未返回预期的 max_tokens 错误信息"
def test_max_tokens_non_integer():
@@ -397,5 +414,5 @@ def test_max_tokens_non_integer():
payload = build_request_payload(TEMPLATE, data)
resp = send_request(URL, payload).json()
assert (
resp.get("detail")[0].get("msg") == "Input should be a valid integer, got a number with a fractional part"
resp.get("error").get("message") == "Input should be a valid integer, got a number with a fractional part"
), "未返回预期的 max_tokens 为非整数的错误信息"

View File

@@ -0,0 +1,168 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
from pydantic import ValidationError
from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
)
class TestChatCompletionRequest(unittest.TestCase):
def test_required_messages(self):
with self.assertRaises(ValidationError):
ChatCompletionRequest()
def test_messages_accepts_list_of_any_and_int(self):
req = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}])
self.assertEqual(req.messages[0]["role"], "user")
req = ChatCompletionRequest(messages=[1, 2, 3])
self.assertEqual(req.messages, [1, 2, 3])
def test_default_values(self):
req = ChatCompletionRequest(messages=[1])
self.assertEqual(req.model, "default")
self.assertFalse(req.logprobs)
self.assertEqual(req.top_logprobs, 0)
self.assertEqual(req.n, 1)
self.assertEqual(req.stop, [])
def test_boundary_values(self):
valid_cases = [
("frequency_penalty", -2),
("frequency_penalty", 2),
("presence_penalty", -2),
("presence_penalty", 2),
("temperature", 0),
("top_p", 1),
("seed", 0),
("seed", 922337203685477580),
]
for field, value in valid_cases:
with self.subTest(field=field, value=value):
req = ChatCompletionRequest(messages=[1], **{field: value})
self.assertEqual(getattr(req, field), value)
def test_invalid_boundary_values(self):
invalid_cases = [
("frequency_penalty", -3),
("frequency_penalty", 3),
("presence_penalty", -3),
("presence_penalty", 3),
("temperature", -1),
("top_p", 1.1),
("seed", -1),
("seed", 922337203685477581),
]
for field, value in invalid_cases:
with self.subTest(field=field, value=value):
with self.assertRaises(ValidationError):
ChatCompletionRequest(messages=[1], **{field: value})
def test_stop_field_accepts_str_or_list(self):
req1 = ChatCompletionRequest(messages=[1], stop="end")
self.assertEqual(req1.stop, "end")
req2 = ChatCompletionRequest(messages=[1], stop=["a", "b"])
self.assertEqual(req2.stop, ["a", "b"])
with self.assertRaises(ValidationError):
ChatCompletionRequest(messages=[1], stop=123)
def test_deprecated_max_tokens_field(self):
req = ChatCompletionRequest(messages=[1], max_tokens=10)
self.assertEqual(req.max_tokens, 10)
def test_field_names_snapshot(self):
expected_fields = set(ChatCompletionRequest.__fields__.keys())
self.assertEqual(set(ChatCompletionRequest.__fields__.keys()), expected_fields)
class TestCompletionRequest(unittest.TestCase):
def test_required_prompt(self):
with self.assertRaises(ValidationError):
CompletionRequest()
def test_prompt_accepts_various_types(self):
# str
req = CompletionRequest(prompt="hello")
self.assertEqual(req.prompt, "hello")
# list of str
req = CompletionRequest(prompt=["hello", "world"])
self.assertEqual(req.prompt, ["hello", "world"])
# list of int
req = CompletionRequest(prompt=[1, 2, 3])
self.assertEqual(req.prompt, [1, 2, 3])
# list of list of int
req = CompletionRequest(prompt=[[1, 2], [3, 4]])
self.assertEqual(req.prompt, [[1, 2], [3, 4]])
def test_default_values(self):
req = CompletionRequest(prompt="test")
self.assertEqual(req.model, "default")
self.assertEqual(req.echo, False)
self.assertEqual(req.temp_scaled_logprobs, False)
self.assertEqual(req.top_p_normalized_logprobs, False)
self.assertEqual(req.n, 1)
self.assertEqual(req.stop, [])
self.assertEqual(req.stream, False)
def test_boundary_values(self):
valid_cases = [
("frequency_penalty", -2),
("frequency_penalty", 2),
("presence_penalty", -2),
("presence_penalty", 2),
("temperature", 0),
("top_p", 0),
("top_p", 1),
("seed", 0),
("seed", 922337203685477580),
]
for field, value in valid_cases:
with self.subTest(field=field, value=value):
req = CompletionRequest(prompt="hi", **{field: value})
self.assertEqual(getattr(req, field), value)
def test_invalid_boundary_values(self):
invalid_cases = [
("frequency_penalty", -3),
("frequency_penalty", 3),
("presence_penalty", -3),
("presence_penalty", 3),
("temperature", -0.1),
("top_p", -0.1),
("top_p", 1.1),
("seed", -1),
("seed", 922337203685477581),
]
for field, value in invalid_cases:
with self.subTest(field=field, value=value):
with self.assertRaises(ValidationError):
CompletionRequest(prompt="hi", **{field: value})
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,30 @@
import unittest
from pydantic import ValidationError
from fastdeploy.entrypoints.openai.protocol import ErrorResponse
class TestErrorResponse(unittest.TestCase):
def test_valid_error_response(self):
data = {
"error": {
"message": "Invalid top_p value",
"type": "invalid_request_error",
"param": "top_p",
"code": "invalid_value",
}
}
err_resp = ErrorResponse(**data)
self.assertEqual(err_resp.error.message, "Invalid top_p value")
self.assertEqual(err_resp.error.param, "top_p")
self.assertEqual(err_resp.error.code, "invalid_value")
def test_missing_message_field(self):
data = {"error": {"type": "invalid_request_error", "param": "messages", "code": "missing_required_parameter"}}
with self.assertRaises(ValidationError):
ErrorResponse(**data)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,54 @@
import json
import unittest
from http import HTTPStatus
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastdeploy.utils import ErrorCode, ExceptionHandler, ParameterError
class TestParameterError(unittest.TestCase):
def test_parameter_error_init(self):
exc = ParameterError("param1", "error message")
self.assertEqual(exc.param, "param1")
self.assertEqual(exc.message, "error message")
self.assertEqual(str(exc), "error message")
class TestExceptionHandler(unittest.IsolatedAsyncioTestCase):
async def test_handle_exception(self):
"""普通异常应返回 500 + internal_error"""
exc = RuntimeError("Something went wrong")
resp: JSONResponse = await ExceptionHandler.handle_exception(None, exc)
body = json.loads(resp.body.decode())
self.assertEqual(resp.status_code, HTTPStatus.INTERNAL_SERVER_ERROR)
self.assertEqual(body["error"]["type"], "internal_error")
self.assertIn("Something went wrong", body["error"]["message"])
async def test_handle_request_validation_missing_messages(self):
"""缺少 messages 参数时,应返回 missing_required_parameter"""
exc = RequestValidationError([{"loc": ("body", "messages"), "msg": "Field required", "type": "missing"}])
resp: JSONResponse = await ExceptionHandler.handle_request_validation_exception(None, exc)
data = json.loads(resp.body.decode())
self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
self.assertEqual(data["error"]["param"], "messages")
self.assertEqual(data["error"]["code"], ErrorCode.MISSING_REQUIRED_PARAMETER)
self.assertIn("Field required", data["error"]["message"])
async def test_handle_request_validation_invalid_value(self):
"""参数非法时,应返回 invalid_value"""
exc = RequestValidationError(
[{"loc": ("body", "top_p"), "msg": "Input should be less than or equal to 1", "type": "value_error"}]
)
resp: JSONResponse = await ExceptionHandler.handle_request_validation_exception(None, exc)
data = json.loads(resp.body.decode())
self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
self.assertEqual(data["error"]["param"], "top_p")
self.assertEqual(data["error"]["code"], ErrorCode.INVALID_VALUE)
self.assertIn("less than or equal to 1", data["error"]["message"])
if __name__ == "__main__":
unittest.main()