[feature] support reward api (#4518)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

Co-authored-by: SunLei <sunlei5788@gmail.com>
This commit is contained in:
xiaolei373
2025-10-29 00:20:28 +08:00
committed by GitHub
parent a012e3608b
commit 14e7d88ea4
9 changed files with 362 additions and 17 deletions

View File

@@ -886,24 +886,27 @@ class EngineService:
for request_id, contents in results.items():
new_contents = []
for content in contents:
decode_type = content.outputs.decode_type
delta_text = ""
if decode_type == 0:
delta_text, token_ids = self._decode_token(
token_ids=content.outputs.token_ids, req_id=request_id, is_end=content.finished
)
if isinstance(content, RequestOutput):
decode_type = content.outputs.decode_type
delta_text = ""
if decode_type == 0:
delta_text, token_ids = self._decode_token(
token_ids=content.outputs.token_ids, req_id=request_id, is_end=content.finished
)
else:
token_ids = content.outputs.token_ids
if len(token_ids):
content.outputs.token_ids = token_ids
content.outputs.text = delta_text
new_contents.append(content)
elif content.finished:
new_contents.append(content)
else:
llm_logger.warning(
f"current tokens need to accumulate, req_id: {request_id} {content.outputs.token_ids}"
)
else:
token_ids = content.outputs.token_ids
if len(token_ids):
content.outputs.token_ids = token_ids
content.outputs.text = delta_text
new_contents.append(content)
elif content.finished:
new_contents.append(content)
else:
llm_logger.warning(
f"current tokens need to accumulate, req_id: {request_id} {content.outputs.token_ids}"
)
if len(new_contents):
llm_logger.info(f"Send response for request id: {request_id}")
self.send_response_server.send_response(request_id, new_contents)

View File

@@ -86,6 +86,7 @@ class PoolingParams(
return {
"embed": ["dimensions", "normalize"],
"encode": ["softmax", "step_tag_id", "returned_token_ids"],
"reward": ["dimensions", "normalize"],
}
def to_dict(self) -> Dict[str, Any]:
@@ -161,6 +162,9 @@ class PoolingParams(
elif self.task == "encode":
if self.softmax is None:
self.softmax = True
elif self.task == "reward":
if self.normalize is None:
self.normalize = True
else:
raise ValueError(f"Unknown pooling task: {self.task}")

View File

@@ -729,3 +729,44 @@ class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]):
prompt_token_ids=request_output.prompt_token_ids,
finished=request_output.finished,
)
@dataclass
class RewardOutput:
"""The output data of one reward output of a request.
Args:
reward: The score, which is a list of floats.
Its length depends on the hidden dimension of the model.
"""
score: list[float]
@staticmethod
def from_base(pooling_output: PoolingOutput):
pooled_data = pooling_output.data
# if pooled_data.ndim != 1:
# raise ValueError("pooled_data should be a 1-D embedding vector")
if isinstance(pooled_data, list):
return RewardOutput(pooled_data)
return RewardOutput(pooled_data.tolist())
@property
def hidden_size(self) -> int:
return len(self.score)
def __repr__(self) -> str:
return f"RewardOutput(hidden_size={self.hidden_size})"
class RewardRequestOutput(PoolingRequestOutput[RewardOutput]):
@staticmethod
def from_base(request_output: PoolingRequestOutput):
return RewardRequestOutput(
request_id=request_output.request_id,
outputs=RewardOutput.from_base(request_output.outputs),
prompt_token_ids=request_output.prompt_token_ids,
finished=request_output.finished,
)

View File

@@ -19,7 +19,7 @@ from typing import Literal, get_args
GenerationTask = Literal["generate"]
GENERATION_TASKS = get_args(GenerationTask)
PoolingTask = Literal["encode", "embed"]
PoolingTask = Literal["encode", "embed", "reward"]
POOLING_TASKS = get_args(PoolingTask)
SupportedTask = Literal[GenerationTask, PoolingTask]

View File

