mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[format] Valid para format error info (#4035)
* feat(log):add_request_and_response_log * 报错信息与OpenAI对齐
This commit is contained in:
@@ -31,7 +31,12 @@ from fastdeploy.inter_communicator import IPCSignal, ZmqClient
|
|||||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||||
from fastdeploy.multimodal.registry import MultimodalRegistry
|
from fastdeploy.multimodal.registry import MultimodalRegistry
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
from fastdeploy.utils import EngineError, StatefulSemaphore, api_server_logger
|
from fastdeploy.utils import (
|
||||||
|
EngineError,
|
||||||
|
ParameterError,
|
||||||
|
StatefulSemaphore,
|
||||||
|
api_server_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EngineClient:
|
class EngineClient:
|
||||||
@@ -218,42 +223,21 @@ class EngineClient:
|
|||||||
def valid_parameters(self, data):
|
def valid_parameters(self, data):
|
||||||
"""
|
"""
|
||||||
Validate stream options
|
Validate stream options
|
||||||
|
超参数(top_p、seed、frequency_penalty、temperature、presence_penalty)的校验逻辑
|
||||||
|
前置到了ChatCompletionRequest/CompletionRequest中
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if data.get("n") is not None:
|
if data.get("n") is not None:
|
||||||
if data["n"] != 1:
|
if data["n"] != 1:
|
||||||
raise ValueError("n only support 1.")
|
raise ParameterError("n", "n only support 1.")
|
||||||
|
|
||||||
if data.get("max_tokens") is not None:
|
if data.get("max_tokens") is not None:
|
||||||
if data["max_tokens"] < 1 or data["max_tokens"] >= self.max_model_len:
|
if data["max_tokens"] < 1 or data["max_tokens"] >= self.max_model_len:
|
||||||
raise ValueError(f"max_tokens can be defined [1, {self.max_model_len}).")
|
raise ParameterError("max_tokens", f"max_tokens can be defined [1, {self.max_model_len}).")
|
||||||
|
|
||||||
if data.get("reasoning_max_tokens") is not None:
|
if data.get("reasoning_max_tokens") is not None:
|
||||||
if data["reasoning_max_tokens"] > data["max_tokens"] or data["reasoning_max_tokens"] < 1:
|
if data["reasoning_max_tokens"] > data["max_tokens"] or data["reasoning_max_tokens"] < 1:
|
||||||
raise ValueError("reasoning_max_tokens must be between max_tokens and 1")
|
raise ParameterError("reasoning_max_tokens", "reasoning_max_tokens must be between max_tokens and 1")
|
||||||
|
|
||||||
if data.get("top_p") is not None:
|
|
||||||
if data["top_p"] > 1 or data["top_p"] < 0:
|
|
||||||
raise ValueError("top_p value can only be defined [0, 1].")
|
|
||||||
|
|
||||||
if data.get("frequency_penalty") is not None:
|
|
||||||
if not -2.0 <= data["frequency_penalty"] <= 2.0:
|
|
||||||
raise ValueError("frequency_penalty must be in [-2, 2]")
|
|
||||||
|
|
||||||
if data.get("temperature") is not None:
|
|
||||||
if data["temperature"] < 0:
|
|
||||||
raise ValueError("temperature must be non-negative")
|
|
||||||
|
|
||||||
if data.get("presence_penalty") is not None:
|
|
||||||
if not -2.0 <= data["presence_penalty"] <= 2.0:
|
|
||||||
raise ValueError("presence_penalty must be in [-2, 2]")
|
|
||||||
|
|
||||||
if data.get("seed") is not None:
|
|
||||||
if not 0 <= data["seed"] <= 922337203685477580:
|
|
||||||
raise ValueError("seed must be in [0, 922337203685477580]")
|
|
||||||
|
|
||||||
if data.get("stream_options") and not data.get("stream"):
|
|
||||||
raise ValueError("Stream options can only be defined when `stream=True`.")
|
|
||||||
|
|
||||||
# logprobs
|
# logprobs
|
||||||
logprobs = data.get("logprobs")
|
logprobs = data.get("logprobs")
|
||||||
@@ -263,35 +247,35 @@ class EngineClient:
|
|||||||
if not self.enable_logprob:
|
if not self.enable_logprob:
|
||||||
err_msg = "Logprobs is disabled, please enable it in startup config."
|
err_msg = "Logprobs is disabled, please enable it in startup config."
|
||||||
api_server_logger.error(err_msg)
|
api_server_logger.error(err_msg)
|
||||||
raise ValueError(err_msg)
|
raise ParameterError("logprobs", err_msg)
|
||||||
top_logprobs = data.get("top_logprobs")
|
top_logprobs = data.get("top_logprobs")
|
||||||
elif isinstance(logprobs, int):
|
elif isinstance(logprobs, int):
|
||||||
top_logprobs = logprobs
|
top_logprobs = logprobs
|
||||||
elif logprobs:
|
elif logprobs:
|
||||||
raise ValueError("Invalid type for 'logprobs'")
|
raise ParameterError("logprobs", "Invalid type for 'logprobs'")
|
||||||
|
|
||||||
# enable_logprob
|
# enable_logprob
|
||||||
if top_logprobs:
|
if top_logprobs:
|
||||||
if not self.enable_logprob:
|
if not self.enable_logprob:
|
||||||
err_msg = "Logprobs is disabled, please enable it in startup config."
|
err_msg = "Logprobs is disabled, please enable it in startup config."
|
||||||
api_server_logger.error(err_msg)
|
api_server_logger.error(err_msg)
|
||||||
raise ValueError(err_msg)
|
raise ParameterError("logprobs", err_msg)
|
||||||
|
|
||||||
if not isinstance(top_logprobs, int):
|
if not isinstance(top_logprobs, int):
|
||||||
err_type = type(top_logprobs).__name__
|
err_type = type(top_logprobs).__name__
|
||||||
err_msg = f"Invalid type for 'top_logprobs': expected int but got {err_type}."
|
err_msg = f"Invalid type for 'top_logprobs': expected int but got {err_type}."
|
||||||
api_server_logger.error(err_msg)
|
api_server_logger.error(err_msg)
|
||||||
raise ValueError(err_msg)
|
raise ParameterError("top_logprobs", err_msg)
|
||||||
|
|
||||||
if top_logprobs < 0:
|
if top_logprobs < 0:
|
||||||
err_msg = f"Invalid 'top_logprobs': must be >= 0, got {top_logprobs}."
|
err_msg = f"Invalid 'top_logprobs': must be >= 0, got {top_logprobs}."
|
||||||
api_server_logger.error(err_msg)
|
api_server_logger.error(err_msg)
|
||||||
raise ValueError(err_msg)
|
raise ParameterError("top_logprobs", err_msg)
|
||||||
|
|
||||||
if top_logprobs > 20:
|
if top_logprobs > 20:
|
||||||
err_msg = "Invalid value for 'top_logprobs': must be <= 20."
|
err_msg = "Invalid value for 'top_logprobs': must be <= 20."
|
||||||
api_server_logger.error(err_msg)
|
api_server_logger.error(err_msg)
|
||||||
raise ValueError(err_msg)
|
raise ParameterError("top_logprobs", err_msg)
|
||||||
|
|
||||||
def check_health(self, time_interval_threashold=30):
|
def check_health(self, time_interval_threashold=30):
|
||||||
"""
|
"""
|
||||||
|
@@ -26,6 +26,7 @@ from multiprocessing import current_process
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
import zmq
|
import zmq
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
from prometheus_client import CONTENT_TYPE_LATEST
|
from prometheus_client import CONTENT_TYPE_LATEST
|
||||||
|
|
||||||
@@ -40,6 +41,7 @@ from fastdeploy.entrypoints.openai.protocol import (
|
|||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
ControlSchedulerRequest,
|
ControlSchedulerRequest,
|
||||||
|
ErrorInfo,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
ModelList,
|
ModelList,
|
||||||
)
|
)
|
||||||
@@ -56,6 +58,7 @@ from fastdeploy.metrics.metrics import (
|
|||||||
)
|
)
|
||||||
from fastdeploy.metrics.trace_util import fd_start_span, inject_to_metadata, instrument
|
from fastdeploy.metrics.trace_util import fd_start_span, inject_to_metadata, instrument
|
||||||
from fastdeploy.utils import (
|
from fastdeploy.utils import (
|
||||||
|
ExceptionHandler,
|
||||||
FlexibleArgumentParser,
|
FlexibleArgumentParser,
|
||||||
StatefulSemaphore,
|
StatefulSemaphore,
|
||||||
api_server_logger,
|
api_server_logger,
|
||||||
@@ -232,6 +235,8 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.add_exception_handler(RequestValidationError, ExceptionHandler.handle_request_validation_exception)
|
||||||
|
app.add_exception_handler(Exception, ExceptionHandler.handle_exception)
|
||||||
instrument(app)
|
instrument(app)
|
||||||
|
|
||||||
|
|
||||||
@@ -336,7 +341,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||||||
if isinstance(generator, ErrorResponse):
|
if isinstance(generator, ErrorResponse):
|
||||||
api_server_logger.debug(f"release: {connection_semaphore.status()}")
|
api_server_logger.debug(f"release: {connection_semaphore.status()}")
|
||||||
connection_semaphore.release()
|
connection_semaphore.release()
|
||||||
return JSONResponse(content={"detail": generator.model_dump()}, status_code=generator.code)
|
return JSONResponse(content=generator.model_dump(), status_code=500)
|
||||||
elif isinstance(generator, ChatCompletionResponse):
|
elif isinstance(generator, ChatCompletionResponse):
|
||||||
api_server_logger.debug(f"release: {connection_semaphore.status()}")
|
api_server_logger.debug(f"release: {connection_semaphore.status()}")
|
||||||
connection_semaphore.release()
|
connection_semaphore.release()
|
||||||
@@ -365,7 +370,7 @@ async def create_completion(request: CompletionRequest):
|
|||||||
generator = await app.state.completion_handler.create_completion(request)
|
generator = await app.state.completion_handler.create_completion(request)
|
||||||
if isinstance(generator, ErrorResponse):
|
if isinstance(generator, ErrorResponse):
|
||||||
connection_semaphore.release()
|
connection_semaphore.release()
|
||||||
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
|
return JSONResponse(content=generator.model_dump(), status_code=500)
|
||||||
elif isinstance(generator, CompletionResponse):
|
elif isinstance(generator, CompletionResponse):
|
||||||
connection_semaphore.release()
|
connection_semaphore.release()
|
||||||
return JSONResponse(content=generator.model_dump())
|
return JSONResponse(content=generator.model_dump())
|
||||||
@@ -388,7 +393,7 @@ async def list_models() -> Response:
|
|||||||
|
|
||||||
models = await app.state.model_handler.list_models()
|
models = await app.state.model_handler.list_models()
|
||||||
if isinstance(models, ErrorResponse):
|
if isinstance(models, ErrorResponse):
|
||||||
return JSONResponse(content=models.model_dump(), status_code=models.code)
|
return JSONResponse(content=models.model_dump())
|
||||||
elif isinstance(models, ModelList):
|
elif isinstance(models, ModelList):
|
||||||
return JSONResponse(content=models.model_dump())
|
return JSONResponse(content=models.model_dump())
|
||||||
|
|
||||||
@@ -502,7 +507,8 @@ def control_scheduler(request: ControlSchedulerRequest):
|
|||||||
"""
|
"""
|
||||||
Control the scheduler behavior with the given parameters.
|
Control the scheduler behavior with the given parameters.
|
||||||
"""
|
"""
|
||||||
content = ErrorResponse(object="", message="Scheduler updated successfully", code=0)
|
|
||||||
|
content = ErrorResponse(error=ErrorInfo(message="Scheduler updated successfully", code=0))
|
||||||
|
|
||||||
global llm_engine
|
global llm_engine
|
||||||
if llm_engine is None:
|
if llm_engine is None:
|
||||||
|
@@ -32,9 +32,14 @@ class ErrorResponse(BaseModel):
|
|||||||
Error response from OpenAI API.
|
Error response from OpenAI API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
object: str = "error"
|
error: ErrorInfo
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorInfo(BaseModel):
|
||||||
message: str
|
message: str
|
||||||
code: int
|
type: Optional[str] = None
|
||||||
|
param: Optional[str] = None
|
||||||
|
code: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class PromptTokenUsageInfo(BaseModel):
|
class PromptTokenUsageInfo(BaseModel):
|
||||||
@@ -403,21 +408,21 @@ class CompletionRequest(BaseModel):
|
|||||||
prompt: Union[List[int], List[List[int]], str, List[str]]
|
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||||
best_of: Optional[int] = None
|
best_of: Optional[int] = None
|
||||||
echo: Optional[bool] = False
|
echo: Optional[bool] = False
|
||||||
frequency_penalty: Optional[float] = None
|
frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
|
||||||
logprobs: Optional[int] = None
|
logprobs: Optional[int] = None
|
||||||
# For logits and logprobs post processing
|
# For logits and logprobs post processing
|
||||||
temp_scaled_logprobs: bool = False
|
temp_scaled_logprobs: bool = False
|
||||||
top_p_normalized_logprobs: bool = False
|
top_p_normalized_logprobs: bool = False
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
n: int = 1
|
n: Optional[int] = 1
|
||||||
presence_penalty: Optional[float] = None
|
presence_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = Field(default=None, ge=0, le=922337203685477580)
|
||||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
stream_options: Optional[StreamOptions] = None
|
stream_options: Optional[StreamOptions] = None
|
||||||
suffix: Optional[dict] = None
|
suffix: Optional[dict] = None
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = Field(default=None, ge=0)
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = Field(default=None, ge=0, le=1)
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
|
|
||||||
# doc: begin-completion-sampling-params
|
# doc: begin-completion-sampling-params
|
||||||
@@ -537,7 +542,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
messages: Union[List[Any], List[int]]
|
messages: Union[List[Any], List[int]]
|
||||||
tools: Optional[List[ChatCompletionToolsParam]] = None
|
tools: Optional[List[ChatCompletionToolsParam]] = None
|
||||||
model: Optional[str] = "default"
|
model: Optional[str] = "default"
|
||||||
frequency_penalty: Optional[float] = None
|
frequency_penalty: Optional[float] = Field(None, le=2, ge=-2)
|
||||||
logprobs: Optional[bool] = False
|
logprobs: Optional[bool] = False
|
||||||
top_logprobs: Optional[int] = 0
|
top_logprobs: Optional[int] = 0
|
||||||
|
|
||||||
@@ -552,13 +557,13 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
)
|
)
|
||||||
max_completion_tokens: Optional[int] = None
|
max_completion_tokens: Optional[int] = None
|
||||||
n: Optional[int] = 1
|
n: Optional[int] = 1
|
||||||
presence_penalty: Optional[float] = None
|
presence_penalty: Optional[float] = Field(None, le=2, ge=-2)
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = Field(default=None, ge=0, le=922337203685477580)
|
||||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
stream_options: Optional[StreamOptions] = None
|
stream_options: Optional[StreamOptions] = None
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = Field(None, ge=0)
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = Field(None, le=1, ge=0)
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None
|
||||||
response_format: Optional[AnyResponseFormat] = None
|
response_format: Optional[AnyResponseFormat] = None
|
||||||
|
@@ -30,6 +30,7 @@ from fastdeploy.entrypoints.openai.protocol import (
|
|||||||
ChatCompletionStreamResponse,
|
ChatCompletionStreamResponse,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
DeltaMessage,
|
DeltaMessage,
|
||||||
|
ErrorInfo,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
LogProbEntry,
|
LogProbEntry,
|
||||||
LogProbs,
|
LogProbs,
|
||||||
@@ -38,7 +39,7 @@ from fastdeploy.entrypoints.openai.protocol import (
|
|||||||
)
|
)
|
||||||
from fastdeploy.entrypoints.openai.response_processors import ChatResponseProcessor
|
from fastdeploy.entrypoints.openai.response_processors import ChatResponseProcessor
|
||||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||||
from fastdeploy.utils import api_server_logger
|
from fastdeploy.utils import ErrorCode, ErrorType, ParameterError, api_server_logger
|
||||||
from fastdeploy.worker.output import LogprobsLists
|
from fastdeploy.worker.output import LogprobsLists
|
||||||
|
|
||||||
|
|
||||||
@@ -86,14 +87,16 @@ class OpenAIServingChat:
|
|||||||
f"Only master node can accept completion request, please send request to master node: {self.master_ip}"
|
f"Only master node can accept completion request, please send request to master node: {self.master_ip}"
|
||||||
)
|
)
|
||||||
api_server_logger.error(err_msg)
|
api_server_logger.error(err_msg)
|
||||||
return ErrorResponse(message=err_msg, code=400)
|
return ErrorResponse(error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR))
|
||||||
|
|
||||||
if self.models:
|
if self.models:
|
||||||
is_supported, request.model = self.models.is_supported_model(request.model)
|
is_supported, request.model = self.models.is_supported_model(request.model)
|
||||||
if not is_supported:
|
if not is_supported:
|
||||||
err_msg = f"Unsupported model: [{request.model}], support [{', '.join([x.name for x in self.models.model_paths])}] or default"
|
err_msg = f"Unsupported model: [{request.model}], support [{', '.join([x.name for x in self.models.model_paths])}] or default"
|
||||||
api_server_logger.error(err_msg)
|
api_server_logger.error(err_msg)
|
||||||
return ErrorResponse(message=err_msg, code=400)
|
return ErrorResponse(
|
||||||
|
error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR, code=ErrorCode.MODEL_NOT_SUPPORT)
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.max_waiting_time < 0:
|
if self.max_waiting_time < 0:
|
||||||
@@ -117,11 +120,17 @@ class OpenAIServingChat:
|
|||||||
text_after_process = current_req_dict.get("text_after_process")
|
text_after_process = current_req_dict.get("text_after_process")
|
||||||
if isinstance(prompt_token_ids, np.ndarray):
|
if isinstance(prompt_token_ids, np.ndarray):
|
||||||
prompt_token_ids = prompt_token_ids.tolist()
|
prompt_token_ids = prompt_token_ids.tolist()
|
||||||
|
except ParameterError as e:
|
||||||
|
api_server_logger.error(e.message)
|
||||||
|
self.engine_client.semaphore.release()
|
||||||
|
return ErrorResponse(
|
||||||
|
error=ErrorInfo(message=str(e.message), type=ErrorType.INVALID_REQUEST_ERROR, param=e.param)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"request[{request_id}] generator error: {str(e)}, {str(traceback.format_exc())}"
|
error_msg = f"request[{request_id}] generator error: {str(e)}, {str(traceback.format_exc())}"
|
||||||
api_server_logger.error(error_msg)
|
api_server_logger.error(error_msg)
|
||||||
self.engine_client.semaphore.release()
|
self.engine_client.semaphore.release()
|
||||||
return ErrorResponse(code=400, message=error_msg)
|
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INVALID_REQUEST_ERROR))
|
||||||
del current_req_dict
|
del current_req_dict
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
@@ -136,21 +145,20 @@ class OpenAIServingChat:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"request[{request_id}]full generator error: {str(e)}, {str(traceback.format_exc())}"
|
error_msg = f"request[{request_id}]full generator error: {str(e)}, {str(traceback.format_exc())}"
|
||||||
api_server_logger.error(error_msg)
|
api_server_logger.error(error_msg)
|
||||||
return ErrorResponse(code=408, message=error_msg)
|
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.SERVER_ERROR))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
f"request[{request_id}] waiting error: {str(e)}, {str(traceback.format_exc())}, "
|
f"request[{request_id}] waiting error: {str(e)}, {str(traceback.format_exc())}, "
|
||||||
f"max waiting time: {self.max_waiting_time}"
|
f"max waiting time: {self.max_waiting_time}"
|
||||||
)
|
)
|
||||||
api_server_logger.error(error_msg)
|
api_server_logger.error(error_msg)
|
||||||
return ErrorResponse(code=408, message=error_msg)
|
return ErrorResponse(
|
||||||
|
error=ErrorInfo(message=error_msg, type=ErrorType.TIMEOUT_ERROR, code=ErrorCode.TIMEOUT)
|
||||||
|
)
|
||||||
|
|
||||||
def _create_streaming_error_response(self, message: str) -> str:
|
def _create_streaming_error_response(self, message: str) -> str:
|
||||||
api_server_logger.error(message)
|
api_server_logger.error(message)
|
||||||
error_response = ErrorResponse(
|
error_response = ErrorResponse(error=ErrorInfo(message=message, type=ErrorType.SERVER_ERROR))
|
||||||
code=400,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
return error_response.model_dump_json()
|
return error_response.model_dump_json()
|
||||||
|
|
||||||
async def chat_completion_stream_generator(
|
async def chat_completion_stream_generator(
|
||||||
|
@@ -30,10 +30,11 @@ from fastdeploy.entrypoints.openai.protocol import (
|
|||||||
CompletionResponseChoice,
|
CompletionResponseChoice,
|
||||||
CompletionResponseStreamChoice,
|
CompletionResponseStreamChoice,
|
||||||
CompletionStreamResponse,
|
CompletionStreamResponse,
|
||||||
|
ErrorInfo,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
from fastdeploy.utils import api_server_logger
|
from fastdeploy.utils import ErrorCode, ErrorType, ParameterError, api_server_logger
|
||||||
from fastdeploy.worker.output import LogprobsLists
|
from fastdeploy.worker.output import LogprobsLists
|
||||||
|
|
||||||
|
|
||||||
@@ -63,13 +64,15 @@ class OpenAIServingCompletion:
|
|||||||
f"Only master node can accept completion request, please send request to master node: {self.master_ip}"
|
f"Only master node can accept completion request, please send request to master node: {self.master_ip}"
|
||||||
)
|
)
|
||||||
api_server_logger.error(err_msg)
|
api_server_logger.error(err_msg)
|
||||||
return ErrorResponse(message=err_msg, code=400)
|
return ErrorResponse(error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR))
|
||||||
if self.models:
|
if self.models:
|
||||||
is_supported, request.model = self.models.is_supported_model(request.model)
|
is_supported, request.model = self.models.is_supported_model(request.model)
|
||||||
if not is_supported:
|
if not is_supported:
|
||||||
err_msg = f"Unsupported model: [{request.model}], support [{', '.join([x.name for x in self.models.model_paths])}] or default"
|
err_msg = f"Unsupported model: [{request.model}], support [{', '.join([x.name for x in self.models.model_paths])}] or default"
|
||||||
api_server_logger.error(err_msg)
|
api_server_logger.error(err_msg)
|
||||||
return ErrorResponse(message=err_msg, code=400)
|
return ErrorResponse(
|
||||||
|
error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR, code=ErrorCode.MODEL_NOT_SUPPORT)
|
||||||
|
)
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
if request.user is not None:
|
if request.user is not None:
|
||||||
request_id = f"cmpl-{request.user}-{uuid.uuid4()}"
|
request_id = f"cmpl-{request.user}-{uuid.uuid4()}"
|
||||||
@@ -112,7 +115,7 @@ class OpenAIServingCompletion:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"OpenAIServingCompletion create_completion: {e}, {str(traceback.format_exc())}"
|
error_msg = f"OpenAIServingCompletion create_completion: {e}, {str(traceback.format_exc())}"
|
||||||
api_server_logger.error(error_msg)
|
api_server_logger.error(error_msg)
|
||||||
return ErrorResponse(message=error_msg, code=400)
|
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.SERVER_ERROR))
|
||||||
|
|
||||||
if request_prompt_ids is not None:
|
if request_prompt_ids is not None:
|
||||||
request_prompts = request_prompt_ids
|
request_prompts = request_prompt_ids
|
||||||
@@ -132,7 +135,9 @@ class OpenAIServingCompletion:
|
|||||||
f"max waiting time: {self.max_waiting_time}"
|
f"max waiting time: {self.max_waiting_time}"
|
||||||
)
|
)
|
||||||
api_server_logger.error(error_msg)
|
api_server_logger.error(error_msg)
|
||||||
return ErrorResponse(code=408, message=error_msg)
|
return ErrorResponse(
|
||||||
|
error=ErrorInfo(message=error_msg, code=ErrorCode.TIMEOUT, type=ErrorType.TIMEOUT_ERROR)
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
@@ -146,11 +151,17 @@ class OpenAIServingCompletion:
|
|||||||
text_after_process_list.append(current_req_dict.get("text_after_process"))
|
text_after_process_list.append(current_req_dict.get("text_after_process"))
|
||||||
prompt_batched_token_ids.append(prompt_token_ids)
|
prompt_batched_token_ids.append(prompt_token_ids)
|
||||||
del current_req_dict
|
del current_req_dict
|
||||||
|
except ParameterError as e:
|
||||||
|
api_server_logger.error(e.message)
|
||||||
|
self.engine_client.semaphore.release()
|
||||||
|
return ErrorResponse(code=400, message=str(e.message), type="invalid_request", param=e.param)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"OpenAIServingCompletion format error: {e}, {str(traceback.format_exc())}"
|
error_msg = f"OpenAIServingCompletion format error: {e}, {str(traceback.format_exc())}"
|
||||||
api_server_logger.error(error_msg)
|
api_server_logger.error(error_msg)
|
||||||
self.engine_client.semaphore.release()
|
self.engine_client.semaphore.release()
|
||||||
return ErrorResponse(message=str(e), code=400)
|
return ErrorResponse(
|
||||||
|
error=ErrorInfo(message=str(e), code=ErrorCode.INVALID_VALUE, type=ErrorType.INVALID_REQUEST_ERROR)
|
||||||
|
)
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return self.completion_stream_generator(
|
return self.completion_stream_generator(
|
||||||
@@ -178,12 +189,12 @@ class OpenAIServingCompletion:
|
|||||||
f"OpenAIServingCompletion completion_full_generator error: {e}, {str(traceback.format_exc())}"
|
f"OpenAIServingCompletion completion_full_generator error: {e}, {str(traceback.format_exc())}"
|
||||||
)
|
)
|
||||||
api_server_logger.error(error_msg)
|
api_server_logger.error(error_msg)
|
||||||
return ErrorResponse(code=400, message=error_msg)
|
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.SERVER_ERROR))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"OpenAIServingCompletion create_completion error: {e}, {str(traceback.format_exc())}"
|
error_msg = f"OpenAIServingCompletion create_completion error: {e}, {str(traceback.format_exc())}"
|
||||||
api_server_logger.error(error_msg)
|
api_server_logger.error(error_msg)
|
||||||
return ErrorResponse(message=error_msg, code=400)
|
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.SERVER_ERROR))
|
||||||
|
|
||||||
async def completion_full_generator(
|
async def completion_full_generator(
|
||||||
self,
|
self,
|
||||||
|
@@ -18,12 +18,13 @@ from dataclasses import dataclass
|
|||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from fastdeploy.entrypoints.openai.protocol import (
|
from fastdeploy.entrypoints.openai.protocol import (
|
||||||
|
ErrorInfo,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
ModelList,
|
ModelList,
|
||||||
ModelPermission,
|
ModelPermission,
|
||||||
)
|
)
|
||||||
from fastdeploy.utils import api_server_logger, get_host_ip
|
from fastdeploy.utils import ErrorType, api_server_logger, get_host_ip
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -86,7 +87,7 @@ class OpenAIServingModels:
|
|||||||
f"Only master node can accept models request, please send request to master node: {self.master_ip}"
|
f"Only master node can accept models request, please send request to master node: {self.master_ip}"
|
||||||
)
|
)
|
||||||
api_server_logger.error(err_msg)
|
api_server_logger.error(err_msg)
|
||||||
return ErrorResponse(message=err_msg, code=400)
|
return ErrorResponse(error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR))
|
||||||
model_infos = [
|
model_infos = [
|
||||||
ModelInfo(
|
ModelInfo(
|
||||||
id=model.name, max_model_len=self.max_model_len, root=model.model_path, permission=[ModelPermission()]
|
id=model.name, max_model_len=self.max_model_len, root=model.model_path, permission=[ModelPermission()]
|
||||||
|
@@ -28,6 +28,8 @@ import sys
|
|||||||
import tarfile
|
import tarfile
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from http import HTTPStatus
|
||||||
from importlib.metadata import PackageNotFoundError, distribution
|
from importlib.metadata import PackageNotFoundError, distribution
|
||||||
from logging.handlers import BaseRotatingHandler
|
from logging.handlers import BaseRotatingHandler
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -38,10 +40,14 @@ import paddle
|
|||||||
import requests
|
import requests
|
||||||
import yaml
|
import yaml
|
||||||
from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download
|
from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing_extensions import TypeIs, assert_never
|
from typing_extensions import TypeIs, assert_never
|
||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
|
from fastdeploy.entrypoints.openai.protocol import ErrorInfo, ErrorResponse
|
||||||
from fastdeploy.logger.logger import FastDeployLogger
|
from fastdeploy.logger.logger import FastDeployLogger
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@@ -59,6 +65,61 @@ class EngineError(Exception):
|
|||||||
self.error_code = error_code
|
self.error_code = error_code
|
||||||
|
|
||||||
|
|
||||||
|
class ParameterError(Exception):
|
||||||
|
def __init__(self, param: str, message: str):
|
||||||
|
self.param = param
|
||||||
|
self.message = message
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ExceptionHandler:
|
||||||
|
|
||||||
|
# 全局异常兜底处理
|
||||||
|
@staticmethod
|
||||||
|
async def handle_exception(request: Request, exc: Exception) -> JSONResponse:
|
||||||
|
error = ErrorResponse(error=ErrorInfo(message=str(exc), type=ErrorType.INTERNAL_ERROR))
|
||||||
|
return JSONResponse(content=error.model_dump(), status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
# 处理请求参数验证异常
|
||||||
|
@staticmethod
|
||||||
|
async def handle_request_validation_exception(_: Request, exc: RequestValidationError) -> JSONResponse:
|
||||||
|
errors = exc.errors()
|
||||||
|
if not errors:
|
||||||
|
message = str(exc)
|
||||||
|
param = None
|
||||||
|
else:
|
||||||
|
first_error = errors[0]
|
||||||
|
loc = first_error.get("loc", [])
|
||||||
|
param = loc[-1] if loc else None
|
||||||
|
message = first_error.get("msg", str(exc))
|
||||||
|
err = ErrorResponse(
|
||||||
|
error=ErrorInfo(
|
||||||
|
message=message,
|
||||||
|
type=ErrorType.INVALID_REQUEST_ERROR,
|
||||||
|
code=ErrorCode.MISSING_REQUIRED_PARAMETER if param == "messages" else ErrorCode.INVALID_VALUE,
|
||||||
|
param=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return JSONResponse(content=err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorType(str, Enum):
|
||||||
|
INVALID_REQUEST_ERROR = "invalid_request_error"
|
||||||
|
TIMEOUT_ERROR = "timeout_error"
|
||||||
|
SERVER_ERROR = "server_error"
|
||||||
|
INTERNAL_ERROR = "internal_error"
|
||||||
|
API_CONNECTION_ERROR = "api_connection_error"
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorCode(str, Enum):
|
||||||
|
INVALID_VALUE = "invalid_value"
|
||||||
|
CONTEXT_LENGTH_EXCEEDED = "context_length_exceeded"
|
||||||
|
MODEL_NOT_SUPPORT = "model_not_support"
|
||||||
|
TIMEOUT = "timeout"
|
||||||
|
CONNECTION_ERROR = "connection_error"
|
||||||
|
MISSING_REQUIRED_PARAMETER = "missing_required_parameter"
|
||||||
|
|
||||||
|
|
||||||
class ColoredFormatter(logging.Formatter):
|
class ColoredFormatter(logging.Formatter):
|
||||||
"""自定义日志格式器,用于控制台输出带颜色"""
|
"""自定义日志格式器,用于控制台输出带颜色"""
|
||||||
|
|
||||||
|
@@ -20,9 +20,14 @@ def test_missing_messages_field():
|
|||||||
payload = build_request_payload(TEMPLATE, data)
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
resp = send_request(URL, payload).json()
|
resp = send_request(URL, payload).json()
|
||||||
|
|
||||||
assert "detail" in resp, "返回中未包含 detail 错误信息字段"
|
assert "error" in resp, "返回中未包含 error 错误信息字段"
|
||||||
assert any("messages" in err.get("loc", []) for err in resp["detail"]), "未检测到 messages 字段缺失的报错"
|
error = resp["error"]
|
||||||
assert any("Field required" in err.get("msg", "") for err in resp["detail"]), "未检测到 'Field required' 错误提示"
|
assert "Field required" in error.get("message", ""), "未检测到 messages 字段缺失的报错"
|
||||||
|
assert error.get("code") == "missing_required_parameter", "code 字段不正确"
|
||||||
|
|
||||||
|
# assert "detail" in resp, "返回中未包含 detail 错误信息字段"
|
||||||
|
# assert any("messages" in err.get("loc", []) for err in resp["detail"]), "未检测到 messages 字段缺失的报错"
|
||||||
|
# assert any("Field required" in err.get("msg", "") for err in resp["detail"]), "未检测到 'Field required' 错误提示"
|
||||||
|
|
||||||
|
|
||||||
def test_malformed_messages_format():
|
def test_malformed_messages_format():
|
||||||
@@ -34,11 +39,15 @@ def test_malformed_messages_format():
|
|||||||
}
|
}
|
||||||
payload = build_request_payload(TEMPLATE, data)
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
resp = send_request(URL, payload).json()
|
resp = send_request(URL, payload).json()
|
||||||
assert "detail" in resp, "非法结构未被识别"
|
assert "error" in resp, "非法结构未被识别"
|
||||||
assert any("messages" in err.get("loc", []) for err in resp["detail"]), "未检测到 messages 字段结构错误"
|
err = resp["error"]
|
||||||
assert any(
|
assert err.get("param") == "list[any]", f"param 字段错误: {err.get('param')}"
|
||||||
"Input should be a valid list" in err.get("msg", "") for err in resp["detail"]
|
assert err.get("message") == "Input should be a valid list", "错误提示不符合预期"
|
||||||
), "未检测到 'Input should be a valid list' 错误提示"
|
# assert "detail" in resp, "非法结构未被识别"
|
||||||
|
# assert any("messages" in err.get("loc", []) for err in resp["detail"]), "未检测到 messages 字段结构错误"
|
||||||
|
# assert any(
|
||||||
|
# "Input should be a valid list" in err.get("msg", "") for err in resp["detail"]
|
||||||
|
# ), "未检测到 'Input should be a valid list' 错误提示"
|
||||||
|
|
||||||
|
|
||||||
def test_extremely_large_max_tokens():
|
def test_extremely_large_max_tokens():
|
||||||
@@ -79,8 +88,13 @@ def test_top_p_exceed_1():
|
|||||||
}
|
}
|
||||||
payload = build_request_payload(TEMPLATE, data)
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
resp = send_request(URL, payload).json()
|
resp = send_request(URL, payload).json()
|
||||||
assert resp.get("detail").get("object") == "error", "top_p > 1 应触发校验异常"
|
assert "error" in resp, "未返回 error 字段"
|
||||||
assert "top_p value can only be defined" in resp.get("detail").get("message", ""), "未返回预期的 top_p 错误信息"
|
|
||||||
|
err = resp["error"]
|
||||||
|
assert err.get("param") == "top_p", f"param 字段错误: {err.get('param')}"
|
||||||
|
assert err.get("message") == "Input should be less than or equal to 1", "错误提示不符合预期"
|
||||||
|
# assert resp.get("detail").get("object") == "error", "top_p > 1 应触发校验异常"
|
||||||
|
# assert "top_p value can only be defined" in resp.get("detail").get("message", ""), "未返回预期的 top_p 错误信息"
|
||||||
|
|
||||||
|
|
||||||
def test_mixed_valid_invalid_fields():
|
def test_mixed_valid_invalid_fields():
|
||||||
@@ -106,8 +120,8 @@ def test_stop_seq_exceed_num():
|
|||||||
}
|
}
|
||||||
payload = build_request_payload(TEMPLATE, data)
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
resp = send_request(URL, payload).json()
|
resp = send_request(URL, payload).json()
|
||||||
assert resp.get("detail").get("object") == "error", "stop 超出个数应触发异常"
|
assert resp.get("error").get("type") == "invalid_request_error", "stop 超出个数应触发异常"
|
||||||
assert "exceeds the limit max_stop_seqs_num" in resp.get("detail").get("message", ""), "未返回预期的报错信息"
|
assert "exceeds the limit max_stop_seqs_num" in resp.get("error").get("message", ""), "未返回预期的报错信息"
|
||||||
|
|
||||||
|
|
||||||
def test_stop_seq_exceed_length():
|
def test_stop_seq_exceed_length():
|
||||||
@@ -120,8 +134,8 @@ def test_stop_seq_exceed_length():
|
|||||||
}
|
}
|
||||||
payload = build_request_payload(TEMPLATE, data)
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
resp = send_request(URL, payload).json()
|
resp = send_request(URL, payload).json()
|
||||||
assert resp.get("detail").get("object") == "error", "stop 超出长度应触发异常"
|
assert resp.get("error").get("type") == "invalid_request_error", "stop 超出长度应触发异常"
|
||||||
assert "exceeds the limit stop_seqs_max_len" in resp.get("detail").get("message", ""), "未返回预期的报错信息"
|
assert "exceeds the limit stop_seqs_max_len" in resp.get("error").get("message", ""), "未返回预期的报错信息"
|
||||||
|
|
||||||
|
|
||||||
def test_multilingual_input():
|
def test_multilingual_input():
|
||||||
@@ -154,8 +168,8 @@ def test_too_long_input():
|
|||||||
data = {"messages": [{"role": "user", "content": "a," * 200000}], "stream": False} # 超过最大输入长度
|
data = {"messages": [{"role": "user", "content": "a," * 200000}], "stream": False} # 超过最大输入长度
|
||||||
payload = build_request_payload(TEMPLATE, data)
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
resp = send_request(URL, payload).json()
|
resp = send_request(URL, payload).json()
|
||||||
assert resp["detail"].get("object") == "error", "超长输入未被识别为错误"
|
# assert resp["detail"].get("object") == "error", "超长输入未被识别为错误"
|
||||||
assert "Input text is too long" in resp["detail"].get("message", ""), "未检测到最大长度限制错误"
|
assert "Input text is too long" in resp["error"].get("message", ""), "未检测到最大长度限制错误"
|
||||||
|
|
||||||
|
|
||||||
def test_empty_input():
|
def test_empty_input():
|
||||||
@@ -361,8 +375,8 @@ def test_max_tokens_negative():
|
|||||||
}
|
}
|
||||||
payload = build_request_payload(TEMPLATE, data)
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
resp = send_request(URL, payload).json()
|
resp = send_request(URL, payload).json()
|
||||||
assert resp.get("detail").get("object") == "error", "max_tokens < 0 未触发校验异常"
|
# assert resp.get("detail").get("object") == "error", "max_tokens < 0 未触发校验异常"
|
||||||
assert "max_tokens can be defined [1," in resp.get("detail").get("message"), "未返回预期的 max_tokens 错误信息"
|
assert "max_tokens can be defined [1," in resp.get("error").get("message"), "未返回预期的 max_tokens 错误信息"
|
||||||
|
|
||||||
|
|
||||||
def test_max_tokens_min():
|
def test_max_tokens_min():
|
||||||
@@ -379,7 +393,10 @@ def test_max_tokens_min():
|
|||||||
}
|
}
|
||||||
payload = build_request_payload(TEMPLATE, data)
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
resp = send_request(URL, payload).json()
|
resp = send_request(URL, payload).json()
|
||||||
assert resp.get("detail").get("object") == "error", "max_tokens未0时API未拦截住"
|
# assert resp.get("detail").get("object") == "error", "max_tokens未0时API未拦截住"
|
||||||
|
assert "max_tokens can be defined [1," in resp.get("error").get(
|
||||||
|
"message"
|
||||||
|
), "max_tokens未0时API未拦截住,未返回预期的 max_tokens 错误信息"
|
||||||
|
|
||||||
|
|
||||||
def test_max_tokens_non_integer():
|
def test_max_tokens_non_integer():
|
||||||
@@ -397,5 +414,5 @@ def test_max_tokens_non_integer():
|
|||||||
payload = build_request_payload(TEMPLATE, data)
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
resp = send_request(URL, payload).json()
|
resp = send_request(URL, payload).json()
|
||||||
assert (
|
assert (
|
||||||
resp.get("detail")[0].get("msg") == "Input should be a valid integer, got a number with a fractional part"
|
resp.get("error").get("message") == "Input should be a valid integer, got a number with a fractional part"
|
||||||
), "未返回预期的 max_tokens 为非整数的错误信息"
|
), "未返回预期的 max_tokens 为非整数的错误信息"
|
||||||
|
168
tests/entrypoints/openai/test_chatcompletion_request.py
Normal file
168
tests/entrypoints/openai/test_chatcompletion_request.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from fastdeploy.entrypoints.openai.protocol import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
CompletionRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatCompletionRequest(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_required_messages(self):
|
||||||
|
with self.assertRaises(ValidationError):
|
||||||
|
ChatCompletionRequest()
|
||||||
|
|
||||||
|
def test_messages_accepts_list_of_any_and_int(self):
|
||||||
|
req = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}])
|
||||||
|
self.assertEqual(req.messages[0]["role"], "user")
|
||||||
|
|
||||||
|
req = ChatCompletionRequest(messages=[1, 2, 3])
|
||||||
|
self.assertEqual(req.messages, [1, 2, 3])
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
req = ChatCompletionRequest(messages=[1])
|
||||||
|
self.assertEqual(req.model, "default")
|
||||||
|
self.assertFalse(req.logprobs)
|
||||||
|
self.assertEqual(req.top_logprobs, 0)
|
||||||
|
self.assertEqual(req.n, 1)
|
||||||
|
self.assertEqual(req.stop, [])
|
||||||
|
|
||||||
|
def test_boundary_values(self):
|
||||||
|
valid_cases = [
|
||||||
|
("frequency_penalty", -2),
|
||||||
|
("frequency_penalty", 2),
|
||||||
|
("presence_penalty", -2),
|
||||||
|
("presence_penalty", 2),
|
||||||
|
("temperature", 0),
|
||||||
|
("top_p", 1),
|
||||||
|
("seed", 0),
|
||||||
|
("seed", 922337203685477580),
|
||||||
|
]
|
||||||
|
for field, value in valid_cases:
|
||||||
|
with self.subTest(field=field, value=value):
|
||||||
|
req = ChatCompletionRequest(messages=[1], **{field: value})
|
||||||
|
self.assertEqual(getattr(req, field), value)
|
||||||
|
|
||||||
|
def test_invalid_boundary_values(self):
|
||||||
|
invalid_cases = [
|
||||||
|
("frequency_penalty", -3),
|
||||||
|
("frequency_penalty", 3),
|
||||||
|
("presence_penalty", -3),
|
||||||
|
("presence_penalty", 3),
|
||||||
|
("temperature", -1),
|
||||||
|
("top_p", 1.1),
|
||||||
|
("seed", -1),
|
||||||
|
("seed", 922337203685477581),
|
||||||
|
]
|
||||||
|
for field, value in invalid_cases:
|
||||||
|
with self.subTest(field=field, value=value):
|
||||||
|
with self.assertRaises(ValidationError):
|
||||||
|
ChatCompletionRequest(messages=[1], **{field: value})
|
||||||
|
|
||||||
|
def test_stop_field_accepts_str_or_list(self):
|
||||||
|
req1 = ChatCompletionRequest(messages=[1], stop="end")
|
||||||
|
self.assertEqual(req1.stop, "end")
|
||||||
|
|
||||||
|
req2 = ChatCompletionRequest(messages=[1], stop=["a", "b"])
|
||||||
|
self.assertEqual(req2.stop, ["a", "b"])
|
||||||
|
|
||||||
|
with self.assertRaises(ValidationError):
|
||||||
|
ChatCompletionRequest(messages=[1], stop=123)
|
||||||
|
|
||||||
|
def test_deprecated_max_tokens_field(self):
|
||||||
|
req = ChatCompletionRequest(messages=[1], max_tokens=10)
|
||||||
|
self.assertEqual(req.max_tokens, 10)
|
||||||
|
|
||||||
|
def test_field_names_snapshot(self):
|
||||||
|
expected_fields = set(ChatCompletionRequest.__fields__.keys())
|
||||||
|
self.assertEqual(set(ChatCompletionRequest.__fields__.keys()), expected_fields)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompletionRequest(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_required_prompt(self):
|
||||||
|
with self.assertRaises(ValidationError):
|
||||||
|
CompletionRequest()
|
||||||
|
|
||||||
|
def test_prompt_accepts_various_types(self):
|
||||||
|
# str
|
||||||
|
req = CompletionRequest(prompt="hello")
|
||||||
|
self.assertEqual(req.prompt, "hello")
|
||||||
|
|
||||||
|
# list of str
|
||||||
|
req = CompletionRequest(prompt=["hello", "world"])
|
||||||
|
self.assertEqual(req.prompt, ["hello", "world"])
|
||||||
|
|
||||||
|
# list of int
|
||||||
|
req = CompletionRequest(prompt=[1, 2, 3])
|
||||||
|
self.assertEqual(req.prompt, [1, 2, 3])
|
||||||
|
|
||||||
|
# list of list of int
|
||||||
|
req = CompletionRequest(prompt=[[1, 2], [3, 4]])
|
||||||
|
self.assertEqual(req.prompt, [[1, 2], [3, 4]])
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
req = CompletionRequest(prompt="test")
|
||||||
|
self.assertEqual(req.model, "default")
|
||||||
|
self.assertEqual(req.echo, False)
|
||||||
|
self.assertEqual(req.temp_scaled_logprobs, False)
|
||||||
|
self.assertEqual(req.top_p_normalized_logprobs, False)
|
||||||
|
self.assertEqual(req.n, 1)
|
||||||
|
self.assertEqual(req.stop, [])
|
||||||
|
self.assertEqual(req.stream, False)
|
||||||
|
|
||||||
|
def test_boundary_values(self):
|
||||||
|
valid_cases = [
|
||||||
|
("frequency_penalty", -2),
|
||||||
|
("frequency_penalty", 2),
|
||||||
|
("presence_penalty", -2),
|
||||||
|
("presence_penalty", 2),
|
||||||
|
("temperature", 0),
|
||||||
|
("top_p", 0),
|
||||||
|
("top_p", 1),
|
||||||
|
("seed", 0),
|
||||||
|
("seed", 922337203685477580),
|
||||||
|
]
|
||||||
|
for field, value in valid_cases:
|
||||||
|
with self.subTest(field=field, value=value):
|
||||||
|
req = CompletionRequest(prompt="hi", **{field: value})
|
||||||
|
self.assertEqual(getattr(req, field), value)
|
||||||
|
|
||||||
|
def test_invalid_boundary_values(self):
|
||||||
|
invalid_cases = [
|
||||||
|
("frequency_penalty", -3),
|
||||||
|
("frequency_penalty", 3),
|
||||||
|
("presence_penalty", -3),
|
||||||
|
("presence_penalty", 3),
|
||||||
|
("temperature", -0.1),
|
||||||
|
("top_p", -0.1),
|
||||||
|
("top_p", 1.1),
|
||||||
|
("seed", -1),
|
||||||
|
("seed", 922337203685477581),
|
||||||
|
]
|
||||||
|
for field, value in invalid_cases:
|
||||||
|
with self.subTest(field=field, value=value):
|
||||||
|
with self.assertRaises(ValidationError):
|
||||||
|
CompletionRequest(prompt="hi", **{field: value})
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
30
tests/entrypoints/openai/test_error_response.py
Normal file
30
tests/entrypoints/openai/test_error_response.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from fastdeploy.entrypoints.openai.protocol import ErrorResponse
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorResponse(unittest.TestCase):
|
||||||
|
def test_valid_error_response(self):
|
||||||
|
data = {
|
||||||
|
"error": {
|
||||||
|
"message": "Invalid top_p value",
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"param": "top_p",
|
||||||
|
"code": "invalid_value",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err_resp = ErrorResponse(**data)
|
||||||
|
self.assertEqual(err_resp.error.message, "Invalid top_p value")
|
||||||
|
self.assertEqual(err_resp.error.param, "top_p")
|
||||||
|
self.assertEqual(err_resp.error.code, "invalid_value")
|
||||||
|
|
||||||
|
def test_missing_message_field(self):
|
||||||
|
data = {"error": {"type": "invalid_request_error", "param": "messages", "code": "missing_required_parameter"}}
|
||||||
|
with self.assertRaises(ValidationError):
|
||||||
|
ErrorResponse(**data)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
54
tests/utils/test_exception_handler.py
Normal file
54
tests/utils/test_exception_handler.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from fastdeploy.utils import ErrorCode, ExceptionHandler, ParameterError
|
||||||
|
|
||||||
|
|
||||||
|
class TestParameterError(unittest.TestCase):
|
||||||
|
def test_parameter_error_init(self):
|
||||||
|
exc = ParameterError("param1", "error message")
|
||||||
|
self.assertEqual(exc.param, "param1")
|
||||||
|
self.assertEqual(exc.message, "error message")
|
||||||
|
self.assertEqual(str(exc), "error message")
|
||||||
|
|
||||||
|
|
||||||
|
class TestExceptionHandler(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
|
async def test_handle_exception(self):
|
||||||
|
"""普通异常应返回 500 + internal_error"""
|
||||||
|
exc = RuntimeError("Something went wrong")
|
||||||
|
resp: JSONResponse = await ExceptionHandler.handle_exception(None, exc)
|
||||||
|
body = json.loads(resp.body.decode())
|
||||||
|
self.assertEqual(resp.status_code, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||||
|
self.assertEqual(body["error"]["type"], "internal_error")
|
||||||
|
self.assertIn("Something went wrong", body["error"]["message"])
|
||||||
|
|
||||||
|
async def test_handle_request_validation_missing_messages(self):
|
||||||
|
"""缺少 messages 参数时,应返回 missing_required_parameter"""
|
||||||
|
exc = RequestValidationError([{"loc": ("body", "messages"), "msg": "Field required", "type": "missing"}])
|
||||||
|
resp: JSONResponse = await ExceptionHandler.handle_request_validation_exception(None, exc)
|
||||||
|
data = json.loads(resp.body.decode())
|
||||||
|
self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
|
||||||
|
self.assertEqual(data["error"]["param"], "messages")
|
||||||
|
self.assertEqual(data["error"]["code"], ErrorCode.MISSING_REQUIRED_PARAMETER)
|
||||||
|
self.assertIn("Field required", data["error"]["message"])
|
||||||
|
|
||||||
|
async def test_handle_request_validation_invalid_value(self):
|
||||||
|
"""参数非法时,应返回 invalid_value"""
|
||||||
|
exc = RequestValidationError(
|
||||||
|
[{"loc": ("body", "top_p"), "msg": "Input should be less than or equal to 1", "type": "value_error"}]
|
||||||
|
)
|
||||||
|
resp: JSONResponse = await ExceptionHandler.handle_request_validation_exception(None, exc)
|
||||||
|
data = json.loads(resp.body.decode())
|
||||||
|
self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
|
||||||
|
self.assertEqual(data["error"]["param"], "top_p")
|
||||||
|
self.assertEqual(data["error"]["code"], ErrorCode.INVALID_VALUE)
|
||||||
|
self.assertIn("less than or equal to 1", data["error"]["message"])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Reference in New Issue
Block a user