Files
FastDeploy/fastdeploy/entrypoints/openai/serving_engine.py
lizexu123 c563eca791 [Feature] support reward model (#5301)
* Your commit message here

* add test

* update develop

* support reward

* support enable_chunk_prefill

* support bingfa

* support convert is reward

* update test

* delete print

* fix enable_thinking

* add document

* fix place

* fix test

* fix

* support enable_prefix_caching

* add no-enable_prefix-caching test

* fix

* support enable_prefix_caching

* delete print

* fix document

* fix

* fix test

* fix document and delete chinese

* udpate

* enable_thinking

* fix test
2025-12-02 14:55:31 +08:00

289 lines
11 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
import time
import traceback
import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from typing import Any, ClassVar, Generic, Optional, TypeVar, Union
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import override
from fastdeploy.engine.request import PoolingRequestOutput, RequestOutput
from fastdeploy.entrypoints.openai.protocol import (
ErrorInfo,
ErrorResponse,
InvalidParameterException,
)
from fastdeploy.utils import ErrorCode, ErrorType, api_server_logger
RequestT = TypeVar("RequestT")
class ServeContext(
BaseModel,
Generic[RequestT],
):
# Shared across all requests
request: RequestT
model_name: str
request_id: str
created_time: int = Field(default_factory=lambda: int(time.time()))
preprocess_requests: Optional[list[dict]] = None
request_output: Optional[Union[RequestOutput, PoolingRequestOutput]] = None
# `protected_namespaces` resolves Pydantic v2's warning
# on conflict with protected namespace "model_"
model_config = ConfigDict(
protected_namespaces=(),
arbitrary_types_allowed=True,
)
class OpenAIServing(ABC, Generic[RequestT]):
request_id_prefix: ClassVar[str]
"""
Base pipeline for OpenAI-style serving implementations
"""
def __init__(self, engine_client, models, cfg, pid, ips, max_waiting_time):
self.engine_client = engine_client
self.models = models
self.cfg = cfg
self.pid = pid
self.max_waiting_time = max_waiting_time
# Parse master IP
if ips is not None:
if isinstance(ips, list):
self.master_ip = ips[0]
else:
self.master_ip = ips.split(",")[0]
else:
self.master_ip = "0.0.0.0"
api_server_logger.info(f"master ip: {self.master_ip}")
def _check_master(self) -> bool:
"""Check if current node is master"""
return self.engine_client.is_master
def _check_supported_model(self, model_name: str) -> tuple[bool, str]:
"""Check if model is supported and return adjusted model name"""
if not self.models:
return True, model_name
is_supported, adjusted_name = self.models.is_supported_model(model_name)
if not is_supported:
err_msg = f"Unsupported model: [{model_name}]"
api_server_logger.error(err_msg)
return is_supported, adjusted_name
async def _acquire_semaphore(self, request_id: str) -> bool:
"""Acquire engine client semaphore with timeout"""
try:
api_server_logger.info(f"Acquire request:{request_id} status:{self.engine_client.semaphore.status()}")
if self.max_waiting_time < 0:
await self.engine_client.semaphore.acquire()
else:
await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time)
return True
except asyncio.TimeoutError:
self._release_semaphore(request_id)
error_msg = f"Request waiting timeout, request:{request_id} max waiting time:{self.max_waiting_time}"
api_server_logger.error(error_msg)
return False
def _release_semaphore(self, request_id: str) -> None:
"""Release engine client semaphore"""
self.engine_client.semaphore.release()
api_server_logger.info(f"Release request:{request_id} status:{self.engine_client.semaphore.status()}")
def _create_error_response(
self,
message: str,
error_type: ErrorType = ErrorType.INTERNAL_ERROR,
code: Optional[ErrorCode] = ErrorCode.INTERNAL_ERROR,
param: Optional[str] = None,
) -> ErrorResponse:
"""Create standardized error response"""
traceback.print_exc()
api_server_logger.error(message)
return ErrorResponse(error=ErrorInfo(message=message, type=error_type, code=code, param=param))
def _generate_request_id(self, user: Optional[str] = None) -> str:
"""Generate a unique request ID"""
if user is not None:
return f"{self.request_id_prefix}-{user}-{uuid.uuid4()}"
return f"{self.request_id_prefix}-{uuid.uuid4()}"
def _validate_request(self, ctx: ServeContext):
"""Validate the request before processing"""
pass
@abstractmethod
async def _preprocess(self, ctx: ServeContext):
"""Preprocess the request into engine format"""
pass
@abstractmethod
async def _prepare_generators(self, ctx: ServeContext) -> Any:
"""Process engine response into final format"""
# 此函数是一个异步方法,用于处理引擎响应并将其转换为最终格式
pass
@abstractmethod
def _build_response(self, ctx: ServeContext) -> Any:
"""Generate the final response object"""
pass
async def handle(self, ctx: ServeContext) -> Union[Any, ErrorResponse]:
"""Handle incoming requests"""
generation = self._pipeline(ctx)
async for response in generation:
yield response
async def _pipeline(self, ctx: ServeContext) -> Union[Any, ErrorResponse]:
"""
Pipeline for handling requests
Args:
reqeust: The request to be handled
Returns:
A generator that yields responses
"""
# Step 1: Request validation
# Step 1.1: Check if current node is master
if not self._check_master():
yield self._create_error_response(
f"Only master node can accept request, please send to master node: {self.master_ip}"
)
request = ctx.request
# Step 1.2: Check supported model
is_supported, request.model = self._check_supported_model(ctx.model_name)
if not is_supported:
yield self._create_error_response(
f"Unsupported model: [{request.model}]", ErrorType.API_CONNECTION_ERROR, ErrorCode.MODEL_NOT_SUPPORT
)
# Step 1.3: Validate request
self._validate_request(ctx)
request_id = self._generate_request_id(getattr(request, "user", None))
api_server_logger.info(f"Initialize request {request_id}: {request}")
# Step 2: Semaphore acquisition
if not await self._acquire_semaphore(request_id):
yield self._create_error_response("Request waiting timeout", ErrorType.TIMEOUT_ERROR, ErrorCode.TIMEOUT)
try:
# Step 3: Preprocessing
await self._preprocess(ctx)
# Step 4: Response processing
generators = self._prepare_generators(ctx)
# Step 5: Final response build
async for request_output in generators:
ctx.request_output = request_output
yield self._build_response(ctx)
except InvalidParameterException as e:
traceback.print_exc()
yield self._create_error_response(str(e.message), ErrorType.INVALID_REQUEST_ERROR, param=e.param)
except Exception as e:
traceback.print_exc()
yield self._create_error_response(str(e))
finally:
self._release_semaphore(request_id)
class ZmqOpenAIServing(OpenAIServing):
"""
OpenAI-style service architecture using ZeroMQ as the communication mechanism.
"""
def __init__(self, engine_client, models, cfg, pid, ips, max_waiting_time, chat_template):
super().__init__(engine_client, models, cfg, pid, ips, max_waiting_time)
self.chat_template = chat_template
def _request_to_dict(self, ctx: ServeContext):
request = ctx.request
if hasattr(request, "to_dict_for_infer"):
request_dict = request.to_dict_for_infer(ctx.request_id)
else:
request_dict = request.dict()
request_dict["request_id"] = ctx.request_id
request_dict["arrival_time"] = time.time()
self._process_chat_template_kwargs(request_dict)
return request_dict
def _request_to_batch_dicts(self, ctx: ServeContext):
"""Convert multiple requests to dictionary form"""
return [self._request_to_dict(ctx)]
@override
async def _preprocess(self, ctx: ServeContext):
"""Preprocess the request into engine format"""
request_dicts = self._request_to_batch_dicts(ctx)
ctx.preprocess_requests = request_dicts
for request_dict in request_dicts:
api_server_logger.info(f"batch add request_id: {request_dict['request_id']}, request: {request_dict}")
await self.engine_client.format_and_add_data(request_dict)
def _process_chat_template_kwargs(self, request_dict):
"""Add default values to chat template kwargs"""
if "chat_template" not in request_dict:
request_dict["chat_template"] = self.chat_template
chat_template_kwargs = request_dict.get("chat_template_kwargs") or {}
chat_template_kwargs.update(
{
"chat_template": request_dict.get("chat_template"),
"add_stop_sequences": request_dict.get("add_stop_sequences"),
}
)
request_dict["chat_template_kwargs"] = chat_template_kwargs
@override
async def _prepare_generators(self, ctx: ServeContext) -> AsyncGenerator[dict]:
"""Prepare a generator of responses"""
request_id = ctx.request_id
try:
num_choices = len(ctx.preprocess_requests)
dealer, request_output_queue = await self.engine_client.connection_manager.get_connection(
request_id, num_choices
)
for pr in ctx.preprocess_requests:
dealer.write([b"", pr["request_id"].encode("utf-8")])
# if self.engine_client.check_model_weight_status():
# raise ValueError("Engine is clearing model weight")
while num_choices > 0:
request_output_dicts = await asyncio.wait_for(request_output_queue.get(), timeout=60)
for request_output_dict in request_output_dicts:
api_server_logger.debug(f"Received RequestOutput: {request_output_dict}")
if request_output_dict["finished"] is True:
num_choices -= 1
yield request_output_dict
except Exception as e:
raise ValueError(f"Error processing response: {str(e)}")
finally:
await self.engine_client.connection_manager.cleanup_request(request_id)