mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
[Feature] Add AsyncTokenizerClient&ChatResponseProcessor with remote encode&decode support. (#3674)
* [Feature] add AsyncTokenizerClient * add decode_image * Add response_processors with remote decode support. * [Feature] add tokenizer_base_url startup argument * Revert comment removal and restore original content. * [Feature] Non-streaming requests now support remote image decoding. * Fix parameter type issue in decode_image call. * Keep completion_token_ids when return_token_ids = False. * add copyright
This commit is contained in:
74
fastdeploy/demo/tokenzier_client_demo.py
Normal file
74
fastdeploy/demo/tokenzier_client_demo.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from fastdeploy.input.tokenzier_client import (
|
||||
AsyncTokenizerClient,
|
||||
ImageDecodeRequest,
|
||||
ImageEncodeRequest,
|
||||
VideoEncodeRequest,
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
测试AsyncTokenizerClient类
|
||||
"""
|
||||
base_url = "http://example.com/"
|
||||
|
||||
client = AsyncTokenizerClient(base_url=base_url)
|
||||
|
||||
# # 测试图片编码请求
|
||||
image_encode_request = ImageEncodeRequest(
|
||||
version="v1", req_id="req_image_001", is_gen=False, resolution=512, image_url="http://example.com/image.jpg"
|
||||
)
|
||||
|
||||
image_encode_ret = await client.encode_image(image_encode_request)
|
||||
print(f"Image encode result:{image_encode_ret}")
|
||||
|
||||
# 测试视频编码请求
|
||||
video_encode_req = VideoEncodeRequest(
|
||||
version="v1",
|
||||
req_id="req_video_001",
|
||||
video_url="http://example.com/video.mp4",
|
||||
is_gen=False,
|
||||
resolution=1024,
|
||||
start_ts=0,
|
||||
end_ts=5,
|
||||
frames=1,
|
||||
)
|
||||
video_encode_result = await client.encode_video(video_encode_req)
|
||||
print(f"Video Encode Result:{video_encode_result}")
|
||||
# 测试图片解码请求
|
||||
with open("./image_decode_demo.json", "r", encoding="utf-8") as file:
|
||||
import json
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
start_process_time = time.process_time() # 记录开始时间
|
||||
json_data = json.load(file)
|
||||
image_decoding_request = ImageDecodeRequest(req_id="req_image_001", data=json_data.get("data"))
|
||||
# import pdb; pdb.set_trace()
|
||||
image_decode_result = await client.decode_image(image_decoding_request)
|
||||
print(f"Image decode result:{image_decode_result}")
|
||||
elapsed_time = time.time() - start_time
|
||||
elapsed_process_time = time.process_time() - start_process_time
|
||||
print(f"decode elapsed_time: {elapsed_time:.6f}s, elapsed_process_time: {elapsed_process_time:.6f}s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
@@ -71,6 +71,10 @@ class EngineArgs:
|
||||
"""
|
||||
The name or path of the tokenizer (defaults to model path if not provided).
|
||||
"""
|
||||
tokenizer_base_url: str = None
|
||||
"""
|
||||
The base URL of the remote tokenizer service (used instead of local tokenizer if provided).
|
||||
"""
|
||||
max_model_len: int = 2048
|
||||
"""
|
||||
Maximum context length supported by the model.
|
||||
@@ -426,6 +430,12 @@ class EngineArgs:
|
||||
default=EngineArgs.tokenizer,
|
||||
help="Tokenizer name or path (defaults to model path if not specified).",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--tokenizer-base-url",
|
||||
type=nullable_str,
|
||||
default=EngineArgs.tokenizer_base_url,
|
||||
help="The base URL of the remote tokenizer service (used instead of local tokenizer if provided).",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--max-model-len",
|
||||
type=int,
|
||||
|
@@ -77,6 +77,9 @@ parser.add_argument(
|
||||
help="max waiting time for connection, if set value -1 means no waiting time limit",
|
||||
)
|
||||
parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency")
|
||||
parser.add_argument(
|
||||
"--enable-mm-output", action="store_true", help="Enable 'multimodal_content' field in response output. "
|
||||
)
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
args.model = retrive_model_from_server(args.model, args.revision)
|
||||
@@ -176,7 +179,14 @@ async def lifespan(app: FastAPI):
|
||||
)
|
||||
app.state.model_handler = model_handler
|
||||
chat_handler = OpenAIServingChat(
|
||||
engine_client, app.state.model_handler, pid, args.ips, args.max_waiting_time, chat_template
|
||||
engine_client,
|
||||
app.state.model_handler,
|
||||
pid,
|
||||
args.ips,
|
||||
args.max_waiting_time,
|
||||
chat_template,
|
||||
args.enable_mm_output,
|
||||
args.tokenizer_base_url,
|
||||
)
|
||||
completion_handler = OpenAIServingCompletion(
|
||||
engine_client,
|
||||
|
@@ -163,8 +163,9 @@ class ChatMessage(BaseModel):
|
||||
Chat message.
|
||||
"""
|
||||
|
||||
role: str
|
||||
content: str
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
multimodal_content: Optional[List[Any]] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
@@ -226,6 +227,7 @@ class DeltaMessage(BaseModel):
|
||||
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
multimodal_content: Optional[List[Any]] = None
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
completion_token_ids: Optional[List[int]] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
|
145
fastdeploy/entrypoints/openai/response_processors.py
Normal file
145
fastdeploy/entrypoints/openai/response_processors.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from fastdeploy.input.tokenzier_client import AsyncTokenizerClient, ImageDecodeRequest
|
||||
|
||||
|
||||
class ChatResponseProcessor:
|
||||
"""
|
||||
A decoder class to build multimodal content (text/image) from token_ids.
|
||||
|
||||
Attributes:
|
||||
eoi_token_id: Token ID indicating the end of an image (<eoi>).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_processor,
|
||||
enable_mm_output: Optional[bool] = False,
|
||||
eoi_token_id: Optional[int] = 101032,
|
||||
eos_token_id: Optional[int] = 2,
|
||||
decoder_base_url: Optional[str] = None,
|
||||
):
|
||||
self.data_processor = data_processor
|
||||
self.enable_mm_output = enable_mm_output
|
||||
self.eoi_token_id = eoi_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
if decoder_base_url is not None:
|
||||
self.decoder_client = AsyncTokenizerClient(base_url=decoder_base_url)
|
||||
self._mm_buffer: List[Any] = [] # Buffer for accumulating image token_ids
|
||||
self._end_image_code_request_output: Optional[Any] = None
|
||||
self._multipart_buffer = []
|
||||
|
||||
def enable_multimodal_content(self):
|
||||
return self.enable_mm_output
|
||||
|
||||
def accumulate_token_ids(self, request_output):
|
||||
decode_type = request_output["outputs"].get("decode_type", 0)
|
||||
|
||||
if not self._multipart_buffer:
|
||||
self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output})
|
||||
else:
|
||||
last_part = self._multipart_buffer[-1]
|
||||
|
||||
if last_part["decode_type"] == decode_type:
|
||||
last_token_ids = last_part["request_output"]["outputs"]["token_ids"]
|
||||
last_token_ids.extend(request_output["outputs"]["token_ids"])
|
||||
request_output["outputs"]["token_ids"] = last_token_ids
|
||||
last_part["request_output"] = request_output
|
||||
else:
|
||||
self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output})
|
||||
|
||||
async def process_response_chat(self, request_outputs, stream, enable_thinking, include_stop_str_in_output):
|
||||
"""
|
||||
Process a list of responses into a generator that yields each processed response as it's generated.
|
||||
Args:
|
||||
request_outputs: The list of outputs to be processed.
|
||||
stream: Whether or not to stream the output.
|
||||
enable_thinking: Whether or not to show thinking messages.
|
||||
include_stop_str_in_output: Whether or not to include stop strings in the output.
|
||||
"""
|
||||
for request_output in request_outputs:
|
||||
if not self.enable_mm_output:
|
||||
yield self.data_processor.process_response_dict(
|
||||
response_dict=request_output,
|
||||
stream=stream,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
)
|
||||
elif stream:
|
||||
decode_type = request_output["outputs"].get("decode_type", 0)
|
||||
token_ids = request_output["outputs"]["token_ids"]
|
||||
if decode_type == 0:
|
||||
if self.eoi_token_id and self.eoi_token_id in token_ids:
|
||||
if self._mm_buffer:
|
||||
all_tokens = self._mm_buffer
|
||||
self._mm_buffer = []
|
||||
image = {"type": "image"}
|
||||
if self.decoder_client:
|
||||
req_id = request_output["request_id"]
|
||||
image_ret = await self.decoder_client.decode_image(
|
||||
request=ImageDecodeRequest(req_id=req_id, data=all_tokens)
|
||||
)
|
||||
image["url"] = image_ret["http_url"]
|
||||
image_output = self._end_image_code_request_output
|
||||
image_output["outputs"]["multipart"] = [image]
|
||||
image_output["outputs"]["token_ids"] = all_tokens
|
||||
yield image_output
|
||||
|
||||
self.data_processor.process_response_dict(
|
||||
response_dict=request_output,
|
||||
stream=stream,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
)
|
||||
text = {"type": "text", "text": request_output["outputs"]["text"]}
|
||||
request_output["outputs"]["multipart"] = [text]
|
||||
yield request_output
|
||||
|
||||
elif decode_type == 1:
|
||||
self._mm_buffer.extend(token_ids)
|
||||
self._end_image_code_request_output = request_output
|
||||
else:
|
||||
self.accumulate_token_ids(request_output)
|
||||
token_ids = request_output["outputs"]["token_ids"]
|
||||
if token_ids[-1] == self.eos_token_id:
|
||||
multipart = []
|
||||
for part in self._multipart_buffer:
|
||||
if part["decode_type"] == 0:
|
||||
self.data_processor.process_response_dict(
|
||||
response_dict=part["request_output"],
|
||||
stream=False,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
)
|
||||
text = {"type": "text", "text": part["request_output"]["outputs"]["text"]}
|
||||
multipart.append(text)
|
||||
elif part["decode_type"] == 1:
|
||||
image = {"type": "image"}
|
||||
if self.decoder_client:
|
||||
req_id = part["request_output"]["request_id"]
|
||||
all_tokens = part["request_output"]["outputs"]["token_ids"]
|
||||
image_ret = await self.decoder_client.decode_image(
|
||||
request=ImageDecodeRequest(req_id=req_id, data=all_tokens)
|
||||
)
|
||||
image["url"] = image_ret["http_url"]
|
||||
multipart.append(image)
|
||||
|
||||
lasrt_request_output = self._multipart_buffer[-1]["request_output"]
|
||||
lasrt_request_output["outputs"]["multipart"] = multipart
|
||||
yield lasrt_request_output
|
@@ -36,6 +36,7 @@ from fastdeploy.entrypoints.openai.protocol import (
|
||||
PromptTokenUsageInfo,
|
||||
UsageInfo,
|
||||
)
|
||||
from fastdeploy.entrypoints.openai.response_processors import ChatResponseProcessor
|
||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||
from fastdeploy.utils import api_server_logger
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
@@ -46,12 +47,24 @@ class OpenAIServingChat:
|
||||
OpenAI-style chat completions serving
|
||||
"""
|
||||
|
||||
def __init__(self, engine_client, models, pid, ips, max_waiting_time, chat_template):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client,
|
||||
models,
|
||||
pid,
|
||||
ips,
|
||||
max_waiting_time,
|
||||
chat_template,
|
||||
enable_mm_output: Optional[bool] = False,
|
||||
tokenizer_base_url: Optional[str] = None,
|
||||
):
|
||||
self.engine_client = engine_client
|
||||
self.models = models
|
||||
self.pid = pid
|
||||
self.max_waiting_time = max_waiting_time
|
||||
self.chat_template = chat_template
|
||||
self.enable_mm_output = enable_mm_output
|
||||
self.tokenizer_base_url = tokenizer_base_url
|
||||
if ips is not None:
|
||||
if isinstance(ips, list):
|
||||
self.master_ip = ips[0]
|
||||
@@ -198,6 +211,11 @@ class OpenAIServingChat:
|
||||
dealer.write([b"", request_id.encode("utf-8")])
|
||||
choices = []
|
||||
current_waiting_time = 0
|
||||
response_processor = ChatResponseProcessor(
|
||||
data_processor=self.engine_client.data_processor,
|
||||
enable_mm_output=self.enable_mm_output,
|
||||
decoder_base_url=self.tokenizer_base_url,
|
||||
)
|
||||
while num_choices > 0:
|
||||
try:
|
||||
response = await asyncio.wait_for(response_queue.get(), timeout=10)
|
||||
@@ -215,17 +233,18 @@ class OpenAIServingChat:
|
||||
current_waiting_time = 0
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
for res in response:
|
||||
|
||||
generator = response_processor.process_response_chat(
|
||||
response,
|
||||
stream=True,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
)
|
||||
|
||||
async for res in generator:
|
||||
if res.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(res["error_msg"]))
|
||||
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
res,
|
||||
stream=True,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
)
|
||||
|
||||
if res["metrics"]["first_token_time"] is not None:
|
||||
arrival_time = res["metrics"]["first_token_time"]
|
||||
inference_start_time = res["metrics"]["inference_start_time"]
|
||||
@@ -239,13 +258,22 @@ class OpenAIServingChat:
|
||||
index=i,
|
||||
delta=DeltaMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
reasoning_content="",
|
||||
tool_calls=None,
|
||||
prompt_token_ids=None,
|
||||
completion_token_ids=None,
|
||||
),
|
||||
)
|
||||
if response_processor.enable_multimodal_content():
|
||||
choice.delta.multimodal_content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
}
|
||||
]
|
||||
else:
|
||||
choice.delta.content = ""
|
||||
|
||||
if request.return_token_ids:
|
||||
choice.delta.prompt_token_ids = list(prompt_token_ids)
|
||||
choice.delta.text_after_process = text_after_process
|
||||
@@ -269,7 +297,6 @@ class OpenAIServingChat:
|
||||
first_iteration = False
|
||||
|
||||
output = res["outputs"]
|
||||
delta_text = output["text"]
|
||||
output_top_logprobs = output["top_logprobs"]
|
||||
previous_num_tokens += len(output["token_ids"])
|
||||
logprobs_res: Optional[LogProbs] = None
|
||||
@@ -279,12 +306,17 @@ class OpenAIServingChat:
|
||||
)
|
||||
|
||||
delta_message = DeltaMessage(
|
||||
content=delta_text,
|
||||
reasoning_content="",
|
||||
prompt_token_ids=None,
|
||||
completion_token_ids=None,
|
||||
tool_calls=None,
|
||||
completion_token_ids=None,
|
||||
)
|
||||
|
||||
if response_processor.enable_multimodal_content():
|
||||
delta_message.multimodal_content = output["multipart"]
|
||||
else:
|
||||
delta_message.content = output["text"]
|
||||
|
||||
if not res["finished"] and "delta_message" in output:
|
||||
delta_message_output = output["delta_message"]
|
||||
if delta_message_output is None:
|
||||
@@ -317,7 +349,10 @@ class OpenAIServingChat:
|
||||
choice.finish_reason = "recover_stop"
|
||||
|
||||
if request.return_token_ids:
|
||||
choice.delta.completion_token_ids = list(output["token_ids"])
|
||||
if response_processor.enable_multimodal_content():
|
||||
choice.delta.multimodal_content[0]["completion_token_ids"] = list(output["token_ids"])
|
||||
else:
|
||||
choice.delta.completion_token_ids = list(output["token_ids"])
|
||||
choice.delta.raw_prediction = output.get("raw_prediction")
|
||||
choice.delta.completion_tokens = output.get("raw_prediction")
|
||||
if include_continuous_usage:
|
||||
@@ -395,6 +430,11 @@ class OpenAIServingChat:
|
||||
current_waiting_time = 0
|
||||
logprob_contents = []
|
||||
completion_token_ids = []
|
||||
response_processor = ChatResponseProcessor(
|
||||
data_processor=self.engine_client.data_processor,
|
||||
enable_mm_output=self.enable_mm_output,
|
||||
decoder_base_url=self.tokenizer_base_url,
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
response = await asyncio.wait_for(response_queue.get(), timeout=10)
|
||||
@@ -411,15 +451,16 @@ class OpenAIServingChat:
|
||||
continue
|
||||
|
||||
task_is_finished = False
|
||||
for data in response:
|
||||
|
||||
generator = response_processor.process_response_chat(
|
||||
response,
|
||||
stream=False,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
)
|
||||
async for data in generator:
|
||||
if data.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(data["error_msg"]))
|
||||
data = self.engine_client.data_processor.process_response_dict(
|
||||
data,
|
||||
stream=False,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
)
|
||||
# api_server_logger.debug(f"Client {request_id} received: {data}")
|
||||
previous_num_tokens += len(data["outputs"]["token_ids"])
|
||||
completion_token_ids.extend(data["outputs"]["token_ids"])
|
||||
@@ -447,7 +488,6 @@ class OpenAIServingChat:
|
||||
output = final_res["outputs"]
|
||||
message = ChatMessage(
|
||||
role="assistant",
|
||||
content=output["text"],
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
tool_calls=output.get("tool_call"),
|
||||
prompt_token_ids=prompt_token_ids if request.return_token_ids else None,
|
||||
@@ -457,6 +497,12 @@ class OpenAIServingChat:
|
||||
raw_prediction=output.get("raw_prediction") if request.return_token_ids else None,
|
||||
completion_tokens=output.get("raw_prediction") if request.return_token_ids else None,
|
||||
)
|
||||
|
||||
if response_processor.enable_multimodal_content():
|
||||
message.multimodal_content = output.get("multipart")
|
||||
else:
|
||||
message.content = output["text"]
|
||||
|
||||
logprobs_full_res = None
|
||||
if logprob_contents:
|
||||
logprobs_full_res = LogProbs(content=logprob_contents)
|
||||
|
163
fastdeploy/input/tokenzier_client.py
Normal file
163
fastdeploy/input/tokenzier_client.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
|
||||
class BaseEncodeRequest(BaseModel):
|
||||
version: str
|
||||
req_id: str
|
||||
is_gen: bool
|
||||
resolution: int
|
||||
|
||||
|
||||
class ImageEncodeRequest(BaseEncodeRequest):
|
||||
image_url: Union[str, HttpUrl]
|
||||
|
||||
|
||||
class VideoEncodeRequest(BaseEncodeRequest):
|
||||
video_url: Union[str, HttpUrl]
|
||||
start_ts: int
|
||||
end_ts: int
|
||||
frames: int
|
||||
|
||||
|
||||
class ImageDecodeRequest(BaseModel):
|
||||
req_id: str
|
||||
data: list[Any]
|
||||
|
||||
|
||||
class AsyncTokenizerClient:
|
||||
def __init__(
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
timeout: float = 5.0,
|
||||
poll_interval: float = 0.5,
|
||||
max_wait: float = 60.0,
|
||||
):
|
||||
"""
|
||||
:param mode: 'local' 或 'remote'
|
||||
:param base_url: 远程服务地址
|
||||
:param timeout: 单次 HTTP 请求超时(秒)
|
||||
:param poll_interval: 查询结果的轮询间隔(秒)
|
||||
:param max_wait: 最大等待时间(秒)
|
||||
"""
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
self.poll_interval = poll_interval
|
||||
self.max_wait = max_wait
|
||||
|
||||
async def encode_image(self, request: ImageEncodeRequest):
|
||||
return await self._async_encode_request("image", request.__dict__)
|
||||
|
||||
async def encode_video(self, request: VideoEncodeRequest):
|
||||
return await self._async_encode_request("video", request.__dict__)
|
||||
|
||||
async def decode_image(self, request: ImageDecodeRequest):
|
||||
return await self._async_decode_request("image", request.__dict__)
|
||||
|
||||
async def log_request(self, request):
|
||||
data_processor_logger.debug(f">>> Request: {request.method} {request.url}")
|
||||
data_processor_logger.debug(f">>> Headers: {request.headers}")
|
||||
if request.content:
|
||||
data_processor_logger.debug(f">>> Content: {request.content.decode('utf-8')}")
|
||||
|
||||
async def log_response(self, response):
|
||||
data_processor_logger.debug(f"<<< Response status: {response.status_code}")
|
||||
data_processor_logger.debug(f"<<< Headers: {response.headers}")
|
||||
|
||||
async def _async_encode_request(self, type: str, request: dict):
|
||||
if not self.base_url:
|
||||
raise ValueError("Missing base_url")
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
timeout=self.timeout, event_hooks={"request": [self.log_request], "response": [self.log_response]}
|
||||
) as client:
|
||||
req_id = request.get("req_id")
|
||||
try:
|
||||
url = None
|
||||
if type == "image":
|
||||
url = f"{self.base_url}/image/encode"
|
||||
elif type == "video":
|
||||
url = f"{self.base_url}/video/encode"
|
||||
else:
|
||||
raise ValueError("Invalid type")
|
||||
|
||||
resp = await client.post(url, json=request)
|
||||
resp.raise_for_status()
|
||||
except httpx.RequestError as e:
|
||||
raise RuntimeError(f"Failed to create tokenize task: {e}") from e
|
||||
|
||||
task_info = resp.json()
|
||||
if task_info.get("code") != 0:
|
||||
raise RuntimeError(f"Tokenize task creation failed, {task_info.get('message')}")
|
||||
|
||||
task_tag = task_info.get("task_tag")
|
||||
if not task_tag:
|
||||
raise RuntimeError("No task_tag returned from server")
|
||||
|
||||
# 2. 轮询结果
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
while True:
|
||||
try:
|
||||
r = await client.get(
|
||||
f"{self.base_url}/encode/get", params={"task_tag": task_tag, "req_id": req_id}
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
||||
# 异步encode任务当前执行状态: Processing, Finished, Error
|
||||
if data.get("state") == "Finished":
|
||||
return data.get("result")
|
||||
elif data.get("state") == "Error":
|
||||
raise RuntimeError(f"Tokenize task failed: {data.get('message')}")
|
||||
|
||||
except httpx.RequestError:
|
||||
# 网络问题时继续轮询
|
||||
pass
|
||||
|
||||
# 超时检测
|
||||
if asyncio.get_event_loop().time() - start_time > self.max_wait:
|
||||
raise TimeoutError(f"Tokenize task {task_tag} timed out after {self.max_wait}s")
|
||||
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
|
||||
async def _async_decode_request(self, type: str, request: dict):
|
||||
if not self.base_url:
|
||||
raise ValueError("Missing base_url")
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
timeout=self.timeout, event_hooks={"request": [self.log_request], "response": [self.log_response]}
|
||||
) as client:
|
||||
try:
|
||||
url = None
|
||||
if type == "image":
|
||||
url = f"{self.base_url}/image/decode"
|
||||
else:
|
||||
raise ValueError("Invalid type")
|
||||
resp = await client.post(url, json=request)
|
||||
resp.raise_for_status()
|
||||
if resp.json().get("code") != 0:
|
||||
raise RuntimeError(f"Tokenize task creation failed, {resp.json().get('message')}")
|
||||
return resp.json().get("result")
|
||||
except httpx.RequestError as e:
|
||||
raise RuntimeError(f"Failed to decode: {e}") from e
|
84
test/input/test_tokenizer_client.py
Normal file
84
test/input/test_tokenizer_client.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from fastdeploy.input.tokenzier_client import (
|
||||
AsyncTokenizerClient,
|
||||
ImageEncodeRequest,
|
||||
VideoEncodeRequest,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_encode_image_success():
|
||||
base_url = "http://testserver"
|
||||
client = AsyncTokenizerClient(base_url=base_url)
|
||||
|
||||
# Mock 创建任务接口
|
||||
respx.post(f"{base_url}/image/encode").mock(
|
||||
return_value=httpx.Response(200, json={"code": 0, "task_tag": "task123"})
|
||||
)
|
||||
# Mock 轮询接口,返回完成状态
|
||||
mock_get_ret = {
|
||||
"state": "Finished",
|
||||
"result": {"feature_url": "bos://host:port/key", "feature_shape": [80, 45, 1563]},
|
||||
}
|
||||
respx.get(f"{base_url}/encode/get").mock(return_value=httpx.Response(200, json=mock_get_ret))
|
||||
|
||||
request = ImageEncodeRequest(
|
||||
version="v1", req_id="req_img_001", is_gen=False, resolution=512, image_url="http://example.com/image.jpg"
|
||||
)
|
||||
|
||||
result = await client.encode_image(request)
|
||||
assert result["feature_url"] == "bos://host:port/key"
|
||||
assert result["feature_shape"] == [80, 45, 1563]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_encode_video_failure():
|
||||
base_url = "http://testserver"
|
||||
client = AsyncTokenizerClient(base_url=base_url, max_wait=1)
|
||||
|
||||
respx.post(f"{base_url}/video/encode").mock(
|
||||
return_value=httpx.Response(200, json={"code": 0, "task_tag": "task_vid_001"})
|
||||
)
|
||||
# 模拟轮询接口失败状态
|
||||
respx.get(f"{base_url}/encode/get").mock(
|
||||
return_value=httpx.Response(200, json={"state": "Error", "message": "Encode failed"})
|
||||
)
|
||||
|
||||
request = VideoEncodeRequest(
|
||||
version="v1",
|
||||
req_id="req_vid_001",
|
||||
is_gen=True,
|
||||
resolution=720,
|
||||
video_url="http://example.com/video.mp4",
|
||||
start_ts=0.0,
|
||||
end_ts=10.0,
|
||||
frames=30,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Encode failed"):
|
||||
await client.encode_video(request)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_encode_timeout():
|
||||
base_url = "http://testserver"
|
||||
client = AsyncTokenizerClient(base_url=base_url, max_wait=1, poll_interval=0.1)
|
||||
|
||||
respx.post(f"{base_url}/image/encode").mock(
|
||||
return_value=httpx.Response(200, json={"code": 0, "task_tag": "task_timeout"})
|
||||
)
|
||||
# 模拟轮询接口一直返回等待状态,导致超时
|
||||
respx.get(f"{base_url}/encode/get").mock(return_value=httpx.Response(200, json={"status": "processing"}))
|
||||
|
||||
request = ImageEncodeRequest(
|
||||
version="v1", req_id="req_img_timeout", is_gen=False, resolution=256, image_url="http://example.com/image.jpg"
|
||||
)
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
await client.encode_image(request)
|
@@ -1,3 +1,19 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
@@ -1,3 +1,19 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
134
tests/entrypoints/openai/test_response_processors.py
Normal file
134
tests/entrypoints/openai/test_response_processors.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from fastdeploy.entrypoints.openai.response_processors import ChatResponseProcessor
|
||||
|
||||
|
||||
class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_data_processor = MagicMock()
|
||||
self.mock_data_processor.process_response_dict = MagicMock(
|
||||
side_effect=lambda response_dict, **_: {"processed": True, "raw": response_dict}
|
||||
)
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.processor_mm = ChatResponseProcessor(
|
||||
data_processor=self.mock_data_processor,
|
||||
enable_mm_output=True,
|
||||
eoi_token_id=101032,
|
||||
eos_token_id=2,
|
||||
decoder_base_url="http://fake-decoder",
|
||||
)
|
||||
self.processor_mm.decoder_client.decode_image = AsyncMock(
|
||||
return_value={"http_url": "http://image.url/test.png"}
|
||||
)
|
||||
|
||||
async def test_text_only_mode(self):
|
||||
"""不开启 multimodal 时,直接走 data_processor"""
|
||||
processor = ChatResponseProcessor(self.mock_data_processor)
|
||||
request_outputs = [{"outputs": {"text": "hello"}}]
|
||||
|
||||
results = [
|
||||
r
|
||||
async for r in processor.process_response_chat(
|
||||
request_outputs, stream=False, enable_thinking=False, include_stop_str_in_output=False
|
||||
)
|
||||
]
|
||||
|
||||
self.mock_data_processor.process_response_dict.assert_called_once()
|
||||
self.assertEqual(results[0]["processed"], True)
|
||||
self.assertEqual(results[0]["raw"]["outputs"]["text"], "hello")
|
||||
|
||||
async def test_streaming_text_and_image(self):
|
||||
"""流式模式下:text → image → text"""
|
||||
request_outputs = [
|
||||
{"request_id": "req1", "outputs": {"decode_type": 0, "token_ids": [1], "text": "hi"}},
|
||||
{"request_id": "req1", "outputs": {"decode_type": 1, "token_ids": [[11, 22]]}},
|
||||
{"request_id": "req1", "outputs": {"decode_type": 0, "token_ids": [101032], "text": "done"}},
|
||||
]
|
||||
|
||||
results = [
|
||||
r
|
||||
async for r in self.processor_mm.process_response_chat(
|
||||
request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False
|
||||
)
|
||||
]
|
||||
|
||||
# 第一个 yield:text
|
||||
text_part = results[0]["outputs"]["multipart"][0]
|
||||
self.assertEqual(text_part["type"], "text")
|
||||
self.assertEqual(text_part["text"], "hi")
|
||||
|
||||
# 第二个 yield:image(token_ids 被拼起来了)
|
||||
image_part = results[1]["outputs"]["multipart"][0]
|
||||
self.assertEqual(image_part["type"], "image")
|
||||
self.assertEqual(image_part["url"], "http://image.url/test.png")
|
||||
self.assertEqual(results[1]["outputs"]["token_ids"], [[11, 22]])
|
||||
|
||||
# 第三个 yield:text
|
||||
text_part = results[2]["outputs"]["multipart"][0]
|
||||
self.assertEqual(text_part["type"], "text")
|
||||
self.assertEqual(text_part["text"], "done")
|
||||
|
||||
async def test_streaming_buffer_accumulation(self):
|
||||
"""流式模式:decode_type=1 只累积 buffer,不 yield"""
|
||||
request_outputs = [{"request_id": "req2", "outputs": {"decode_type": 1, "token_ids": [[33, 44]]}}]
|
||||
|
||||
results = [
|
||||
r
|
||||
async for r in self.processor_mm.process_response_chat(
|
||||
request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False
|
||||
)
|
||||
]
|
||||
|
||||
self.assertEqual(results, [])
|
||||
self.assertEqual(self.processor_mm._mm_buffer, [[33, 44]])
|
||||
|
||||
async def test_non_streaming_accumulate_and_emit(self):
|
||||
"""非流式模式:等 eos_token_id 才输出 multipart(text+image)"""
|
||||
request_outputs = [
|
||||
{"request_id": "req3", "outputs": {"decode_type": 0, "token_ids": [10], "text": "hello"}},
|
||||
{"request_id": "req3", "outputs": {"decode_type": 1, "token_ids": [[55, 66]]}},
|
||||
{"request_id": "req3", "outputs": {"decode_type": 0, "token_ids": [2], "text": "bye"}}, # eos_token_id
|
||||
]
|
||||
|
||||
results = [
|
||||
r
|
||||
async for r in self.processor_mm.process_response_chat(
|
||||
request_outputs, stream=False, enable_thinking=False, include_stop_str_in_output=False
|
||||
)
|
||||
]
|
||||
|
||||
# 只在最后一个输出 yield
|
||||
self.assertEqual(len(results), 1)
|
||||
multipart = results[0]["outputs"]["multipart"]
|
||||
|
||||
self.assertEqual(multipart[0]["type"], "text")
|
||||
self.assertEqual(multipart[0]["text"], "hello")
|
||||
|
||||
self.assertEqual(multipart[1]["type"], "image")
|
||||
self.assertEqual(multipart[1]["url"], "http://image.url/test.png")
|
||||
|
||||
self.assertEqual(multipart[2]["type"], "text")
|
||||
self.assertEqual(multipart[2]["text"], "bye")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@@ -1,3 +1,19 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from typing import List
|
||||
from unittest.mock import Mock
|
||||
|
@@ -1,3 +1,19 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
|
Reference in New Issue
Block a user