diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 777689c73..b6c0008c3 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -31,7 +31,12 @@ from fastdeploy.inter_communicator import IPCSignal, ZmqClient from fastdeploy.metrics.work_metrics import work_process_metrics from fastdeploy.multimodal.registry import MultimodalRegistry from fastdeploy.platforms import current_platform -from fastdeploy.utils import EngineError, StatefulSemaphore, api_server_logger +from fastdeploy.utils import ( + EngineError, + ParameterError, + StatefulSemaphore, + api_server_logger, +) class EngineClient: @@ -218,42 +223,21 @@ class EngineClient: def valid_parameters(self, data): """ Validate stream options + 超参数(top_p、seed、frequency_penalty、temperature、presence_penalty)的校验逻辑 + 前置到了ChatCompletionRequest/CompletionRequest中 """ if data.get("n") is not None: if data["n"] != 1: - raise ValueError("n only support 1.") + raise ParameterError("n", "n only support 1.") if data.get("max_tokens") is not None: if data["max_tokens"] < 1 or data["max_tokens"] >= self.max_model_len: - raise ValueError(f"max_tokens can be defined [1, {self.max_model_len}).") + raise ParameterError("max_tokens", f"max_tokens can be defined [1, {self.max_model_len}).") if data.get("reasoning_max_tokens") is not None: if data["reasoning_max_tokens"] > data["max_tokens"] or data["reasoning_max_tokens"] < 1: - raise ValueError("reasoning_max_tokens must be between max_tokens and 1") - - if data.get("top_p") is not None: - if data["top_p"] > 1 or data["top_p"] < 0: - raise ValueError("top_p value can only be defined [0, 1].") - - if data.get("frequency_penalty") is not None: - if not -2.0 <= data["frequency_penalty"] <= 2.0: - raise ValueError("frequency_penalty must be in [-2, 2]") - - if data.get("temperature") is not None: - if data["temperature"] < 0: - raise ValueError("temperature must be non-negative") - - if data.get("presence_penalty") is not None: - if not -2.0 <= data["presence_penalty"] <= 2.0: - raise ValueError("presence_penalty must be in [-2, 2]") - - if data.get("seed") is not None: - if not 0 <= data["seed"] <= 922337203685477580: - raise ValueError("seed must be in [0, 922337203685477580]") - - if data.get("stream_options") and not data.get("stream"): - raise ValueError("Stream options can only be defined when `stream=True`.") + raise ParameterError("reasoning_max_tokens", "reasoning_max_tokens must be between max_tokens and 1") # logprobs logprobs = data.get("logprobs") @@ -263,35 +247,35 @@ class EngineClient: if not self.enable_logprob: err_msg = "Logprobs is disabled, please enable it in startup config." api_server_logger.error(err_msg) - raise ValueError(err_msg) + raise ParameterError("logprobs", err_msg) top_logprobs = data.get("top_logprobs") elif isinstance(logprobs, int): top_logprobs = logprobs elif logprobs: - raise ValueError("Invalid type for 'logprobs'") + raise ParameterError("logprobs", "Invalid type for 'logprobs'") # enable_logprob if top_logprobs: if not self.enable_logprob: err_msg = "Logprobs is disabled, please enable it in startup config." api_server_logger.error(err_msg) - raise ValueError(err_msg) + raise ParameterError("logprobs", err_msg) if not isinstance(top_logprobs, int): err_type = type(top_logprobs).__name__ err_msg = f"Invalid type for 'top_logprobs': expected int but got {err_type}." api_server_logger.error(err_msg) - raise ValueError(err_msg) + raise ParameterError("top_logprobs", err_msg) if top_logprobs < 0: err_msg = f"Invalid 'top_logprobs': must be >= 0, got {top_logprobs}." api_server_logger.error(err_msg) - raise ValueError(err_msg) + raise ParameterError("top_logprobs", err_msg) if top_logprobs > 20: err_msg = "Invalid value for 'top_logprobs': must be <= 20." api_server_logger.error(err_msg) - raise ValueError(err_msg) + raise ParameterError("top_logprobs", err_msg) def check_health(self, time_interval_threashold=30): """ diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 1c9a34c0b..0417e2cb8 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -26,6 +26,7 @@ from multiprocessing import current_process import uvicorn import zmq from fastapi import FastAPI, HTTPException, Request +from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, Response, StreamingResponse from prometheus_client import CONTENT_TYPE_LATEST @@ -40,6 +41,7 @@ from fastdeploy.entrypoints.openai.protocol import ( CompletionRequest, CompletionResponse, ControlSchedulerRequest, + ErrorInfo, ErrorResponse, ModelList, ) @@ -56,6 +58,7 @@ from fastdeploy.metrics.metrics import ( ) from fastdeploy.metrics.trace_util import fd_start_span, inject_to_metadata, instrument from fastdeploy.utils import ( + ExceptionHandler, FlexibleArgumentParser, StatefulSemaphore, api_server_logger, @@ -232,6 +235,8 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) +app.add_exception_handler(RequestValidationError, ExceptionHandler.handle_request_validation_exception) +app.add_exception_handler(Exception, ExceptionHandler.handle_exception) instrument(app) @@ -336,7 +341,7 @@ async def create_chat_completion(request: ChatCompletionRequest): if isinstance(generator, ErrorResponse): api_server_logger.debug(f"release: {connection_semaphore.status()}") connection_semaphore.release() - return JSONResponse(content={"detail": generator.model_dump()}, status_code=generator.code) + return JSONResponse(content=generator.model_dump(), status_code=500) elif isinstance(generator, ChatCompletionResponse): api_server_logger.debug(f"release: {connection_semaphore.status()}") connection_semaphore.release() @@ -365,7 +370,7 @@ async def create_completion(request: CompletionRequest): generator = await app.state.completion_handler.create_completion(request) if isinstance(generator, ErrorResponse): connection_semaphore.release() - return JSONResponse(content=generator.model_dump(), status_code=generator.code) + return JSONResponse(content=generator.model_dump(), status_code=500) elif isinstance(generator, CompletionResponse): connection_semaphore.release() return JSONResponse(content=generator.model_dump()) @@ -388,7 +393,7 @@ async def list_models() -> Response: models = await app.state.model_handler.list_models() if isinstance(models, ErrorResponse): - return JSONResponse(content=models.model_dump(), status_code=models.code) + return JSONResponse(content=models.model_dump()) elif isinstance(models, ModelList): return JSONResponse(content=models.model_dump()) @@ -502,7 +507,8 @@ def control_scheduler(request: ControlSchedulerRequest): """ Control the scheduler behavior with the given parameters. """ - content = ErrorResponse(object="", message="Scheduler updated successfully", code=0) + + content = ErrorResponse(error=ErrorInfo(message="Scheduler updated successfully", code=0)) global llm_engine if llm_engine is None: diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index b74e0ffb4..5c6f32f71 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -32,9 +32,14 @@ class ErrorResponse(BaseModel): Error response from OpenAI API. """ - object: str = "error" + error: ErrorInfo + + +class ErrorInfo(BaseModel): message: str - code: int + type: Optional[str] = None + param: Optional[str] = None + code: Optional[str] = None class PromptTokenUsageInfo(BaseModel): @@ -403,21 +408,21 @@ class CompletionRequest(BaseModel): prompt: Union[List[int], List[List[int]], str, List[str]] best_of: Optional[int] = None echo: Optional[bool] = False - frequency_penalty: Optional[float] = None + frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2) logprobs: Optional[int] = None # For logits and logprobs post processing temp_scaled_logprobs: bool = False top_p_normalized_logprobs: bool = False max_tokens: Optional[int] = None - n: int = 1 - presence_penalty: Optional[float] = None - seed: Optional[int] = None + n: Optional[int] = 1 + presence_penalty: Optional[float] = Field(default=None, ge=-2, le=2) + seed: Optional[int] = Field(default=None, ge=0, le=922337203685477580) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None suffix: Optional[dict] = None - temperature: Optional[float] = None - top_p: Optional[float] = None + temperature: Optional[float] = Field(default=None, ge=0) + top_p: Optional[float] = Field(default=None, ge=0, le=1) user: Optional[str] = None # doc: begin-completion-sampling-params @@ -537,7 +542,7 @@ class ChatCompletionRequest(BaseModel): messages: Union[List[Any], List[int]] tools: Optional[List[ChatCompletionToolsParam]] = None model: Optional[str] = "default" - frequency_penalty: Optional[float] = None + frequency_penalty: Optional[float] = Field(None, le=2, ge=-2) logprobs: Optional[bool] = False top_logprobs: Optional[int] = 0 @@ -552,13 +557,13 @@ class ChatCompletionRequest(BaseModel): ) max_completion_tokens: Optional[int] = None n: Optional[int] = 1 - presence_penalty: Optional[float] = None - seed: Optional[int] = None + presence_penalty: Optional[float] = Field(None, le=2, ge=-2) + seed: Optional[int] = Field(default=None, ge=0, le=922337203685477580) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None - temperature: Optional[float] = None - top_p: Optional[float] = None + temperature: Optional[float] = Field(None, ge=0) + top_p: Optional[float] = Field(None, le=1, ge=0) user: Optional[str] = None metadata: Optional[dict] = None response_format: Optional[AnyResponseFormat] = None diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index f770e9d6b..cdd8ff3f4 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -30,6 +30,7 @@ from fastdeploy.entrypoints.openai.protocol import ( ChatCompletionStreamResponse, ChatMessage, DeltaMessage, + ErrorInfo, ErrorResponse, LogProbEntry, LogProbs, @@ -38,7 +39,7 @@ from fastdeploy.entrypoints.openai.protocol import ( ) from fastdeploy.entrypoints.openai.response_processors import ChatResponseProcessor from fastdeploy.metrics.work_metrics import work_process_metrics -from fastdeploy.utils import api_server_logger +from fastdeploy.utils import ErrorCode, ErrorType, ParameterError, api_server_logger from fastdeploy.worker.output import LogprobsLists @@ -86,14 +87,16 @@ class OpenAIServingChat: f"Only master node can accept completion request, please send request to master node: {self.master_ip}" ) api_server_logger.error(err_msg) - return ErrorResponse(message=err_msg, code=400) + return ErrorResponse(error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR)) if self.models: is_supported, request.model = self.models.is_supported_model(request.model) if not is_supported: err_msg = f"Unsupported model: [{request.model}], support [{', '.join([x.name for x in self.models.model_paths])}] or default" api_server_logger.error(err_msg) - return ErrorResponse(message=err_msg, code=400) + return ErrorResponse( + error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR, code=ErrorCode.MODEL_NOT_SUPPORT) + ) try: if self.max_waiting_time < 0: @@ -117,11 +120,17 @@ class OpenAIServingChat: text_after_process = current_req_dict.get("text_after_process") if isinstance(prompt_token_ids, np.ndarray): prompt_token_ids = prompt_token_ids.tolist() + except ParameterError as e: + api_server_logger.error(e.message) + self.engine_client.semaphore.release() + return ErrorResponse( + error=ErrorInfo(message=str(e.message), type=ErrorType.INVALID_REQUEST_ERROR, param=e.param) + ) except Exception as e: error_msg = f"request[{request_id}] generator error: {str(e)}, {str(traceback.format_exc())}" api_server_logger.error(error_msg) self.engine_client.semaphore.release() - return ErrorResponse(code=400, message=error_msg) + return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INVALID_REQUEST_ERROR)) del current_req_dict if request.stream: @@ -136,21 +145,20 @@ class OpenAIServingChat: except Exception as e: error_msg = f"request[{request_id}]full generator error: {str(e)}, {str(traceback.format_exc())}" api_server_logger.error(error_msg) - return ErrorResponse(code=408, message=error_msg) + return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.SERVER_ERROR)) except Exception as e: error_msg = ( f"request[{request_id}] waiting error: {str(e)}, {str(traceback.format_exc())}, " f"max waiting time: {self.max_waiting_time}" ) api_server_logger.error(error_msg) - return ErrorResponse(code=408, message=error_msg) + return ErrorResponse( + error=ErrorInfo(message=error_msg, type=ErrorType.TIMEOUT_ERROR, code=ErrorCode.TIMEOUT) + ) def _create_streaming_error_response(self, message: str) -> str: api_server_logger.error(message) - error_response = ErrorResponse( - code=400, - message=message, - ) + error_response = ErrorResponse(error=ErrorInfo(message=message, type=ErrorType.SERVER_ERROR)) return error_response.model_dump_json() async def chat_completion_stream_generator( diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index b3a97c426..5a0d21705 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -30,10 +30,11 @@ from fastdeploy.entrypoints.openai.protocol import ( CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, + ErrorInfo, ErrorResponse, UsageInfo, ) -from fastdeploy.utils import api_server_logger +from fastdeploy.utils import ErrorCode, ErrorType, ParameterError, api_server_logger from fastdeploy.worker.output import LogprobsLists @@ -63,13 +64,15 @@ class OpenAIServingCompletion: f"Only master node can accept completion request, please send request to master node: {self.master_ip}" ) api_server_logger.error(err_msg) - return ErrorResponse(message=err_msg, code=400) + return ErrorResponse(error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR)) if self.models: is_supported, request.model = self.models.is_supported_model(request.model) if not is_supported: err_msg = f"Unsupported model: [{request.model}], support [{', '.join([x.name for x in self.models.model_paths])}] or default" api_server_logger.error(err_msg) - return ErrorResponse(message=err_msg, code=400) + return ErrorResponse( + error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR, code=ErrorCode.MODEL_NOT_SUPPORT) + ) created_time = int(time.time()) if request.user is not None: request_id = f"cmpl-{request.user}-{uuid.uuid4()}" @@ -112,7 +115,7 @@ class OpenAIServingCompletion: except Exception as e: error_msg = f"OpenAIServingCompletion create_completion: {e}, {str(traceback.format_exc())}" api_server_logger.error(error_msg) - return ErrorResponse(message=error_msg, code=400) + return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.SERVER_ERROR)) if request_prompt_ids is not None: request_prompts = request_prompt_ids @@ -132,7 +135,9 @@ class OpenAIServingCompletion: f"max waiting time: {self.max_waiting_time}" ) api_server_logger.error(error_msg) - return ErrorResponse(code=408, message=error_msg) + return ErrorResponse( + error=ErrorInfo(message=error_msg, code=ErrorCode.TIMEOUT, type=ErrorType.TIMEOUT_ERROR) + ) try: try: @@ -146,11 +151,17 @@ class OpenAIServingCompletion: text_after_process_list.append(current_req_dict.get("text_after_process")) prompt_batched_token_ids.append(prompt_token_ids) del current_req_dict + except ParameterError as e: + api_server_logger.error(e.message) + self.engine_client.semaphore.release() + return ErrorResponse(code=400, message=str(e.message), type="invalid_request", param=e.param) except Exception as e: error_msg = f"OpenAIServingCompletion format error: {e}, {str(traceback.format_exc())}" api_server_logger.error(error_msg) self.engine_client.semaphore.release() - return ErrorResponse(message=str(e), code=400) + return ErrorResponse( + error=ErrorInfo(message=str(e), code=ErrorCode.INVALID_VALUE, type=ErrorType.INVALID_REQUEST_ERROR) + ) if request.stream: return self.completion_stream_generator( @@ -178,12 +189,12 @@ class OpenAIServingCompletion: f"OpenAIServingCompletion completion_full_generator error: {e}, {str(traceback.format_exc())}" ) api_server_logger.error(error_msg) - return ErrorResponse(code=400, message=error_msg) + return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.SERVER_ERROR)) except Exception as e: error_msg = f"OpenAIServingCompletion create_completion error: {e}, {str(traceback.format_exc())}" api_server_logger.error(error_msg) - return ErrorResponse(message=error_msg, code=400) + return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.SERVER_ERROR)) async def completion_full_generator( self, diff --git a/fastdeploy/entrypoints/openai/serving_models.py b/fastdeploy/entrypoints/openai/serving_models.py index 9493aa4f2..74f925947 100644 --- a/fastdeploy/entrypoints/openai/serving_models.py +++ b/fastdeploy/entrypoints/openai/serving_models.py @@ -18,12 +18,13 @@ from dataclasses import dataclass from typing import List, Union from fastdeploy.entrypoints.openai.protocol import ( + ErrorInfo, ErrorResponse, ModelInfo, ModelList, ModelPermission, ) -from fastdeploy.utils import api_server_logger, get_host_ip +from fastdeploy.utils import ErrorType, api_server_logger, get_host_ip @dataclass @@ -86,7 +87,7 @@ class OpenAIServingModels: f"Only master node can accept models request, please send request to master node: {self.master_ip}" ) api_server_logger.error(err_msg) - return ErrorResponse(message=err_msg, code=400) + return ErrorResponse(error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR)) model_infos = [ ModelInfo( id=model.name, max_model_len=self.max_model_len, root=model.model_path, permission=[ModelPermission()] diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 0b5d74e7b..08c73aab9 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -28,6 +28,8 @@ import sys import tarfile import time from datetime import datetime +from enum import Enum +from http import HTTPStatus from importlib.metadata import PackageNotFoundError, distribution from logging.handlers import BaseRotatingHandler from pathlib import Path @@ -38,10 +40,14 @@ import paddle import requests import yaml from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download +from fastapi import Request +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse from tqdm import tqdm from typing_extensions import TypeIs, assert_never from fastdeploy import envs +from fastdeploy.entrypoints.openai.protocol import ErrorInfo, ErrorResponse from fastdeploy.logger.logger import FastDeployLogger T = TypeVar("T") @@ -59,6 +65,61 @@ class EngineError(Exception): self.error_code = error_code +class ParameterError(Exception): + def __init__(self, param: str, message: str): + self.param = param + self.message = message + super().__init__(message) + + +class ExceptionHandler: + + # 全局异常兜底处理 + @staticmethod + async def handle_exception(request: Request, exc: Exception) -> JSONResponse: + error = ErrorResponse(error=ErrorInfo(message=str(exc), type=ErrorType.INTERNAL_ERROR)) + return JSONResponse(content=error.model_dump(), status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + + # 处理请求参数验证异常 + @staticmethod + async def handle_request_validation_exception(_: Request, exc: RequestValidationError) -> JSONResponse: + errors = exc.errors() + if not errors: + message = str(exc) + param = None + else: + first_error = errors[0] + loc = first_error.get("loc", []) + param = loc[-1] if loc else None + message = first_error.get("msg", str(exc)) + err = ErrorResponse( + error=ErrorInfo( + message=message, + type=ErrorType.INVALID_REQUEST_ERROR, + code=ErrorCode.MISSING_REQUIRED_PARAMETER if param == "messages" else ErrorCode.INVALID_VALUE, + param=param, + ) + ) + return JSONResponse(content=err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) + + +class ErrorType(str, Enum): + INVALID_REQUEST_ERROR = "invalid_request_error" + TIMEOUT_ERROR = "timeout_error" + SERVER_ERROR = "server_error" + INTERNAL_ERROR = "internal_error" + API_CONNECTION_ERROR = "api_connection_error" + + +class ErrorCode(str, Enum): + INVALID_VALUE = "invalid_value" + CONTEXT_LENGTH_EXCEEDED = "context_length_exceeded" + MODEL_NOT_SUPPORT = "model_not_support" + TIMEOUT = "timeout" + CONNECTION_ERROR = "connection_error" + MISSING_REQUIRED_PARAMETER = "missing_required_parameter" + + class ColoredFormatter(logging.Formatter): """自定义日志格式器,用于控制台输出带颜色""" diff --git a/tests/ce/server/test_evil_cases.py b/tests/ce/server/test_evil_cases.py index 18c445f4b..4f89874bc 100644 --- a/tests/ce/server/test_evil_cases.py +++ b/tests/ce/server/test_evil_cases.py @@ -20,9 +20,14 @@ def test_missing_messages_field(): payload = build_request_payload(TEMPLATE, data) resp = send_request(URL, payload).json() - assert "detail" in resp, "返回中未包含 detail 错误信息字段" - assert any("messages" in err.get("loc", []) for err in resp["detail"]), "未检测到 messages 字段缺失的报错" - assert any("Field required" in err.get("msg", "") for err in resp["detail"]), "未检测到 'Field required' 错误提示" + assert "error" in resp, "返回中未包含 error 错误信息字段" + error = resp["error"] + assert "Field required" in error.get("message", ""), "未检测到 messages 字段缺失的报错" + assert error.get("code") == "missing_required_parameter", "code 字段不正确" + + # assert "detail" in resp, "返回中未包含 detail 错误信息字段" + # assert any("messages" in err.get("loc", []) for err in resp["detail"]), "未检测到 messages 字段缺失的报错" + # assert any("Field required" in err.get("msg", "") for err in resp["detail"]), "未检测到 'Field required' 错误提示" def test_malformed_messages_format(): @@ -34,11 +39,15 @@ def test_malformed_messages_format(): } payload = build_request_payload(TEMPLATE, data) resp = send_request(URL, payload).json() - assert "detail" in resp, "非法结构未被识别" - assert any("messages" in err.get("loc", []) for err in resp["detail"]), "未检测到 messages 字段结构错误" - assert any( - "Input should be a valid list" in err.get("msg", "") for err in resp["detail"] - ), "未检测到 'Input should be a valid list' 错误提示" + assert "error" in resp, "非法结构未被识别" + err = resp["error"] + assert err.get("param") == "list[any]", f"param 字段错误: {err.get('param')}" + assert err.get("message") == "Input should be a valid list", "错误提示不符合预期" + # assert "detail" in resp, "非法结构未被识别" + # assert any("messages" in err.get("loc", []) for err in resp["detail"]), "未检测到 messages 字段结构错误" + # assert any( + # "Input should be a valid list" in err.get("msg", "") for err in resp["detail"] + # ), "未检测到 'Input should be a valid list' 错误提示" def test_extremely_large_max_tokens(): @@ -79,8 +88,13 @@ def test_top_p_exceed_1(): } payload = build_request_payload(TEMPLATE, data) resp = send_request(URL, payload).json() - assert resp.get("detail").get("object") == "error", "top_p > 1 应触发校验异常" - assert "top_p value can only be defined" in resp.get("detail").get("message", ""), "未返回预期的 top_p 错误信息" + assert "error" in resp, "未返回 error 字段" + + err = resp["error"] + assert err.get("param") == "top_p", f"param 字段错误: {err.get('param')}" + assert err.get("message") == "Input should be less than or equal to 1", "错误提示不符合预期" + # assert resp.get("detail").get("object") == "error", "top_p > 1 应触发校验异常" + # assert "top_p value can only be defined" in resp.get("detail").get("message", ""), "未返回预期的 top_p 错误信息" def test_mixed_valid_invalid_fields(): @@ -106,8 +120,8 @@ def test_stop_seq_exceed_num(): } payload = build_request_payload(TEMPLATE, data) resp = send_request(URL, payload).json() - assert resp.get("detail").get("object") == "error", "stop 超出个数应触发异常" - assert "exceeds the limit max_stop_seqs_num" in resp.get("detail").get("message", ""), "未返回预期的报错信息" + assert resp.get("error").get("type") == "invalid_request_error", "stop 超出个数应触发异常" + assert "exceeds the limit max_stop_seqs_num" in resp.get("error").get("message", ""), "未返回预期的报错信息" def test_stop_seq_exceed_length(): @@ -120,8 +134,8 @@ def test_stop_seq_exceed_length(): } payload = build_request_payload(TEMPLATE, data) resp = send_request(URL, payload).json() - assert resp.get("detail").get("object") == "error", "stop 超出长度应触发异常" - assert "exceeds the limit stop_seqs_max_len" in resp.get("detail").get("message", ""), "未返回预期的报错信息" + assert resp.get("error").get("type") == "invalid_request_error", "stop 超出长度应触发异常" + assert "exceeds the limit stop_seqs_max_len" in resp.get("error").get("message", ""), "未返回预期的报错信息" def test_multilingual_input(): @@ -154,8 +168,8 @@ def test_too_long_input(): data = {"messages": [{"role": "user", "content": "a," * 200000}], "stream": False} # 超过最大输入长度 payload = build_request_payload(TEMPLATE, data) resp = send_request(URL, payload).json() - assert resp["detail"].get("object") == "error", "超长输入未被识别为错误" - assert "Input text is too long" in resp["detail"].get("message", ""), "未检测到最大长度限制错误" + # assert resp["detail"].get("object") == "error", "超长输入未被识别为错误" + assert "Input text is too long" in resp["error"].get("message", ""), "未检测到最大长度限制错误" def test_empty_input(): @@ -361,8 +375,8 @@ def test_max_tokens_negative(): } payload = build_request_payload(TEMPLATE, data) resp = send_request(URL, payload).json() - assert resp.get("detail").get("object") == "error", "max_tokens < 0 未触发校验异常" - assert "max_tokens can be defined [1," in resp.get("detail").get("message"), "未返回预期的 max_tokens 错误信息" + # assert resp.get("detail").get("object") == "error", "max_tokens < 0 未触发校验异常" + assert "max_tokens can be defined [1," in resp.get("error").get("message"), "未返回预期的 max_tokens 错误信息" def test_max_tokens_min(): @@ -379,7 +393,10 @@ def test_max_tokens_min(): } payload = build_request_payload(TEMPLATE, data) resp = send_request(URL, payload).json() - assert resp.get("detail").get("object") == "error", "max_tokens未0时API未拦截住" + # assert resp.get("detail").get("object") == "error", "max_tokens未0时API未拦截住" + assert "max_tokens can be defined [1," in resp.get("error").get( + "message" + ), "max_tokens未0时API未拦截住,未返回预期的 max_tokens 错误信息" def test_max_tokens_non_integer(): @@ -397,5 +414,5 @@ def test_max_tokens_non_integer(): payload = build_request_payload(TEMPLATE, data) resp = send_request(URL, payload).json() assert ( - resp.get("detail")[0].get("msg") == "Input should be a valid integer, got a number with a fractional part" + resp.get("error").get("message") == "Input should be a valid integer, got a number with a fractional part" ), "未返回预期的 max_tokens 为非整数的错误信息" diff --git a/tests/entrypoints/openai/test_chatcompletion_request.py b/tests/entrypoints/openai/test_chatcompletion_request.py new file mode 100644 index 000000000..55aaf1944 --- /dev/null +++ b/tests/entrypoints/openai/test_chatcompletion_request.py @@ -0,0 +1,168 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest + +from pydantic import ValidationError + +from fastdeploy.entrypoints.openai.protocol import ( + ChatCompletionRequest, + CompletionRequest, +) + + +class TestChatCompletionRequest(unittest.TestCase): + + def test_required_messages(self): + with self.assertRaises(ValidationError): + ChatCompletionRequest() + + def test_messages_accepts_list_of_any_and_int(self): + req = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}]) + self.assertEqual(req.messages[0]["role"], "user") + + req = ChatCompletionRequest(messages=[1, 2, 3]) + self.assertEqual(req.messages, [1, 2, 3]) + + def test_default_values(self): + req = ChatCompletionRequest(messages=[1]) + self.assertEqual(req.model, "default") + self.assertFalse(req.logprobs) + self.assertEqual(req.top_logprobs, 0) + self.assertEqual(req.n, 1) + self.assertEqual(req.stop, []) + + def test_boundary_values(self): + valid_cases = [ + ("frequency_penalty", -2), + ("frequency_penalty", 2), + ("presence_penalty", -2), + ("presence_penalty", 2), + ("temperature", 0), + ("top_p", 1), + ("seed", 0), + ("seed", 922337203685477580), + ] + for field, value in valid_cases: + with self.subTest(field=field, value=value): + req = ChatCompletionRequest(messages=[1], **{field: value}) + self.assertEqual(getattr(req, field), value) + + def test_invalid_boundary_values(self): + invalid_cases = [ + ("frequency_penalty", -3), + ("frequency_penalty", 3), + ("presence_penalty", -3), + ("presence_penalty", 3), + ("temperature", -1), + ("top_p", 1.1), + ("seed", -1), + ("seed", 922337203685477581), + ] + for field, value in invalid_cases: + with self.subTest(field=field, value=value): + with self.assertRaises(ValidationError): + ChatCompletionRequest(messages=[1], **{field: value}) + + def test_stop_field_accepts_str_or_list(self): + req1 = ChatCompletionRequest(messages=[1], stop="end") + self.assertEqual(req1.stop, "end") + + req2 = ChatCompletionRequest(messages=[1], stop=["a", "b"]) + self.assertEqual(req2.stop, ["a", "b"]) + + with self.assertRaises(ValidationError): + ChatCompletionRequest(messages=[1], stop=123) + + def test_deprecated_max_tokens_field(self): + req = ChatCompletionRequest(messages=[1], max_tokens=10) + self.assertEqual(req.max_tokens, 10) + + def test_field_names_snapshot(self): + expected_fields = set(ChatCompletionRequest.__fields__.keys()) + self.assertEqual(set(ChatCompletionRequest.__fields__.keys()), expected_fields) + + +class TestCompletionRequest(unittest.TestCase): + + def test_required_prompt(self): + with self.assertRaises(ValidationError): + CompletionRequest() + + def test_prompt_accepts_various_types(self): + # str + req = CompletionRequest(prompt="hello") + self.assertEqual(req.prompt, "hello") + + # list of str + req = CompletionRequest(prompt=["hello", "world"]) + self.assertEqual(req.prompt, ["hello", "world"]) + + # list of int + req = CompletionRequest(prompt=[1, 2, 3]) + self.assertEqual(req.prompt, [1, 2, 3]) + + # list of list of int + req = CompletionRequest(prompt=[[1, 2], [3, 4]]) + self.assertEqual(req.prompt, [[1, 2], [3, 4]]) + + def test_default_values(self): + req = CompletionRequest(prompt="test") + self.assertEqual(req.model, "default") + self.assertEqual(req.echo, False) + self.assertEqual(req.temp_scaled_logprobs, False) + self.assertEqual(req.top_p_normalized_logprobs, False) + self.assertEqual(req.n, 1) + self.assertEqual(req.stop, []) + self.assertEqual(req.stream, False) + + def test_boundary_values(self): + valid_cases = [ + ("frequency_penalty", -2), + ("frequency_penalty", 2), + ("presence_penalty", -2), + ("presence_penalty", 2), + ("temperature", 0), + ("top_p", 0), + ("top_p", 1), + ("seed", 0), + ("seed", 922337203685477580), + ] + for field, value in valid_cases: + with self.subTest(field=field, value=value): + req = CompletionRequest(prompt="hi", **{field: value}) + self.assertEqual(getattr(req, field), value) + + def test_invalid_boundary_values(self): + invalid_cases = [ + ("frequency_penalty", -3), + ("frequency_penalty", 3), + ("presence_penalty", -3), + ("presence_penalty", 3), + ("temperature", -0.1), + ("top_p", -0.1), + ("top_p", 1.1), + ("seed", -1), + ("seed", 922337203685477581), + ] + for field, value in invalid_cases: + with self.subTest(field=field, value=value): + with self.assertRaises(ValidationError): + CompletionRequest(prompt="hi", **{field: value}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/entrypoints/openai/test_error_response.py b/tests/entrypoints/openai/test_error_response.py new file mode 100644 index 000000000..1d00495e8 --- /dev/null +++ b/tests/entrypoints/openai/test_error_response.py @@ -0,0 +1,30 @@ +import unittest + +from pydantic import ValidationError + +from fastdeploy.entrypoints.openai.protocol import ErrorResponse + + +class TestErrorResponse(unittest.TestCase): + def test_valid_error_response(self): + data = { + "error": { + "message": "Invalid top_p value", + "type": "invalid_request_error", + "param": "top_p", + "code": "invalid_value", + } + } + err_resp = ErrorResponse(**data) + self.assertEqual(err_resp.error.message, "Invalid top_p value") + self.assertEqual(err_resp.error.param, "top_p") + self.assertEqual(err_resp.error.code, "invalid_value") + + def test_missing_message_field(self): + data = {"error": {"type": "invalid_request_error", "param": "messages", "code": "missing_required_parameter"}} + with self.assertRaises(ValidationError): + ErrorResponse(**data) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/test_exception_handler.py b/tests/utils/test_exception_handler.py new file mode 100644 index 000000000..c5e2d0855 --- /dev/null +++ b/tests/utils/test_exception_handler.py @@ -0,0 +1,54 @@ +import json +import unittest +from http import HTTPStatus + +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse + +from fastdeploy.utils import ErrorCode, ExceptionHandler, ParameterError + + +class TestParameterError(unittest.TestCase): + def test_parameter_error_init(self): + exc = ParameterError("param1", "error message") + self.assertEqual(exc.param, "param1") + self.assertEqual(exc.message, "error message") + self.assertEqual(str(exc), "error message") + + +class TestExceptionHandler(unittest.IsolatedAsyncioTestCase): + + async def test_handle_exception(self): + """普通异常应返回 500 + internal_error""" + exc = RuntimeError("Something went wrong") + resp: JSONResponse = await ExceptionHandler.handle_exception(None, exc) + body = json.loads(resp.body.decode()) + self.assertEqual(resp.status_code, HTTPStatus.INTERNAL_SERVER_ERROR) + self.assertEqual(body["error"]["type"], "internal_error") + self.assertIn("Something went wrong", body["error"]["message"]) + + async def test_handle_request_validation_missing_messages(self): + """缺少 messages 参数时,应返回 missing_required_parameter""" + exc = RequestValidationError([{"loc": ("body", "messages"), "msg": "Field required", "type": "missing"}]) + resp: JSONResponse = await ExceptionHandler.handle_request_validation_exception(None, exc) + data = json.loads(resp.body.decode()) + self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST) + self.assertEqual(data["error"]["param"], "messages") + self.assertEqual(data["error"]["code"], ErrorCode.MISSING_REQUIRED_PARAMETER) + self.assertIn("Field required", data["error"]["message"]) + + async def test_handle_request_validation_invalid_value(self): + """参数非法时,应返回 invalid_value""" + exc = RequestValidationError( + [{"loc": ("body", "top_p"), "msg": "Input should be less than or equal to 1", "type": "value_error"}] + ) + resp: JSONResponse = await ExceptionHandler.handle_request_validation_exception(None, exc) + data = json.loads(resp.body.decode()) + self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST) + self.assertEqual(data["error"]["param"], "top_p") + self.assertEqual(data["error"]["code"], ErrorCode.INVALID_VALUE) + self.assertIn("less than or equal to 1", data["error"]["message"]) + + +if __name__ == "__main__": + unittest.main()