@@ -41,6 +41,7 @@ from fastdeploy.entrypoints.engine_client import EngineClient
from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatRewardRequest,
CompletionRequest,
CompletionResponse,
ControlSchedulerRequest,
@@ -53,6 +54,7 @@ from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion
from fastdeploy.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from fastdeploy.entrypoints.openai.serving_models import ModelPath, OpenAIServingModels
from fastdeploy.entrypoints.openai.serving_reward import OpenAIServingReward
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
from fastdeploy.entrypoints.openai.utils import UVICORN_CONFIG, make_arg_parser
from fastdeploy.envs import environment_variables
@@ -232,12 +234,16 @@ async def lifespan(app: FastAPI):
args.max_waiting_time,
chat_template,
)
reward_handler = OpenAIServingReward(
engine_client, app.state.model_handler, config, pid, args.ips, args.max_waiting_time, chat_template
)
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
engine_client.pid = pid
app.state.engine_client = engine_client
app.state.chat_handler = chat_handler
app.state.completion_handler = completion_handler
app.state.embedding_handler = embedding_handler
app.state.reward_handler = reward_handler
global llm_engine
if llm_engine is not None:
llm_engine.engine.data_processor = engine_client.data_processor
@@ -447,6 +453,20 @@ async def list_models() -> Response:
return JSONResponse(content=models.model_dump())
@app.post("/v1/reward")
async def create_reward(request: ChatRewardRequest):
"""
Create reward for the input texts
"""
if app.state.dynamic_load_weight:
status, msg = app.state.engine_client.is_workers_alive()
if not status:
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
generator = await app.state.reward_handler.create_reward(request)
return JSONResponse(content=generator.model_dump())
@app.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest):
"""

View File

@@ -974,3 +974,89 @@ EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
PoolingCompletionRequest = EmbeddingCompletionRequest
PoolingChatRequest = EmbeddingChatRequest
class ChatRewardRequest(BaseModel):
model: Optional[str] = None # 指定模型,例如 "default" 或支持 embedding 的 chat 模型
messages: Union[List[Any], List[int]] # 聊天消息列表(必选)
user: Optional[str] = None # 调用方标识符
dimensions: Optional[int] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# --8<-- [start:chat-embedding-extra-params]
add_generation_prompt: bool = Field(
default=False,
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."
),
)
chat_template: Optional[str] = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
),
)
chat_template_kwargs: Optional[dict[str, Any]] = Field(
default=None,
description=(
"Additional keyword args to pass to the template renderer. " "Will be accessible by the chat template."
),
)
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
request_id: str = Field(
default_factory=lambda: f"{uuid.uuid4().hex}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a uuid.uuid4().hex will be generated. This id is used "
"through out the inference process and return in response."
),
)
normalize: Optional[bool] = None
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens, dimensions=self.dimensions, normalize=self.normalize
)
class ChatRewardData(BaseModel):
index: Optional[int] = None # 数据索引(可选)
object: str = "reward" # 固定为 "reward"
score: List[float] # reward 分数(浮点数列表)
class ChatRewardResponse(BaseModel):
id: str # 响应 ID例如 chat-reward-<uuid>
object: str = "object" # 固定为 "object"
created: int # 创建时间Unix 时间戳)
model: str # 使用的模型名
data: List[ChatRewardData] # reward 结果列表
usage: Optional[UsageInfo] = None # Token 使用情况

View File

@@ -0,0 +1,117 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from collections.abc import AsyncGenerator
from typing_extensions import override
from fastdeploy.engine.pooling_params import PoolingParams
from fastdeploy.engine.request import PoolingRequestOutput, RewardRequestOutput
from fastdeploy.entrypoints.openai.protocol import (
ChatRewardData,
ChatRewardRequest,
ChatRewardResponse,
UsageInfo,
)
from fastdeploy.entrypoints.openai.serving_engine import ServeContext, ZmqOpenAIServing
from fastdeploy.utils import api_server_logger
class OpenAIServingReward(ZmqOpenAIServing):
request_id_prefix = "reward"
"""
OpenAI-style reward serving using pipeline pattern
"""
def __init__(self, engine_client, models, cfg, pid, ips, max_waiting_time, chat_template):
super().__init__(engine_client, models, cfg, pid, ips, max_waiting_time, chat_template)
@override
def _request_to_dict(self, ctx: ServeContext):
request: ChatRewardRequest = ctx.request
request_dict = super()._request_to_dict(ctx)
if hasattr(request, "to_pooling_params"):
pooling_params: PoolingParams = request.to_pooling_params()
pooling_params.verify("reward", self.cfg.model_config)
request_dict["pooling_params"] = pooling_params.to_dict()
return request_dict
@override
def _request_to_batch_dicts(self, ctx: ServeContext):
"""
Convert the request into dictionary format that can be sent to the inference server
"""
request_dict = self._request_to_dict(ctx)
request_dict["request_id"] = f"{ctx.request_id}_0"
request_dicts = [request_dict]
return request_dicts
async def create_reward(self, request: ChatRewardRequest):
"""
Create embeddings for the input texts using the pipeline pattern
"""
request_id = self._generate_request_id(getattr(request, "user", None))
ctx = ServeContext[ChatRewardRequest](
request=request,
model_name=request.model,
request_id=request_id,
)
idx = 0
response: ChatRewardResponse = None
generators: AsyncGenerator[ChatRewardResponse, None] = self.handle(ctx)
async for r in generators:
r.data[0].index = idx
idx += 1
if response is None:
response = r
else:
response.data.append(r.data[0])
response.usage.prompt_tokens += r.usage.prompt_tokens
response.usage.total_tokens += r.usage.total_tokens
return response
@override
def _build_response(self, ctx: ServeContext):
"""Generate final reward response"""
api_server_logger.info(f"[{ctx.request_id}] Reward RequestOutput received:{ctx.request_output}")
base = PoolingRequestOutput.from_dict(ctx.request_output)
reward_res = RewardRequestOutput.from_base(base)
data = ChatRewardData(
index=0,
score=reward_res.outputs.score,
)
num_prompt_tokens = 0
if reward_res.prompt_token_ids:
num_prompt_tokens = len(reward_res.prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return ChatRewardResponse(
id=ctx.request_id,
created=ctx.created_time,
model=ctx.model_name,
data=[data],
usage=usage,
)

View File

@@ -125,6 +125,8 @@ class DealerConnectionManager:
request_id = response[-1]["request_id"]
if request_id[:4] in ["cmpl", "embd"]:
request_id = request_id.rsplit("_", 1)[0]
elif "reward" == request_id[:6]:
request_id = request_id.rsplit("_", 1)[0]
elif "chatcmpl" == request_id[:8]:
request_id = request_id.rsplit("_", 1)[0]
async with self.lock:

View File

@@ -0,0 +1,72 @@
import time
import unittest
from unittest.mock import AsyncMock, MagicMock
from fastdeploy.engine.request import (
PoolingOutput,
PoolingRequestOutput,
RequestMetrics,
)
from fastdeploy.entrypoints.openai.protocol import ChatRewardRequest, ChatRewardResponse
from fastdeploy.entrypoints.openai.serving_reward import OpenAIServingReward
class TestOpenAIServingReward(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.mock_engine_client = MagicMock()
self.mock_engine_client.semaphore.acquire = AsyncMock()
self.mock_engine_client.semaphore.release = MagicMock()
self.mock_engine_client.check_model_weight_status = AsyncMock(return_value=False)
mock_dealer = MagicMock()
mock_response_queue = MagicMock()
self.response_data: PoolingRequestOutput = PoolingRequestOutput(
request_id="test_request_id",
prompt_token_ids=[1, 2, 3],
finished=True,
outputs=PoolingOutput(data=[0.1, 0.2, 0.3]),
metrics=RequestMetrics(arrival_time=time.time()),
)
mock_response_queue.get = AsyncMock(
return_value=[
self.response_data.to_dict(),
]
)
self.mock_engine_client.connection_manager.get_connection = AsyncMock(
return_value=(mock_dealer, mock_response_queue)
)
self.mock_engine_client.connection_manager.cleanup_request = AsyncMock()
self.mock_engine_client.format_and_add_data = AsyncMock(return_value=[[1, 2, 3]])
models = MagicMock()
models.is_supported_model = MagicMock(return_value=(True, "ERNIE"))
pid = 123
ips = ["127.0.0.1"]
max_waiting_time = 30
chat_template = MagicMock()
cfg = MagicMock()
self.reward_service = OpenAIServingReward(
self.mock_engine_client, models, cfg, pid, ips, max_waiting_time, chat_template
)
async def test_create_reward_success(self):
# Setup
request = ChatRewardRequest(
model="text-reward-ada-002",
messages=[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
],
)
# Execute
result: ChatRewardResponse = await self.reward_service.create_reward(request)
# Assert
self.assertEqual(result.data[0].score, self.response_data.outputs.data)
if __name__ == "__main__":
unittest.main()