mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[format] Valid para format error info (#4035)
* feat(log):add_request_and_response_log * 报错信息与OpenAI对齐
This commit is contained in:
@@ -31,7 +31,12 @@ from fastdeploy.inter_communicator import IPCSignal, ZmqClient
|
||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||
from fastdeploy.multimodal.registry import MultimodalRegistry
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import EngineError, StatefulSemaphore, api_server_logger
|
||||
from fastdeploy.utils import (
|
||||
EngineError,
|
||||
ParameterError,
|
||||
StatefulSemaphore,
|
||||
api_server_logger,
|
||||
)
|
||||
|
||||
|
||||
class EngineClient:
|
||||
@@ -218,42 +223,21 @@ class EngineClient:
|
||||
def valid_parameters(self, data):
|
||||
"""
|
||||
Validate stream options
|
||||
超参数(top_p、seed、frequency_penalty、temperature、presence_penalty)的校验逻辑
|
||||
前置到了ChatCompletionRequest/CompletionRequest中
|
||||
"""
|
||||
|
||||
if data.get("n") is not None:
|
||||
if data["n"] != 1:
|
||||
raise ValueError("n only support 1.")
|
||||
raise ParameterError("n", "n only support 1.")
|
||||
|
||||
if data.get("max_tokens") is not None:
|
||||
if data["max_tokens"] < 1 or data["max_tokens"] >= self.max_model_len:
|
||||
raise ValueError(f"max_tokens can be defined [1, {self.max_model_len}).")
|
||||
raise ParameterError("max_tokens", f"max_tokens can be defined [1, {self.max_model_len}).")
|
||||
|
||||
if data.get("reasoning_max_tokens") is not None:
|
||||
if data["reasoning_max_tokens"] > data["max_tokens"] or data["reasoning_max_tokens"] < 1:
|
||||
raise ValueError("reasoning_max_tokens must be between max_tokens and 1")
|
||||
|
||||
if data.get("top_p") is not None:
|
||||
if data["top_p"] > 1 or data["top_p"] < 0:
|
||||
raise ValueError("top_p value can only be defined [0, 1].")
|
||||
|
||||
if data.get("frequency_penalty") is not None:
|
||||
if not -2.0 <= data["frequency_penalty"] <= 2.0:
|
||||
raise ValueError("frequency_penalty must be in [-2, 2]")
|
||||
|
||||
if data.get("temperature") is not None:
|
||||
if data["temperature"] < 0:
|
||||
raise ValueError("temperature must be non-negative")
|
||||
|
||||
if data.get("presence_penalty") is not None:
|
||||
if not -2.0 <= data["presence_penalty"] <= 2.0:
|
||||
raise ValueError("presence_penalty must be in [-2, 2]")
|
||||
|
||||
if data.get("seed") is not None:
|
||||
if not 0 <= data["seed"] <= 922337203685477580:
|
||||
raise ValueError("seed must be in [0, 922337203685477580]")
|
||||
|
||||
if data.get("stream_options") and not data.get("stream"):
|
||||
raise ValueError("Stream options can only be defined when `stream=True`.")
|
||||
raise ParameterError("reasoning_max_tokens", "reasoning_max_tokens must be between max_tokens and 1")
|
||||
|
||||
# logprobs
|
||||
logprobs = data.get("logprobs")
|
||||
@@ -263,35 +247,35 @@ class EngineClient:
|
||||
if not self.enable_logprob:
|
||||
err_msg = "Logprobs is disabled, please enable it in startup config."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
raise ParameterError("logprobs", err_msg)
|
||||
top_logprobs = data.get("top_logprobs")
|
||||
elif isinstance(logprobs, int):
|
||||
top_logprobs = logprobs
|
||||
elif logprobs:
|
||||
raise ValueError("Invalid type for 'logprobs'")
|
||||
raise ParameterError("logprobs", "Invalid type for 'logprobs'")
|
||||
|
||||
# enable_logprob
|
||||
if top_logprobs:
|
||||
if not self.enable_logprob:
|
||||
err_msg = "Logprobs is disabled, please enable it in startup config."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
raise ParameterError("logprobs", err_msg)
|
||||
|
||||
if not isinstance(top_logprobs, int):
|
||||
err_type = type(top_logprobs).__name__
|
||||
err_msg = f"Invalid type for 'top_logprobs': expected int but got {err_type}."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
raise ParameterError("top_logprobs", err_msg)
|
||||
|
||||
if top_logprobs < 0:
|
||||
err_msg = f"Invalid 'top_logprobs': must be >= 0, got {top_logprobs}."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
raise ParameterError("top_logprobs", err_msg)
|
||||
|
||||
if top_logprobs > 20:
|
||||
err_msg = "Invalid value for 'top_logprobs': must be <= 20."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
raise ParameterError("top_logprobs", err_msg)
|
||||
|
||||
def check_health(self, time_interval_threashold=30):
|
||||
"""
|
||||
|
@@ -26,6 +26,7 @@ from multiprocessing import current_process
|
||||
import uvicorn
|
||||
import zmq
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from prometheus_client import CONTENT_TYPE_LATEST
|
||||
|
||||
@@ -40,6 +41,7 @@ from fastdeploy.entrypoints.openai.protocol import (
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
ControlSchedulerRequest,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
ModelList,
|
||||
)
|
||||
@@ -56,6 +58,7 @@ from fastdeploy.metrics.metrics import (
|
||||
)
|
||||
from fastdeploy.metrics.trace_util import fd_start_span, inject_to_metadata, instrument
|
||||
from fastdeploy.utils import (
|
||||
ExceptionHandler,
|
||||
FlexibleArgumentParser,
|
||||
StatefulSemaphore,
|
||||
api_server_logger,
|
||||
@@ -232,6 +235,8 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_exception_handler(RequestValidationError, ExceptionHandler.handle_request_validation_exception)
|
||||
app.add_exception_handler(Exception, ExceptionHandler.handle_exception)
|
||||
instrument(app)
|
||||
|
||||
|
||||
@@ -336,7 +341,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if isinstance(generator, ErrorResponse):
|
||||
api_server_logger.debug(f"release: {connection_semaphore.status()}")
|
||||
connection_semaphore.release()
|
||||
return JSONResponse(content={"detail": generator.model_dump()}, status_code=generator.code)
|
||||
return JSONResponse(content=generator.model_dump(), status_code=500)
|
||||
elif isinstance(generator, ChatCompletionResponse):
|
||||
api_server_logger.debug(f"release: {connection_semaphore.status()}")
|
||||
connection_semaphore.release()
|
||||
@@ -365,7 +370,7 @@ async def create_completion(request: CompletionRequest):
|
||||
generator = await app.state.completion_handler.create_completion(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
connection_semaphore.release()
|
||||
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
|
||||
return JSONResponse(content=generator.model_dump(), status_code=500)
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
connection_semaphore.release()
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
@@ -388,7 +393,7 @@ async def list_models() -> Response:
|
||||
|
||||
models = await app.state.model_handler.list_models()
|
||||
if isinstance(models, ErrorResponse):
|
||||
return JSONResponse(content=models.model_dump(), status_code=models.code)
|
||||
return JSONResponse(content=models.model_dump())
|
||||
elif isinstance(models, ModelList):
|
||||
return JSONResponse(content=models.model_dump())
|
||||
|
||||
@@ -502,7 +507,8 @@ def control_scheduler(request: ControlSchedulerRequest):
|
||||
"""
|
||||
Control the scheduler behavior with the given parameters.
|
||||
"""
|
||||
content = ErrorResponse(object="", message="Scheduler updated successfully", code=0)
|
||||
|
||||
content = ErrorResponse(error=ErrorInfo(message="Scheduler updated successfully", code=0))
|
||||
|
||||
global llm_engine
|
||||
if llm_engine is None:
|
||||
|
@@ -32,9 +32,14 @@ class ErrorResponse(BaseModel):
|
||||
Error response from OpenAI API.
|
||||
"""
|
||||
|
||||
object: str = "error"
|
||||
error: ErrorInfo
|
||||
|
||||
|
||||
class ErrorInfo(BaseModel):
|
||||
message: str
|
||||
code: int
|
||||
type: Optional[str] = None
|
||||
param: Optional[str] = None
|
||||
code: Optional[str] = None
|
||||
|
||||
|
||||
class PromptTokenUsageInfo(BaseModel):
|
||||
@@ -403,21 +408,21 @@ class CompletionRequest(BaseModel):
|
||||
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||
best_of: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
frequency_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
|
||||
logprobs: Optional[int] = None
|
||||
# For logits and logprobs post processing
|
||||
temp_scaled_logprobs: bool = False
|
||||
top_p_normalized_logprobs: bool = False
|
||||
max_tokens: Optional[int] = None
|
||||
n: int = 1
|
||||
presence_penalty: Optional[float] = None
|
||||
seed: Optional[int] = None
|
||||
n: Optional[int] = 1
|
||||
presence_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
|
||||
seed: Optional[int] = Field(default=None, ge=0, le=922337203685477580)
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
suffix: Optional[dict] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
temperature: Optional[float] = Field(default=None, ge=0)
|
||||
top_p: Optional[float] = Field(default=None, ge=0, le=1)
|
||||
user: Optional[str] = None
|
||||
|
||||
# doc: begin-completion-sampling-params
|
||||
@@ -537,7 +542,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
messages: Union[List[Any], List[int]]
|
||||
tools: Optional[List[ChatCompletionToolsParam]] = None
|
||||
model: Optional[str] = "default"
|
||||
frequency_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = Field(None, le=2, ge=-2)
|
||||
logprobs: Optional[bool] = False
|
||||
top_logprobs: Optional[int] = 0
|
||||
|
||||
@@ -552,13 +557,13 @@ class ChatCompletionRequest(BaseModel):
|
||||
)
|
||||
max_completion_tokens: Optional[int] = None
|
||||
n: Optional[int] = 1
|
||||
presence_penalty: Optional[float] = None
|
||||
seed: Optional[int] = None
|
||||
presence_penalty: Optional[float] = Field(None, le=2, ge=-2)
|
||||
seed: Optional[int] = Field(default=None, ge=0, le=922337203685477580)
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
temperature: Optional[float] = Field(None, ge=0)
|
||||
top_p: Optional[float] = Field(None, le=1, ge=0)
|
||||
user: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
response_format: Optional[AnyResponseFormat] = None
|
||||
|
@@ -30,6 +30,7 @@ from fastdeploy.entrypoints.openai.protocol import (
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
DeltaMessage,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
LogProbEntry,
|
||||
LogProbs,
|
||||
@@ -38,7 +39,7 @@ from fastdeploy.entrypoints.openai.protocol import (
|
||||
)
|
||||
from fastdeploy.entrypoints.openai.response_processors import ChatResponseProcessor
|
||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||
from fastdeploy.utils import api_server_logger
|
||||
from fastdeploy.utils import ErrorCode, ErrorType, ParameterError, api_server_logger
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
|
||||
|
||||
@@ -86,14 +87,16 @@ class OpenAIServingChat:
|
||||
f"Only master node can accept completion request, please send request to master node: {self.master_ip}"
|
||||
)
|
||||
api_server_logger.error(err_msg)
|
||||
return ErrorResponse(message=err_msg, code=400)
|
||||
return ErrorResponse(error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR))
|
||||
|
||||
if self.models:
|
||||
is_supported, request.model = self.models.is_supported_model(request.model)
|
||||
if not is_supported:
|
||||
err_msg = f"Unsupported model: [{request.model}], support [{', '.join([x.name for x in self.models.model_paths])}] or default"
|
||||
api_server_logger.error(err_msg)
|
||||
return ErrorResponse(message=err_msg, code=400)
|
||||
return ErrorResponse(
|
||||
error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR, code=ErrorCode.MODEL_NOT_SUPPORT)
|
||||
)
|
||||
|
||||
try:
|
||||
if self.max_waiting_time < 0:
|
||||
@@ -117,11 +120,17 @@ class OpenAIServingChat:
|
||||
text_after_process = current_req_dict.get("text_after_process")
|
||||
if isinstance(prompt_token_ids, np.ndarray):
|
||||
prompt_token_ids = prompt_token_ids.tolist()
|
||||
except ParameterError as e:
|
||||
api_server_logger.error(e.message)
|
||||
self.engine_client.semaphore.release()
|
||||
return ErrorResponse(
|
||||
error=ErrorInfo(message=str(e.message), type=ErrorType.INVALID_REQUEST_ERROR, param=e.param)
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"request[{request_id}] generator error: {str(e)}, {str(traceback.format_exc())}"
|
||||
api_server_logger.error(error_msg)
|
||||
self.engine_client.semaphore.release()
|
||||
return ErrorResponse(code=400, message=error_msg)
|
||||
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INVALID_REQUEST_ERROR))
|
||||
del current_req_dict
|
||||
|
||||
if request.stream:
|
||||
@@ -136,21 +145,20 @@ class OpenAIServingChat:
|
||||
except Exception as e:
|
||||
error_msg = f"request[{request_id}]full generator error: {str(e)}, {str(traceback.format_exc())}"
|
||||
api_server_logger.error(error_msg)
|
||||
return ErrorResponse(code=408, message=error_msg)
|
||||
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.SERVER_ERROR))
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
f"request[{request_id}] waiting error: {str(e)}, {str(traceback.format_exc())}, "
|
||||
f"max waiting time: {self.max_waiting_time}"
|
||||
)
|
||||
api_server_logger.error(error_msg)
|
||||
return ErrorResponse(code=408, message=error_msg)
|
||||
return ErrorResponse(
|
||||
error=ErrorInfo(message=error_msg, type=ErrorType.TIMEOUT_ERROR, code=ErrorCode.TIMEOUT)
|
||||
)
|
||||
|
||||
def _create_streaming_error_response(self, message: str) -> str:
|
||||
api_server_logger.error(message)
|
||||
error_response = ErrorResponse(
|
||||
code=400,
|
||||
message=message,
|
||||
)
|
||||
error_response = ErrorResponse(error=ErrorInfo(message=message, type=ErrorType.SERVER_ERROR))
|
||||
return error_response.model_dump_json()
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
|
@@ -30,10 +30,11 @@ from fastdeploy.entrypoints.openai.protocol import (
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
from fastdeploy.utils import api_server_logger
|
||||
from fastdeploy.utils import ErrorCode, ErrorType, ParameterError, api_server_logger
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
|
||||
|
||||
@@ -63,13 +64,15 @@ class OpenAIServingCompletion:
|
||||
f"Only master node can accept completion request, please send request to master node: {self.master_ip}"
|
||||
)
|
||||
api_server_logger.error(err_msg)
|
||||
return ErrorResponse(message=err_msg, code=400)
|
||||
return ErrorResponse(error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR))
|
||||
if self.models:
|
||||
is_supported, request.model = self.models.is_supported_model(request.model)
|
||||
if not is_supported:
|
||||
err_msg = f"Unsupported model: [{request.model}], support [{', '.join([x.name for x in self.models.model_paths])}] or default"
|
||||
api_server_logger.error(err_msg)
|
||||
return ErrorResponse(message=err_msg, code=400)
|
||||
return ErrorResponse(
|
||||
error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR, code=ErrorCode.MODEL_NOT_SUPPORT)
|
||||
)
|
||||
created_time = int(time.time())
|
||||
if request.user is not None:
|
||||
request_id = f"cmpl-{request.user}-{uuid.uuid4()}"
|
||||
@@ -112,7 +115,7 @@ class OpenAIServingCompletion:
|
||||
except Exception as e:
|
||||
error_msg = f"OpenAIServingCompletion create_completion: {e}, {str(traceback.format_exc())}"
|
||||
api_server_logger.error(error_msg)
|
||||
return ErrorResponse(message=error_msg, code=400)
|
||||
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.SERVER_ERROR))
|
||||
|
||||
if request_prompt_ids is not None:
|
||||
request_prompts = request_prompt_ids
|
||||
@@ -132,7 +135,9 @@ class OpenAIServingCompletion:
|
||||
f"max waiting time: {self.max_waiting_time}"
|
||||
)
|
||||
api_server_logger.error(error_msg)
|
||||
return ErrorResponse(code=408, message=error_msg)
|
||||
return ErrorResponse(
|
||||
error=ErrorInfo(message=error_msg, code=ErrorCode.TIMEOUT, type=ErrorType.TIMEOUT_ERROR)
|
||||
)
|
||||
|
||||
try:
|
||||
try:
|
||||
@@ -146,11 +151,17 @@ class OpenAIServingCompletion:
|
||||
text_after_process_list.append(current_req_dict.get("text_after_process"))
|
||||
prompt_batched_token_ids.append(prompt_token_ids)
|
||||
del current_req_dict
|
||||
except ParameterError as e:
|
||||
api_server_logger.error(e.message)
|
||||
self.engine_client.semaphore.release()
|
||||
return ErrorResponse(code=400, message=str(e.message), type="invalid_request", param=e.param)
|
||||
except Exception as e:
|
||||
error_msg = f"OpenAIServingCompletion format error: {e}, {str(traceback.format_exc())}"
|
||||
api_server_logger.error(error_msg)
|
||||
self.engine_client.semaphore.release()
|
||||
return ErrorResponse(message=str(e), code=400)
|
||||
return ErrorResponse(
|
||||
error=ErrorInfo(message=str(e), code=ErrorCode.INVALID_VALUE, type=ErrorType.INVALID_REQUEST_ERROR)
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
return self.completion_stream_generator(
|
||||
@@ -178,12 +189,12 @@ class OpenAIServingCompletion:
|
||||
f"OpenAIServingCompletion completion_full_generator error: {e}, {str(traceback.format_exc())}"
|
||||
)
|
||||
api_server_logger.error(error_msg)
|
||||
return ErrorResponse(code=400, message=error_msg)
|
||||
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.SERVER_ERROR))
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"OpenAIServingCompletion create_completion error: {e}, {str(traceback.format_exc())}"
|
||||
api_server_logger.error(error_msg)
|
||||
return ErrorResponse(message=error_msg, code=400)
|
||||
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.SERVER_ERROR))
|
||||
|
||||
async def completion_full_generator(
|
||||
self,
|
||||
|
@@ -18,12 +18,13 @@ from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import (
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
ModelInfo,
|
||||
ModelList,
|
||||
ModelPermission,
|
||||
)
|
||||
from fastdeploy.utils import api_server_logger, get_host_ip
|
||||
from fastdeploy.utils import ErrorType, api_server_logger, get_host_ip
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -86,7 +87,7 @@ class OpenAIServingModels:
|
||||
f"Only master node can accept models request, please send request to master node: {self.master_ip}"
|
||||
)
|
||||
api_server_logger.error(err_msg)
|
||||
return ErrorResponse(message=err_msg, code=400)
|
||||
return ErrorResponse(error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR))
|
||||
model_infos = [
|
||||
ModelInfo(
|
||||
id=model.name, max_model_len=self.max_model_len, root=model.model_path, permission=[ModelPermission()]
|
||||
|
@@ -28,6 +28,8 @@ import sys
|
||||
import tarfile
|
||||
import time
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from importlib.metadata import PackageNotFoundError, distribution
|
||||
from logging.handlers import BaseRotatingHandler
|
||||
from pathlib import Path
|
||||
@@ -38,10 +40,14 @@ import paddle
|
||||
import requests
|
||||
import yaml
|
||||
from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download
|
||||
from fastapi import Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import TypeIs, assert_never
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.entrypoints.openai.protocol import ErrorInfo, ErrorResponse
|
||||
from fastdeploy.logger.logger import FastDeployLogger
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -59,6 +65,61 @@ class EngineError(Exception):
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class ParameterError(Exception):
|
||||
def __init__(self, param: str, message: str):
|
||||
self.param = param
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ExceptionHandler:
|
||||
|
||||
# 全局异常兜底处理
|
||||
@staticmethod
|
||||
async def handle_exception(request: Request, exc: Exception) -> JSONResponse:
|
||||
error = ErrorResponse(error=ErrorInfo(message=str(exc), type=ErrorType.INTERNAL_ERROR))
|
||||
return JSONResponse(content=error.model_dump(), status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
# 处理请求参数验证异常
|
||||
@staticmethod
|
||||
async def handle_request_validation_exception(_: Request, exc: RequestValidationError) -> JSONResponse:
|
||||
errors = exc.errors()
|
||||
if not errors:
|
||||
message = str(exc)
|
||||
param = None
|
||||
else:
|
||||
first_error = errors[0]
|
||||
loc = first_error.get("loc", [])
|
||||
param = loc[-1] if loc else None
|
||||
message = first_error.get("msg", str(exc))
|
||||
err = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=message,
|
||||
type=ErrorType.INVALID_REQUEST_ERROR,
|
||||
code=ErrorCode.MISSING_REQUIRED_PARAMETER if param == "messages" else ErrorCode.INVALID_VALUE,
|
||||
param=param,
|
||||
)
|
||||
)
|
||||
return JSONResponse(content=err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
class ErrorType(str, Enum):
|
||||
INVALID_REQUEST_ERROR = "invalid_request_error"
|
||||
TIMEOUT_ERROR = "timeout_error"
|
||||
SERVER_ERROR = "server_error"
|
||||
INTERNAL_ERROR = "internal_error"
|
||||
API_CONNECTION_ERROR = "api_connection_error"
|
||||
|
||||
|
||||
class ErrorCode(str, Enum):
|
||||
INVALID_VALUE = "invalid_value"
|
||||
CONTEXT_LENGTH_EXCEEDED = "context_length_exceeded"
|
||||
MODEL_NOT_SUPPORT = "model_not_support"
|
||||
TIMEOUT = "timeout"
|
||||
CONNECTION_ERROR = "connection_error"
|
||||
MISSING_REQUIRED_PARAMETER = "missing_required_parameter"
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""自定义日志格式器,用于控制台输出带颜色"""
|
||||
|
||||
|
@@ -20,9 +20,14 @@ def test_missing_messages_field():
|
||||
payload = build_request_payload(TEMPLATE, data)
|
||||
resp = send_request(URL, payload).json()
|
||||
|
||||
assert "detail" in resp, "返回中未包含 detail 错误信息字段"
|
||||
assert any("messages" in err.get("loc", []) for err in resp["detail"]), "未检测到 messages 字段缺失的报错"
|
||||
assert any("Field required" in err.get("msg", "") for err in resp["detail"]), "未检测到 'Field required' 错误提示"
|
||||
assert "error" in resp, "返回中未包含 error 错误信息字段"
|
||||
error = resp["error"]
|
||||
assert "Field required" in error.get("message", ""), "未检测到 messages 字段缺失的报错"
|
||||
assert error.get("code") == "missing_required_parameter", "code 字段不正确"
|
||||
|
||||
# assert "detail" in resp, "返回中未包含 detail 错误信息字段"
|
||||
# assert any("messages" in err.get("loc", []) for err in resp["detail"]), "未检测到 messages 字段缺失的报错"
|
||||
# assert any("Field required" in err.get("msg", "") for err in resp["detail"]), "未检测到 'Field required' 错误提示"
|
||||
|
||||
|
||||
def test_malformed_messages_format():
|
||||
@@ -34,11 +39,15 @@ def test_malformed_messages_format():
|
||||
}
|
||||
payload = build_request_payload(TEMPLATE, data)
|
||||
resp = send_request(URL, payload).json()
|
||||
assert "detail" in resp, "非法结构未被识别"
|
||||
assert any("messages" in err.get("loc", []) for err in resp["detail"]), "未检测到 messages 字段结构错误"
|
||||
assert any(
|
||||
"Input should be a valid list" in err.get("msg", "") for err in resp["detail"]
|
||||
), "未检测到 'Input should be a valid list' 错误提示"
|
||||
assert "error" in resp, "非法结构未被识别"
|
||||
err = resp["error"]
|
||||
assert err.get("param") == "list[any]", f"param 字段错误: {err.get('param')}"
|
||||
assert err.get("message") == "Input should be a valid list", "错误提示不符合预期"
|
||||
# assert "detail" in resp, "非法结构未被识别"
|
||||
# assert any("messages" in err.get("loc", []) for err in resp["detail"]), "未检测到 messages 字段结构错误"
|
||||
# assert any(
|
||||
# "Input should be a valid list" in err.get("msg", "") for err in resp["detail"]
|
||||
# ), "未检测到 'Input should be a valid list' 错误提示"
|
||||
|
||||
|
||||
def test_extremely_large_max_tokens():
|
||||
@@ -79,8 +88,13 @@ def test_top_p_exceed_1():
|
||||
}
|
||||
payload = build_request_payload(TEMPLATE, data)
|
||||
resp = send_request(URL, payload).json()
|
||||
assert resp.get("detail").get("object") == "error", "top_p > 1 应触发校验异常"
|
||||
assert "top_p value can only be defined" in resp.get("detail").get("message", ""), "未返回预期的 top_p 错误信息"
|
||||
assert "error" in resp, "未返回 error 字段"
|
||||
|
||||
err = resp["error"]
|
||||
assert err.get("param") == "top_p", f"param 字段错误: {err.get('param')}"
|
||||
assert err.get("message") == "Input should be less than or equal to 1", "错误提示不符合预期"
|
||||
# assert resp.get("detail").get("object") == "error", "top_p > 1 应触发校验异常"
|
||||
# assert "top_p value can only be defined" in resp.get("detail").get("message", ""), "未返回预期的 top_p 错误信息"
|
||||
|
||||
|
||||
def test_mixed_valid_invalid_fields():
|
||||
@@ -106,8 +120,8 @@ def test_stop_seq_exceed_num():
|
||||
}
|
||||
payload = build_request_payload(TEMPLATE, data)
|
||||
resp = send_request(URL, payload).json()
|
||||
assert resp.get("detail").get("object") == "error", "stop 超出个数应触发异常"
|
||||
assert "exceeds the limit max_stop_seqs_num" in resp.get("detail").get("message", ""), "未返回预期的报错信息"
|
||||
assert resp.get("error").get("type") == "invalid_request_error", "stop 超出个数应触发异常"
|
||||
assert "exceeds the limit max_stop_seqs_num" in resp.get("error").get("message", ""), "未返回预期的报错信息"
|
||||
|
||||
|
||||
def test_stop_seq_exceed_length():
|
||||
@@ -120,8 +134,8 @@ def test_stop_seq_exceed_length():
|
||||
}
|
||||
payload = build_request_payload(TEMPLATE, data)
|
||||
resp = send_request(URL, payload).json()
|
||||
assert resp.get("detail").get("object") == "error", "stop 超出长度应触发异常"
|
||||
assert "exceeds the limit stop_seqs_max_len" in resp.get("detail").get("message", ""), "未返回预期的报错信息"
|
||||
assert resp.get("error").get("type") == "invalid_request_error", "stop 超出长度应触发异常"
|
||||
assert "exceeds the limit stop_seqs_max_len" in resp.get("error").get("message", ""), "未返回预期的报错信息"
|
||||
|
||||
|
||||
def test_multilingual_input():
|
||||
@@ -154,8 +168,8 @@ def test_too_long_input():
|
||||
data = {"messages": [{"role": "user", "content": "a," * 200000}], "stream": False} # 超过最大输入长度
|
||||
payload = build_request_payload(TEMPLATE, data)
|
||||
resp = send_request(URL, payload).json()
|
||||
assert resp["detail"].get("object") == "error", "超长输入未被识别为错误"
|
||||
assert "Input text is too long" in resp["detail"].get("message", ""), "未检测到最大长度限制错误"
|
||||
# assert resp["detail"].get("object") == "error", "超长输入未被识别为错误"
|
||||
assert "Input text is too long" in resp["error"].get("message", ""), "未检测到最大长度限制错误"
|
||||
|
||||
|
||||
def test_empty_input():
|
||||
@@ -361,8 +375,8 @@ def test_max_tokens_negative():
|
||||
}
|
||||
payload = build_request_payload(TEMPLATE, data)
|
||||
resp = send_request(URL, payload).json()
|
||||
assert resp.get("detail").get("object") == "error", "max_tokens < 0 未触发校验异常"
|
||||
assert "max_tokens can be defined [1," in resp.get("detail").get("message"), "未返回预期的 max_tokens 错误信息"
|
||||
# assert resp.get("detail").get("object") == "error", "max_tokens < 0 未触发校验异常"
|
||||
assert "max_tokens can be defined [1," in resp.get("error").get("message"), "未返回预期的 max_tokens 错误信息"
|
||||
|
||||
|
||||
def test_max_tokens_min():
|
||||
@@ -379,7 +393,10 @@ def test_max_tokens_min():
|
||||
}
|
||||
payload = build_request_payload(TEMPLATE, data)
|
||||
resp = send_request(URL, payload).json()
|
||||
assert resp.get("detail").get("object") == "error", "max_tokens未0时API未拦截住"
|
||||
# assert resp.get("detail").get("object") == "error", "max_tokens未0时API未拦截住"
|
||||
assert "max_tokens can be defined [1," in resp.get("error").get(
|
||||
"message"
|
||||
), "max_tokens未0时API未拦截住,未返回预期的 max_tokens 错误信息"
|
||||
|
||||
|
||||
def test_max_tokens_non_integer():
|
||||
@@ -397,5 +414,5 @@ def test_max_tokens_non_integer():
|
||||
payload = build_request_payload(TEMPLATE, data)
|
||||
resp = send_request(URL, payload).json()
|
||||
assert (
|
||||
resp.get("detail")[0].get("msg") == "Input should be a valid integer, got a number with a fractional part"
|
||||
resp.get("error").get("message") == "Input should be a valid integer, got a number with a fractional part"
|
||||
), "未返回预期的 max_tokens 为非整数的错误信息"
|
||||
|
168
tests/entrypoints/openai/test_chatcompletion_request.py
Normal file
168
tests/entrypoints/openai/test_chatcompletion_request.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
)
|
||||
|
||||
|
||||
class TestChatCompletionRequest(unittest.TestCase):
|
||||
|
||||
def test_required_messages(self):
|
||||
with self.assertRaises(ValidationError):
|
||||
ChatCompletionRequest()
|
||||
|
||||
def test_messages_accepts_list_of_any_and_int(self):
|
||||
req = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}])
|
||||
self.assertEqual(req.messages[0]["role"], "user")
|
||||
|
||||
req = ChatCompletionRequest(messages=[1, 2, 3])
|
||||
self.assertEqual(req.messages, [1, 2, 3])
|
||||
|
||||
def test_default_values(self):
|
||||
req = ChatCompletionRequest(messages=[1])
|
||||
self.assertEqual(req.model, "default")
|
||||
self.assertFalse(req.logprobs)
|
||||
self.assertEqual(req.top_logprobs, 0)
|
||||
self.assertEqual(req.n, 1)
|
||||
self.assertEqual(req.stop, [])
|
||||
|
||||
def test_boundary_values(self):
|
||||
valid_cases = [
|
||||
("frequency_penalty", -2),
|
||||
("frequency_penalty", 2),
|
||||
("presence_penalty", -2),
|
||||
("presence_penalty", 2),
|
||||
("temperature", 0),
|
||||
("top_p", 1),
|
||||
("seed", 0),
|
||||
("seed", 922337203685477580),
|
||||
]
|
||||
for field, value in valid_cases:
|
||||
with self.subTest(field=field, value=value):
|
||||
req = ChatCompletionRequest(messages=[1], **{field: value})
|
||||
self.assertEqual(getattr(req, field), value)
|
||||
|
||||
def test_invalid_boundary_values(self):
|
||||
invalid_cases = [
|
||||
("frequency_penalty", -3),
|
||||
("frequency_penalty", 3),
|
||||
("presence_penalty", -3),
|
||||
("presence_penalty", 3),
|
||||
("temperature", -1),
|
||||
("top_p", 1.1),
|
||||
("seed", -1),
|
||||
("seed", 922337203685477581),
|
||||
]
|
||||
for field, value in invalid_cases:
|
||||
with self.subTest(field=field, value=value):
|
||||
with self.assertRaises(ValidationError):
|
||||
ChatCompletionRequest(messages=[1], **{field: value})
|
||||
|
||||
def test_stop_field_accepts_str_or_list(self):
|
||||
req1 = ChatCompletionRequest(messages=[1], stop="end")
|
||||
self.assertEqual(req1.stop, "end")
|
||||
|
||||
req2 = ChatCompletionRequest(messages=[1], stop=["a", "b"])
|
||||
self.assertEqual(req2.stop, ["a", "b"])
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
ChatCompletionRequest(messages=[1], stop=123)
|
||||
|
||||
def test_deprecated_max_tokens_field(self):
|
||||
req = ChatCompletionRequest(messages=[1], max_tokens=10)
|
||||
self.assertEqual(req.max_tokens, 10)
|
||||
|
||||
def test_field_names_snapshot(self):
|
||||
expected_fields = set(ChatCompletionRequest.__fields__.keys())
|
||||
self.assertEqual(set(ChatCompletionRequest.__fields__.keys()), expected_fields)
|
||||
|
||||
|
||||
class TestCompletionRequest(unittest.TestCase):
|
||||
|
||||
def test_required_prompt(self):
|
||||
with self.assertRaises(ValidationError):
|
||||
CompletionRequest()
|
||||
|
||||
def test_prompt_accepts_various_types(self):
|
||||
# str
|
||||
req = CompletionRequest(prompt="hello")
|
||||
self.assertEqual(req.prompt, "hello")
|
||||
|
||||
# list of str
|
||||
req = CompletionRequest(prompt=["hello", "world"])
|
||||
self.assertEqual(req.prompt, ["hello", "world"])
|
||||
|
||||
# list of int
|
||||
req = CompletionRequest(prompt=[1, 2, 3])
|
||||
self.assertEqual(req.prompt, [1, 2, 3])
|
||||
|
||||
# list of list of int
|
||||
req = CompletionRequest(prompt=[[1, 2], [3, 4]])
|
||||
self.assertEqual(req.prompt, [[1, 2], [3, 4]])
|
||||
|
||||
def test_default_values(self):
|
||||
req = CompletionRequest(prompt="test")
|
||||
self.assertEqual(req.model, "default")
|
||||
self.assertEqual(req.echo, False)
|
||||
self.assertEqual(req.temp_scaled_logprobs, False)
|
||||
self.assertEqual(req.top_p_normalized_logprobs, False)
|
||||
self.assertEqual(req.n, 1)
|
||||
self.assertEqual(req.stop, [])
|
||||
self.assertEqual(req.stream, False)
|
||||
|
||||
def test_boundary_values(self):
|
||||
valid_cases = [
|
||||
("frequency_penalty", -2),
|
||||
("frequency_penalty", 2),
|
||||
("presence_penalty", -2),
|
||||
("presence_penalty", 2),
|
||||
("temperature", 0),
|
||||
("top_p", 0),
|
||||
("top_p", 1),
|
||||
("seed", 0),
|
||||
("seed", 922337203685477580),
|
||||
]
|
||||
for field, value in valid_cases:
|
||||
with self.subTest(field=field, value=value):
|
||||
req = CompletionRequest(prompt="hi", **{field: value})
|
||||
self.assertEqual(getattr(req, field), value)
|
||||
|
||||
def test_invalid_boundary_values(self):
|
||||
invalid_cases = [
|
||||
("frequency_penalty", -3),
|
||||
("frequency_penalty", 3),
|
||||
("presence_penalty", -3),
|
||||
("presence_penalty", 3),
|
||||
("temperature", -0.1),
|
||||
("top_p", -0.1),
|
||||
("top_p", 1.1),
|
||||
("seed", -1),
|
||||
("seed", 922337203685477581),
|
||||
]
|
||||
for field, value in invalid_cases:
|
||||
with self.subTest(field=field, value=value):
|
||||
with self.assertRaises(ValidationError):
|
||||
CompletionRequest(prompt="hi", **{field: value})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
30
tests/entrypoints/openai/test_error_response.py
Normal file
30
tests/entrypoints/openai/test_error_response.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import unittest
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import ErrorResponse
|
||||
|
||||
|
||||
class TestErrorResponse(unittest.TestCase):
|
||||
def test_valid_error_response(self):
|
||||
data = {
|
||||
"error": {
|
||||
"message": "Invalid top_p value",
|
||||
"type": "invalid_request_error",
|
||||
"param": "top_p",
|
||||
"code": "invalid_value",
|
||||
}
|
||||
}
|
||||
err_resp = ErrorResponse(**data)
|
||||
self.assertEqual(err_resp.error.message, "Invalid top_p value")
|
||||
self.assertEqual(err_resp.error.param, "top_p")
|
||||
self.assertEqual(err_resp.error.code, "invalid_value")
|
||||
|
||||
def test_missing_message_field(self):
|
||||
data = {"error": {"type": "invalid_request_error", "param": "messages", "code": "missing_required_parameter"}}
|
||||
with self.assertRaises(ValidationError):
|
||||
ErrorResponse(**data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
54
tests/utils/test_exception_handler.py
Normal file
54
tests/utils/test_exception_handler.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import json
|
||||
import unittest
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from fastdeploy.utils import ErrorCode, ExceptionHandler, ParameterError
|
||||
|
||||
|
||||
class TestParameterError(unittest.TestCase):
|
||||
def test_parameter_error_init(self):
|
||||
exc = ParameterError("param1", "error message")
|
||||
self.assertEqual(exc.param, "param1")
|
||||
self.assertEqual(exc.message, "error message")
|
||||
self.assertEqual(str(exc), "error message")
|
||||
|
||||
|
||||
class TestExceptionHandler(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_handle_exception(self):
|
||||
"""普通异常应返回 500 + internal_error"""
|
||||
exc = RuntimeError("Something went wrong")
|
||||
resp: JSONResponse = await ExceptionHandler.handle_exception(None, exc)
|
||||
body = json.loads(resp.body.decode())
|
||||
self.assertEqual(resp.status_code, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
self.assertEqual(body["error"]["type"], "internal_error")
|
||||
self.assertIn("Something went wrong", body["error"]["message"])
|
||||
|
||||
async def test_handle_request_validation_missing_messages(self):
|
||||
"""缺少 messages 参数时,应返回 missing_required_parameter"""
|
||||
exc = RequestValidationError([{"loc": ("body", "messages"), "msg": "Field required", "type": "missing"}])
|
||||
resp: JSONResponse = await ExceptionHandler.handle_request_validation_exception(None, exc)
|
||||
data = json.loads(resp.body.decode())
|
||||
self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
|
||||
self.assertEqual(data["error"]["param"], "messages")
|
||||
self.assertEqual(data["error"]["code"], ErrorCode.MISSING_REQUIRED_PARAMETER)
|
||||
self.assertIn("Field required", data["error"]["message"])
|
||||
|
||||
async def test_handle_request_validation_invalid_value(self):
|
||||
"""参数非法时,应返回 invalid_value"""
|
||||
exc = RequestValidationError(
|
||||
[{"loc": ("body", "top_p"), "msg": "Input should be less than or equal to 1", "type": "value_error"}]
|
||||
)
|
||||
resp: JSONResponse = await ExceptionHandler.handle_request_validation_exception(None, exc)
|
||||
data = json.loads(resp.body.decode())
|
||||
self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
|
||||
self.assertEqual(data["error"]["param"], "top_p")
|
||||
self.assertEqual(data["error"]["code"], ErrorCode.INVALID_VALUE)
|
||||
self.assertIn("less than or equal to 1", data["error"]["message"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user