mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-02 12:44:20 +08:00
[feature] support reward api (#4518)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Co-authored-by: SunLei <sunlei5788@gmail.com>
This commit is contained in:
@@ -886,24 +886,27 @@ class EngineService:
|
|||||||
for request_id, contents in results.items():
|
for request_id, contents in results.items():
|
||||||
new_contents = []
|
new_contents = []
|
||||||
for content in contents:
|
for content in contents:
|
||||||
decode_type = content.outputs.decode_type
|
if isinstance(content, RequestOutput):
|
||||||
delta_text = ""
|
decode_type = content.outputs.decode_type
|
||||||
if decode_type == 0:
|
delta_text = ""
|
||||||
delta_text, token_ids = self._decode_token(
|
if decode_type == 0:
|
||||||
token_ids=content.outputs.token_ids, req_id=request_id, is_end=content.finished
|
delta_text, token_ids = self._decode_token(
|
||||||
)
|
token_ids=content.outputs.token_ids, req_id=request_id, is_end=content.finished
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
token_ids = content.outputs.token_ids
|
||||||
|
if len(token_ids):
|
||||||
|
content.outputs.token_ids = token_ids
|
||||||
|
content.outputs.text = delta_text
|
||||||
|
new_contents.append(content)
|
||||||
|
elif content.finished:
|
||||||
|
new_contents.append(content)
|
||||||
|
else:
|
||||||
|
llm_logger.warning(
|
||||||
|
f"current tokens need to accumulate, req_id: {request_id} {content.outputs.token_ids}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
token_ids = content.outputs.token_ids
|
|
||||||
if len(token_ids):
|
|
||||||
content.outputs.token_ids = token_ids
|
|
||||||
content.outputs.text = delta_text
|
|
||||||
new_contents.append(content)
|
new_contents.append(content)
|
||||||
elif content.finished:
|
|
||||||
new_contents.append(content)
|
|
||||||
else:
|
|
||||||
llm_logger.warning(
|
|
||||||
f"current tokens need to accumulate, req_id: {request_id} {content.outputs.token_ids}"
|
|
||||||
)
|
|
||||||
if len(new_contents):
|
if len(new_contents):
|
||||||
llm_logger.info(f"Send response for request id: {request_id}")
|
llm_logger.info(f"Send response for request id: {request_id}")
|
||||||
self.send_response_server.send_response(request_id, new_contents)
|
self.send_response_server.send_response(request_id, new_contents)
|
||||||
|
|||||||
@@ -86,6 +86,7 @@ class PoolingParams(
|
|||||||
return {
|
return {
|
||||||
"embed": ["dimensions", "normalize"],
|
"embed": ["dimensions", "normalize"],
|
||||||
"encode": ["softmax", "step_tag_id", "returned_token_ids"],
|
"encode": ["softmax", "step_tag_id", "returned_token_ids"],
|
||||||
|
"reward": ["dimensions", "normalize"],
|
||||||
}
|
}
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
@@ -161,6 +162,9 @@ class PoolingParams(
|
|||||||
elif self.task == "encode":
|
elif self.task == "encode":
|
||||||
if self.softmax is None:
|
if self.softmax is None:
|
||||||
self.softmax = True
|
self.softmax = True
|
||||||
|
elif self.task == "reward":
|
||||||
|
if self.normalize is None:
|
||||||
|
self.normalize = True
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown pooling task: {self.task}")
|
raise ValueError(f"Unknown pooling task: {self.task}")
|
||||||
|
|
||||||
|
|||||||
@@ -729,3 +729,44 @@ class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]):
|
|||||||
prompt_token_ids=request_output.prompt_token_ids,
|
prompt_token_ids=request_output.prompt_token_ids,
|
||||||
finished=request_output.finished,
|
finished=request_output.finished,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RewardOutput:
|
||||||
|
"""The output data of one reward output of a request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reward: The score, which is a list of floats.
|
||||||
|
Its length depends on the hidden dimension of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
score: list[float]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_base(pooling_output: PoolingOutput):
|
||||||
|
pooled_data = pooling_output.data
|
||||||
|
# if pooled_data.ndim != 1:
|
||||||
|
# raise ValueError("pooled_data should be a 1-D embedding vector")
|
||||||
|
|
||||||
|
if isinstance(pooled_data, list):
|
||||||
|
return RewardOutput(pooled_data)
|
||||||
|
|
||||||
|
return RewardOutput(pooled_data.tolist())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hidden_size(self) -> int:
|
||||||
|
return len(self.score)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"RewardOutput(hidden_size={self.hidden_size})"
|
||||||
|
|
||||||
|
|
||||||
|
class RewardRequestOutput(PoolingRequestOutput[RewardOutput]):
|
||||||
|
@staticmethod
|
||||||
|
def from_base(request_output: PoolingRequestOutput):
|
||||||
|
return RewardRequestOutput(
|
||||||
|
request_id=request_output.request_id,
|
||||||
|
outputs=RewardOutput.from_base(request_output.outputs),
|
||||||
|
prompt_token_ids=request_output.prompt_token_ids,
|
||||||
|
finished=request_output.finished,
|
||||||
|
)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from typing import Literal, get_args
|
|||||||
GenerationTask = Literal["generate"]
|
GenerationTask = Literal["generate"]
|
||||||
GENERATION_TASKS = get_args(GenerationTask)
|
GENERATION_TASKS = get_args(GenerationTask)
|
||||||
|
|
||||||
PoolingTask = Literal["encode", "embed"]
|
PoolingTask = Literal["encode", "embed", "reward"]
|
||||||
POOLING_TASKS = get_args(PoolingTask)
|
POOLING_TASKS = get_args(PoolingTask)
|
||||||
|
|
||||||
SupportedTask = Literal[GenerationTask, PoolingTask]
|
SupportedTask = Literal[GenerationTask, PoolingTask]
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ from fastdeploy.entrypoints.engine_client import EngineClient
|
|||||||
from fastdeploy.entrypoints.openai.protocol import (
|
from fastdeploy.entrypoints.openai.protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
ChatRewardRequest,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
ControlSchedulerRequest,
|
ControlSchedulerRequest,
|
||||||
@@ -53,6 +54,7 @@ from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
|
|||||||
from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
from fastdeploy.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
from fastdeploy.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
from fastdeploy.entrypoints.openai.serving_models import ModelPath, OpenAIServingModels
|
from fastdeploy.entrypoints.openai.serving_models import ModelPath, OpenAIServingModels
|
||||||
|
from fastdeploy.entrypoints.openai.serving_reward import OpenAIServingReward
|
||||||
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
|
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
from fastdeploy.entrypoints.openai.utils import UVICORN_CONFIG, make_arg_parser
|
from fastdeploy.entrypoints.openai.utils import UVICORN_CONFIG, make_arg_parser
|
||||||
from fastdeploy.envs import environment_variables
|
from fastdeploy.envs import environment_variables
|
||||||
@@ -232,12 +234,16 @@ async def lifespan(app: FastAPI):
|
|||||||
args.max_waiting_time,
|
args.max_waiting_time,
|
||||||
chat_template,
|
chat_template,
|
||||||
)
|
)
|
||||||
|
reward_handler = OpenAIServingReward(
|
||||||
|
engine_client, app.state.model_handler, config, pid, args.ips, args.max_waiting_time, chat_template
|
||||||
|
)
|
||||||
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
|
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
|
||||||
engine_client.pid = pid
|
engine_client.pid = pid
|
||||||
app.state.engine_client = engine_client
|
app.state.engine_client = engine_client
|
||||||
app.state.chat_handler = chat_handler
|
app.state.chat_handler = chat_handler
|
||||||
app.state.completion_handler = completion_handler
|
app.state.completion_handler = completion_handler
|
||||||
app.state.embedding_handler = embedding_handler
|
app.state.embedding_handler = embedding_handler
|
||||||
|
app.state.reward_handler = reward_handler
|
||||||
global llm_engine
|
global llm_engine
|
||||||
if llm_engine is not None:
|
if llm_engine is not None:
|
||||||
llm_engine.engine.data_processor = engine_client.data_processor
|
llm_engine.engine.data_processor = engine_client.data_processor
|
||||||
@@ -447,6 +453,20 @@ async def list_models() -> Response:
|
|||||||
return JSONResponse(content=models.model_dump())
|
return JSONResponse(content=models.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/reward")
|
||||||
|
async def create_reward(request: ChatRewardRequest):
|
||||||
|
"""
|
||||||
|
Create reward for the input texts
|
||||||
|
"""
|
||||||
|
if app.state.dynamic_load_weight:
|
||||||
|
status, msg = app.state.engine_client.is_workers_alive()
|
||||||
|
if not status:
|
||||||
|
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
|
||||||
|
|
||||||
|
generator = await app.state.reward_handler.create_reward(request)
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/embeddings")
|
@app.post("/v1/embeddings")
|
||||||
async def create_embedding(request: EmbeddingRequest):
|
async def create_embedding(request: EmbeddingRequest):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -974,3 +974,89 @@ EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
|
|||||||
|
|
||||||
PoolingCompletionRequest = EmbeddingCompletionRequest
|
PoolingCompletionRequest = EmbeddingCompletionRequest
|
||||||
PoolingChatRequest = EmbeddingChatRequest
|
PoolingChatRequest = EmbeddingChatRequest
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRewardRequest(BaseModel):
|
||||||
|
model: Optional[str] = None # 指定模型,例如 "default" 或支持 embedding 的 chat 模型
|
||||||
|
messages: Union[List[Any], List[int]] # 聊天消息列表(必选)
|
||||||
|
user: Optional[str] = None # 调用方标识符
|
||||||
|
|
||||||
|
dimensions: Optional[int] = None
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
|
||||||
|
|
||||||
|
# --8<-- [start:chat-embedding-extra-params]
|
||||||
|
add_generation_prompt: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"If true, the generation prompt will be added to the chat template. "
|
||||||
|
"This is a parameter used by chat template in tokenizer config of the "
|
||||||
|
"model."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
add_special_tokens: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||||
|
"on top of what is added by the chat template. "
|
||||||
|
"For most models, the chat template takes care of adding the "
|
||||||
|
"special tokens so this should be set to false (as is the "
|
||||||
|
"default)."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
chat_template: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"A Jinja template to use for this conversion. "
|
||||||
|
"As of transformers v4.44, default chat template is no longer "
|
||||||
|
"allowed, so you must provide a chat template if the tokenizer "
|
||||||
|
"does not define one."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
chat_template_kwargs: Optional[dict[str, Any]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Additional keyword args to pass to the template renderer. " "Will be accessible by the chat template."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=("Additional kwargs to pass to the HF processor."),
|
||||||
|
)
|
||||||
|
priority: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"The priority of the request (lower means earlier handling; "
|
||||||
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
|
"if the served model does not use priority scheduling."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
request_id: str = Field(
|
||||||
|
default_factory=lambda: f"{uuid.uuid4().hex}",
|
||||||
|
description=(
|
||||||
|
"The request_id related to this request. If the caller does "
|
||||||
|
"not set it, a uuid.uuid4().hex will be generated. This id is used "
|
||||||
|
"through out the inference process and return in response."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
normalize: Optional[bool] = None
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens, dimensions=self.dimensions, normalize=self.normalize
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRewardData(BaseModel):
|
||||||
|
index: Optional[int] = None # 数据索引(可选)
|
||||||
|
object: str = "reward" # 固定为 "reward"
|
||||||
|
score: List[float] # reward 分数(浮点数列表)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRewardResponse(BaseModel):
|
||||||
|
id: str # 响应 ID,例如 chat-reward-<uuid>
|
||||||
|
object: str = "object" # 固定为 "object"
|
||||||
|
created: int # 创建时间(Unix 时间戳)
|
||||||
|
model: str # 使用的模型名
|
||||||
|
data: List[ChatRewardData] # reward 结果列表
|
||||||
|
usage: Optional[UsageInfo] = None # Token 使用情况
|
||||||
|
|||||||
117
fastdeploy/entrypoints/openai/serving_reward.py
Normal file
117
fastdeploy/entrypoints/openai/serving_reward.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from fastdeploy.engine.pooling_params import PoolingParams
|
||||||
|
from fastdeploy.engine.request import PoolingRequestOutput, RewardRequestOutput
|
||||||
|
from fastdeploy.entrypoints.openai.protocol import (
|
||||||
|
ChatRewardData,
|
||||||
|
ChatRewardRequest,
|
||||||
|
ChatRewardResponse,
|
||||||
|
UsageInfo,
|
||||||
|
)
|
||||||
|
from fastdeploy.entrypoints.openai.serving_engine import ServeContext, ZmqOpenAIServing
|
||||||
|
from fastdeploy.utils import api_server_logger
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIServingReward(ZmqOpenAIServing):
|
||||||
|
request_id_prefix = "reward"
|
||||||
|
|
||||||
|
"""
|
||||||
|
OpenAI-style reward serving using pipeline pattern
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, engine_client, models, cfg, pid, ips, max_waiting_time, chat_template):
|
||||||
|
super().__init__(engine_client, models, cfg, pid, ips, max_waiting_time, chat_template)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _request_to_dict(self, ctx: ServeContext):
|
||||||
|
request: ChatRewardRequest = ctx.request
|
||||||
|
request_dict = super()._request_to_dict(ctx)
|
||||||
|
if hasattr(request, "to_pooling_params"):
|
||||||
|
pooling_params: PoolingParams = request.to_pooling_params()
|
||||||
|
pooling_params.verify("reward", self.cfg.model_config)
|
||||||
|
request_dict["pooling_params"] = pooling_params.to_dict()
|
||||||
|
return request_dict
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _request_to_batch_dicts(self, ctx: ServeContext):
|
||||||
|
"""
|
||||||
|
Convert the request into dictionary format that can be sent to the inference server
|
||||||
|
"""
|
||||||
|
request_dict = self._request_to_dict(ctx)
|
||||||
|
request_dict["request_id"] = f"{ctx.request_id}_0"
|
||||||
|
request_dicts = [request_dict]
|
||||||
|
return request_dicts
|
||||||
|
|
||||||
|
async def create_reward(self, request: ChatRewardRequest):
|
||||||
|
"""
|
||||||
|
Create embeddings for the input texts using the pipeline pattern
|
||||||
|
"""
|
||||||
|
request_id = self._generate_request_id(getattr(request, "user", None))
|
||||||
|
|
||||||
|
ctx = ServeContext[ChatRewardRequest](
|
||||||
|
request=request,
|
||||||
|
model_name=request.model,
|
||||||
|
request_id=request_id,
|
||||||
|
)
|
||||||
|
idx = 0
|
||||||
|
response: ChatRewardResponse = None
|
||||||
|
generators: AsyncGenerator[ChatRewardResponse, None] = self.handle(ctx)
|
||||||
|
async for r in generators:
|
||||||
|
r.data[0].index = idx
|
||||||
|
idx += 1
|
||||||
|
if response is None:
|
||||||
|
response = r
|
||||||
|
else:
|
||||||
|
response.data.append(r.data[0])
|
||||||
|
response.usage.prompt_tokens += r.usage.prompt_tokens
|
||||||
|
response.usage.total_tokens += r.usage.total_tokens
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _build_response(self, ctx: ServeContext):
|
||||||
|
"""Generate final reward response"""
|
||||||
|
api_server_logger.info(f"[{ctx.request_id}] Reward RequestOutput received:{ctx.request_output}")
|
||||||
|
|
||||||
|
base = PoolingRequestOutput.from_dict(ctx.request_output)
|
||||||
|
reward_res = RewardRequestOutput.from_base(base)
|
||||||
|
|
||||||
|
data = ChatRewardData(
|
||||||
|
index=0,
|
||||||
|
score=reward_res.outputs.score,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_prompt_tokens = 0
|
||||||
|
if reward_res.prompt_token_ids:
|
||||||
|
num_prompt_tokens = len(reward_res.prompt_token_ids)
|
||||||
|
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
total_tokens=num_prompt_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatRewardResponse(
|
||||||
|
id=ctx.request_id,
|
||||||
|
created=ctx.created_time,
|
||||||
|
model=ctx.model_name,
|
||||||
|
data=[data],
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
@@ -125,6 +125,8 @@ class DealerConnectionManager:
|
|||||||
request_id = response[-1]["request_id"]
|
request_id = response[-1]["request_id"]
|
||||||
if request_id[:4] in ["cmpl", "embd"]:
|
if request_id[:4] in ["cmpl", "embd"]:
|
||||||
request_id = request_id.rsplit("_", 1)[0]
|
request_id = request_id.rsplit("_", 1)[0]
|
||||||
|
elif "reward" == request_id[:6]:
|
||||||
|
request_id = request_id.rsplit("_", 1)[0]
|
||||||
elif "chatcmpl" == request_id[:8]:
|
elif "chatcmpl" == request_id[:8]:
|
||||||
request_id = request_id.rsplit("_", 1)[0]
|
request_id = request_id.rsplit("_", 1)[0]
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
|
|||||||
72
tests/entrypoints/openai/test_serving_reward.py
Normal file
72
tests/entrypoints/openai/test_serving_reward.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from fastdeploy.engine.request import (
|
||||||
|
PoolingOutput,
|
||||||
|
PoolingRequestOutput,
|
||||||
|
RequestMetrics,
|
||||||
|
)
|
||||||
|
from fastdeploy.entrypoints.openai.protocol import ChatRewardRequest, ChatRewardResponse
|
||||||
|
from fastdeploy.entrypoints.openai.serving_reward import OpenAIServingReward
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIServingReward(unittest.IsolatedAsyncioTestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.mock_engine_client = MagicMock()
|
||||||
|
self.mock_engine_client.semaphore.acquire = AsyncMock()
|
||||||
|
self.mock_engine_client.semaphore.release = MagicMock()
|
||||||
|
|
||||||
|
self.mock_engine_client.check_model_weight_status = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
mock_dealer = MagicMock()
|
||||||
|
mock_response_queue = MagicMock()
|
||||||
|
self.response_data: PoolingRequestOutput = PoolingRequestOutput(
|
||||||
|
request_id="test_request_id",
|
||||||
|
prompt_token_ids=[1, 2, 3],
|
||||||
|
finished=True,
|
||||||
|
outputs=PoolingOutput(data=[0.1, 0.2, 0.3]),
|
||||||
|
metrics=RequestMetrics(arrival_time=time.time()),
|
||||||
|
)
|
||||||
|
mock_response_queue.get = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
self.response_data.to_dict(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mock_engine_client.connection_manager.get_connection = AsyncMock(
|
||||||
|
return_value=(mock_dealer, mock_response_queue)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mock_engine_client.connection_manager.cleanup_request = AsyncMock()
|
||||||
|
self.mock_engine_client.format_and_add_data = AsyncMock(return_value=[[1, 2, 3]])
|
||||||
|
models = MagicMock()
|
||||||
|
models.is_supported_model = MagicMock(return_value=(True, "ERNIE"))
|
||||||
|
pid = 123
|
||||||
|
ips = ["127.0.0.1"]
|
||||||
|
max_waiting_time = 30
|
||||||
|
chat_template = MagicMock()
|
||||||
|
cfg = MagicMock()
|
||||||
|
self.reward_service = OpenAIServingReward(
|
||||||
|
self.mock_engine_client, models, cfg, pid, ips, max_waiting_time, chat_template
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_create_reward_success(self):
|
||||||
|
# Setup
|
||||||
|
request = ChatRewardRequest(
|
||||||
|
model="text-reward-ada-002",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result: ChatRewardResponse = await self.reward_service.create_reward(request)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
self.assertEqual(result.data[0].score, self.response_data.outputs.data)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user