[Feature] Add AsyncTokenizerClient&ChatResponseProcessor with remote encode&decode support. (#3674)

* [Feature] add AsyncTokenizerClient

* add decode_image

* Add response_processors with remote decode support.

* [Feature] add tokenizer_base_url startup argument

* Revert comment removal and restore original content.

* [Feature] Non-streaming requests now support remote image decoding.

* Fix parameter type issue in decode_image call.

* Keep completion_token_ids when return_token_ids = False.

* add copyright
This commit is contained in:
SunLei
2025-08-30 17:06:26 +08:00
committed by GitHub
parent 9a7c231f2c
commit b9af95cf1c
13 changed files with 757 additions and 25 deletions

View File

@@ -0,0 +1,74 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
from fastdeploy.input.tokenzier_client import (
AsyncTokenizerClient,
ImageDecodeRequest,
ImageEncodeRequest,
VideoEncodeRequest,
)
async def main():
"""
测试AsyncTokenizerClient类
"""
base_url = "http://example.com/"
client = AsyncTokenizerClient(base_url=base_url)
# # 测试图片编码请求
image_encode_request = ImageEncodeRequest(
version="v1", req_id="req_image_001", is_gen=False, resolution=512, image_url="http://example.com/image.jpg"
)
image_encode_ret = await client.encode_image(image_encode_request)
print(f"Image encode result:{image_encode_ret}")
# 测试视频编码请求
video_encode_req = VideoEncodeRequest(
version="v1",
req_id="req_video_001",
video_url="http://example.com/video.mp4",
is_gen=False,
resolution=1024,
start_ts=0,
end_ts=5,
frames=1,
)
video_encode_result = await client.encode_video(video_encode_req)
print(f"Video Encode Result:{video_encode_result}")
# 测试图片解码请求
with open("./image_decode_demo.json", "r", encoding="utf-8") as file:
import json
import time
start_time = time.time()
start_process_time = time.process_time() # 记录开始时间
json_data = json.load(file)
image_decoding_request = ImageDecodeRequest(req_id="req_image_001", data=json_data.get("data"))
# import pdb; pdb.set_trace()
image_decode_result = await client.decode_image(image_decoding_request)
print(f"Image decode result:{image_decode_result}")
elapsed_time = time.time() - start_time
elapsed_process_time = time.process_time() - start_process_time
print(f"decode elapsed_time: {elapsed_time:.6f}s, elapsed_process_time: {elapsed_process_time:.6f}s")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -71,6 +71,10 @@ class EngineArgs:
""" """
The name or path of the tokenizer (defaults to model path if not provided). The name or path of the tokenizer (defaults to model path if not provided).
""" """
tokenizer_base_url: str = None
"""
The base URL of the remote tokenizer service (used instead of local tokenizer if provided).
"""
max_model_len: int = 2048 max_model_len: int = 2048
""" """
Maximum context length supported by the model. Maximum context length supported by the model.
@@ -426,6 +430,12 @@ class EngineArgs:
default=EngineArgs.tokenizer, default=EngineArgs.tokenizer,
help="Tokenizer name or path (defaults to model path if not specified).", help="Tokenizer name or path (defaults to model path if not specified).",
) )
model_group.add_argument(
"--tokenizer-base-url",
type=nullable_str,
default=EngineArgs.tokenizer_base_url,
help="The base URL of the remote tokenizer service (used instead of local tokenizer if provided).",
)
model_group.add_argument( model_group.add_argument(
"--max-model-len", "--max-model-len",
type=int, type=int,

View File

@@ -77,6 +77,9 @@ parser.add_argument(
help="max waiting time for connection, if set value -1 means no waiting time limit", help="max waiting time for connection, if set value -1 means no waiting time limit",
) )
parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency") parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency")
parser.add_argument(
"--enable-mm-output", action="store_true", help="Enable 'multimodal_content' field in response output. "
)
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
args.model = retrive_model_from_server(args.model, args.revision) args.model = retrive_model_from_server(args.model, args.revision)
@@ -176,7 +179,14 @@ async def lifespan(app: FastAPI):
) )
app.state.model_handler = model_handler app.state.model_handler = model_handler
chat_handler = OpenAIServingChat( chat_handler = OpenAIServingChat(
engine_client, app.state.model_handler, pid, args.ips, args.max_waiting_time, chat_template engine_client,
app.state.model_handler,
pid,
args.ips,
args.max_waiting_time,
chat_template,
args.enable_mm_output,
args.tokenizer_base_url,
) )
completion_handler = OpenAIServingCompletion( completion_handler = OpenAIServingCompletion(
engine_client, engine_client,

View File

@@ -163,8 +163,9 @@ class ChatMessage(BaseModel):
Chat message. Chat message.
""" """
role: str role: Optional[str] = None
content: str content: Optional[str] = None
multimodal_content: Optional[List[Any]] = None
reasoning_content: Optional[str] = None reasoning_content: Optional[str] = None
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
prompt_token_ids: Optional[List[int]] = None prompt_token_ids: Optional[List[int]] = None
@@ -226,6 +227,7 @@ class DeltaMessage(BaseModel):
role: Optional[str] = None role: Optional[str] = None
content: Optional[str] = None content: Optional[str] = None
multimodal_content: Optional[List[Any]] = None
prompt_token_ids: Optional[List[int]] = None prompt_token_ids: Optional[List[int]] = None
completion_token_ids: Optional[List[int]] = None completion_token_ids: Optional[List[int]] = None
reasoning_content: Optional[str] = None reasoning_content: Optional[str] = None

View File

@@ -0,0 +1,145 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Any, List, Optional
from fastdeploy.input.tokenzier_client import AsyncTokenizerClient, ImageDecodeRequest
class ChatResponseProcessor:
"""
A decoder class to build multimodal content (text/image) from token_ids.
Attributes:
eoi_token_id: Token ID indicating the end of an image (<eoi>).
"""
def __init__(
self,
data_processor,
enable_mm_output: Optional[bool] = False,
eoi_token_id: Optional[int] = 101032,
eos_token_id: Optional[int] = 2,
decoder_base_url: Optional[str] = None,
):
self.data_processor = data_processor
self.enable_mm_output = enable_mm_output
self.eoi_token_id = eoi_token_id
self.eos_token_id = eos_token_id
if decoder_base_url is not None:
self.decoder_client = AsyncTokenizerClient(base_url=decoder_base_url)
self._mm_buffer: List[Any] = [] # Buffer for accumulating image token_ids
self._end_image_code_request_output: Optional[Any] = None
self._multipart_buffer = []
def enable_multimodal_content(self):
return self.enable_mm_output
def accumulate_token_ids(self, request_output):
decode_type = request_output["outputs"].get("decode_type", 0)
if not self._multipart_buffer:
self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output})
else:
last_part = self._multipart_buffer[-1]
if last_part["decode_type"] == decode_type:
last_token_ids = last_part["request_output"]["outputs"]["token_ids"]
last_token_ids.extend(request_output["outputs"]["token_ids"])
request_output["outputs"]["token_ids"] = last_token_ids
last_part["request_output"] = request_output
else:
self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output})
async def process_response_chat(self, request_outputs, stream, enable_thinking, include_stop_str_in_output):
"""
Process a list of responses into a generator that yields each processed response as it's generated.
Args:
request_outputs: The list of outputs to be processed.
stream: Whether or not to stream the output.
enable_thinking: Whether or not to show thinking messages.
include_stop_str_in_output: Whether or not to include stop strings in the output.
"""
for request_output in request_outputs:
if not self.enable_mm_output:
yield self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
elif stream:
decode_type = request_output["outputs"].get("decode_type", 0)
token_ids = request_output["outputs"]["token_ids"]
if decode_type == 0:
if self.eoi_token_id and self.eoi_token_id in token_ids:
if self._mm_buffer:
all_tokens = self._mm_buffer
self._mm_buffer = []
image = {"type": "image"}
if self.decoder_client:
req_id = request_output["request_id"]
image_ret = await self.decoder_client.decode_image(
request=ImageDecodeRequest(req_id=req_id, data=all_tokens)
)
image["url"] = image_ret["http_url"]
image_output = self._end_image_code_request_output
image_output["outputs"]["multipart"] = [image]
image_output["outputs"]["token_ids"] = all_tokens
yield image_output
self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
text = {"type": "text", "text": request_output["outputs"]["text"]}
request_output["outputs"]["multipart"] = [text]
yield request_output
elif decode_type == 1:
self._mm_buffer.extend(token_ids)
self._end_image_code_request_output = request_output
else:
self.accumulate_token_ids(request_output)
token_ids = request_output["outputs"]["token_ids"]
if token_ids[-1] == self.eos_token_id:
multipart = []
for part in self._multipart_buffer:
if part["decode_type"] == 0:
self.data_processor.process_response_dict(
response_dict=part["request_output"],
stream=False,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
text = {"type": "text", "text": part["request_output"]["outputs"]["text"]}
multipart.append(text)
elif part["decode_type"] == 1:
image = {"type": "image"}
if self.decoder_client:
req_id = part["request_output"]["request_id"]
all_tokens = part["request_output"]["outputs"]["token_ids"]
image_ret = await self.decoder_client.decode_image(
request=ImageDecodeRequest(req_id=req_id, data=all_tokens)
)
image["url"] = image_ret["http_url"]
multipart.append(image)
lasrt_request_output = self._multipart_buffer[-1]["request_output"]
lasrt_request_output["outputs"]["multipart"] = multipart
yield lasrt_request_output

View File

@@ -36,6 +36,7 @@ from fastdeploy.entrypoints.openai.protocol import (
PromptTokenUsageInfo, PromptTokenUsageInfo,
UsageInfo, UsageInfo,
) )
from fastdeploy.entrypoints.openai.response_processors import ChatResponseProcessor
from fastdeploy.metrics.work_metrics import work_process_metrics from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.utils import api_server_logger from fastdeploy.utils import api_server_logger
from fastdeploy.worker.output import LogprobsLists from fastdeploy.worker.output import LogprobsLists
@@ -46,12 +47,24 @@ class OpenAIServingChat:
OpenAI-style chat completions serving OpenAI-style chat completions serving
""" """
def __init__(self, engine_client, models, pid, ips, max_waiting_time, chat_template): def __init__(
self,
engine_client,
models,
pid,
ips,
max_waiting_time,
chat_template,
enable_mm_output: Optional[bool] = False,
tokenizer_base_url: Optional[str] = None,
):
self.engine_client = engine_client self.engine_client = engine_client
self.models = models self.models = models
self.pid = pid self.pid = pid
self.max_waiting_time = max_waiting_time self.max_waiting_time = max_waiting_time
self.chat_template = chat_template self.chat_template = chat_template
self.enable_mm_output = enable_mm_output
self.tokenizer_base_url = tokenizer_base_url
if ips is not None: if ips is not None:
if isinstance(ips, list): if isinstance(ips, list):
self.master_ip = ips[0] self.master_ip = ips[0]
@@ -198,6 +211,11 @@ class OpenAIServingChat:
dealer.write([b"", request_id.encode("utf-8")]) dealer.write([b"", request_id.encode("utf-8")])
choices = [] choices = []
current_waiting_time = 0 current_waiting_time = 0
response_processor = ChatResponseProcessor(
data_processor=self.engine_client.data_processor,
enable_mm_output=self.enable_mm_output,
decoder_base_url=self.tokenizer_base_url,
)
while num_choices > 0: while num_choices > 0:
try: try:
response = await asyncio.wait_for(response_queue.get(), timeout=10) response = await asyncio.wait_for(response_queue.get(), timeout=10)
@@ -215,17 +233,18 @@ class OpenAIServingChat:
current_waiting_time = 0 current_waiting_time = 0
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
continue continue
for res in response:
generator = response_processor.process_response_chat(
response,
stream=True,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
async for res in generator:
if res.get("error_code", 200) != 200: if res.get("error_code", 200) != 200:
raise ValueError("{}".format(res["error_msg"])) raise ValueError("{}".format(res["error_msg"]))
self.engine_client.data_processor.process_response_dict(
res,
stream=True,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
if res["metrics"]["first_token_time"] is not None: if res["metrics"]["first_token_time"] is not None:
arrival_time = res["metrics"]["first_token_time"] arrival_time = res["metrics"]["first_token_time"]
inference_start_time = res["metrics"]["inference_start_time"] inference_start_time = res["metrics"]["inference_start_time"]
@@ -239,13 +258,22 @@ class OpenAIServingChat:
index=i, index=i,
delta=DeltaMessage( delta=DeltaMessage(
role="assistant", role="assistant",
content="",
reasoning_content="", reasoning_content="",
tool_calls=None, tool_calls=None,
prompt_token_ids=None, prompt_token_ids=None,
completion_token_ids=None, completion_token_ids=None,
), ),
) )
if response_processor.enable_multimodal_content():
choice.delta.multimodal_content = [
{
"type": "text",
"text": "",
}
]
else:
choice.delta.content = ""
if request.return_token_ids: if request.return_token_ids:
choice.delta.prompt_token_ids = list(prompt_token_ids) choice.delta.prompt_token_ids = list(prompt_token_ids)
choice.delta.text_after_process = text_after_process choice.delta.text_after_process = text_after_process
@@ -269,7 +297,6 @@ class OpenAIServingChat:
first_iteration = False first_iteration = False
output = res["outputs"] output = res["outputs"]
delta_text = output["text"]
output_top_logprobs = output["top_logprobs"] output_top_logprobs = output["top_logprobs"]
previous_num_tokens += len(output["token_ids"]) previous_num_tokens += len(output["token_ids"])
logprobs_res: Optional[LogProbs] = None logprobs_res: Optional[LogProbs] = None
@@ -279,12 +306,17 @@ class OpenAIServingChat:
) )
delta_message = DeltaMessage( delta_message = DeltaMessage(
content=delta_text,
reasoning_content="", reasoning_content="",
prompt_token_ids=None, prompt_token_ids=None,
completion_token_ids=None,
tool_calls=None, tool_calls=None,
completion_token_ids=None,
) )
if response_processor.enable_multimodal_content():
delta_message.multimodal_content = output["multipart"]
else:
delta_message.content = output["text"]
if not res["finished"] and "delta_message" in output: if not res["finished"] and "delta_message" in output:
delta_message_output = output["delta_message"] delta_message_output = output["delta_message"]
if delta_message_output is None: if delta_message_output is None:
@@ -317,7 +349,10 @@ class OpenAIServingChat:
choice.finish_reason = "recover_stop" choice.finish_reason = "recover_stop"
if request.return_token_ids: if request.return_token_ids:
choice.delta.completion_token_ids = list(output["token_ids"]) if response_processor.enable_multimodal_content():
choice.delta.multimodal_content[0]["completion_token_ids"] = list(output["token_ids"])
else:
choice.delta.completion_token_ids = list(output["token_ids"])
choice.delta.raw_prediction = output.get("raw_prediction") choice.delta.raw_prediction = output.get("raw_prediction")
choice.delta.completion_tokens = output.get("raw_prediction") choice.delta.completion_tokens = output.get("raw_prediction")
if include_continuous_usage: if include_continuous_usage:
@@ -395,6 +430,11 @@ class OpenAIServingChat:
current_waiting_time = 0 current_waiting_time = 0
logprob_contents = [] logprob_contents = []
completion_token_ids = [] completion_token_ids = []
response_processor = ChatResponseProcessor(
data_processor=self.engine_client.data_processor,
enable_mm_output=self.enable_mm_output,
decoder_base_url=self.tokenizer_base_url,
)
while True: while True:
try: try:
response = await asyncio.wait_for(response_queue.get(), timeout=10) response = await asyncio.wait_for(response_queue.get(), timeout=10)
@@ -411,15 +451,16 @@ class OpenAIServingChat:
continue continue
task_is_finished = False task_is_finished = False
for data in response:
generator = response_processor.process_response_chat(
response,
stream=False,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
async for data in generator:
if data.get("error_code", 200) != 200: if data.get("error_code", 200) != 200:
raise ValueError("{}".format(data["error_msg"])) raise ValueError("{}".format(data["error_msg"]))
data = self.engine_client.data_processor.process_response_dict(
data,
stream=False,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
# api_server_logger.debug(f"Client {request_id} received: {data}") # api_server_logger.debug(f"Client {request_id} received: {data}")
previous_num_tokens += len(data["outputs"]["token_ids"]) previous_num_tokens += len(data["outputs"]["token_ids"])
completion_token_ids.extend(data["outputs"]["token_ids"]) completion_token_ids.extend(data["outputs"]["token_ids"])
@@ -447,7 +488,6 @@ class OpenAIServingChat:
output = final_res["outputs"] output = final_res["outputs"]
message = ChatMessage( message = ChatMessage(
role="assistant", role="assistant",
content=output["text"],
reasoning_content=output.get("reasoning_content"), reasoning_content=output.get("reasoning_content"),
tool_calls=output.get("tool_call"), tool_calls=output.get("tool_call"),
prompt_token_ids=prompt_token_ids if request.return_token_ids else None, prompt_token_ids=prompt_token_ids if request.return_token_ids else None,
@@ -457,6 +497,12 @@ class OpenAIServingChat:
raw_prediction=output.get("raw_prediction") if request.return_token_ids else None, raw_prediction=output.get("raw_prediction") if request.return_token_ids else None,
completion_tokens=output.get("raw_prediction") if request.return_token_ids else None, completion_tokens=output.get("raw_prediction") if request.return_token_ids else None,
) )
if response_processor.enable_multimodal_content():
message.multimodal_content = output.get("multipart")
else:
message.content = output["text"]
logprobs_full_res = None logprobs_full_res = None
if logprob_contents: if logprob_contents:
logprobs_full_res = LogProbs(content=logprob_contents) logprobs_full_res = LogProbs(content=logprob_contents)

View File

@@ -0,0 +1,163 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
from typing import Any, Optional, Union
import httpx
from pydantic import BaseModel, HttpUrl
from fastdeploy.utils import data_processor_logger
class BaseEncodeRequest(BaseModel):
version: str
req_id: str
is_gen: bool
resolution: int
class ImageEncodeRequest(BaseEncodeRequest):
image_url: Union[str, HttpUrl]
class VideoEncodeRequest(BaseEncodeRequest):
video_url: Union[str, HttpUrl]
start_ts: int
end_ts: int
frames: int
class ImageDecodeRequest(BaseModel):
req_id: str
data: list[Any]
class AsyncTokenizerClient:
def __init__(
self,
base_url: Optional[str] = None,
timeout: float = 5.0,
poll_interval: float = 0.5,
max_wait: float = 60.0,
):
"""
:param mode: 'local''remote'
:param base_url: 远程服务地址
:param timeout: 单次 HTTP 请求超时(秒)
:param poll_interval: 查询结果的轮询间隔(秒)
:param max_wait: 最大等待时间(秒)
"""
self.base_url = base_url
self.timeout = timeout
self.poll_interval = poll_interval
self.max_wait = max_wait
async def encode_image(self, request: ImageEncodeRequest):
return await self._async_encode_request("image", request.__dict__)
async def encode_video(self, request: VideoEncodeRequest):
return await self._async_encode_request("video", request.__dict__)
async def decode_image(self, request: ImageDecodeRequest):
return await self._async_decode_request("image", request.__dict__)
async def log_request(self, request):
data_processor_logger.debug(f">>> Request: {request.method} {request.url}")
data_processor_logger.debug(f">>> Headers: {request.headers}")
if request.content:
data_processor_logger.debug(f">>> Content: {request.content.decode('utf-8')}")
async def log_response(self, response):
data_processor_logger.debug(f"<<< Response status: {response.status_code}")
data_processor_logger.debug(f"<<< Headers: {response.headers}")
async def _async_encode_request(self, type: str, request: dict):
if not self.base_url:
raise ValueError("Missing base_url")
async with httpx.AsyncClient(
timeout=self.timeout, event_hooks={"request": [self.log_request], "response": [self.log_response]}
) as client:
req_id = request.get("req_id")
try:
url = None
if type == "image":
url = f"{self.base_url}/image/encode"
elif type == "video":
url = f"{self.base_url}/video/encode"
else:
raise ValueError("Invalid type")
resp = await client.post(url, json=request)
resp.raise_for_status()
except httpx.RequestError as e:
raise RuntimeError(f"Failed to create tokenize task: {e}") from e
task_info = resp.json()
if task_info.get("code") != 0:
raise RuntimeError(f"Tokenize task creation failed, {task_info.get('message')}")
task_tag = task_info.get("task_tag")
if not task_tag:
raise RuntimeError("No task_tag returned from server")
# 2. 轮询结果
start_time = asyncio.get_event_loop().time()
while True:
try:
r = await client.get(
f"{self.base_url}/encode/get", params={"task_tag": task_tag, "req_id": req_id}
)
r.raise_for_status()
data = r.json()
# 异步encode任务当前执行状态: Processing, Finished, Error
if data.get("state") == "Finished":
return data.get("result")
elif data.get("state") == "Error":
raise RuntimeError(f"Tokenize task failed: {data.get('message')}")
except httpx.RequestError:
# 网络问题时继续轮询
pass
# 超时检测
if asyncio.get_event_loop().time() - start_time > self.max_wait:
raise TimeoutError(f"Tokenize task {task_tag} timed out after {self.max_wait}s")
await asyncio.sleep(self.poll_interval)
async def _async_decode_request(self, type: str, request: dict):
if not self.base_url:
raise ValueError("Missing base_url")
async with httpx.AsyncClient(
timeout=self.timeout, event_hooks={"request": [self.log_request], "response": [self.log_response]}
) as client:
try:
url = None
if type == "image":
url = f"{self.base_url}/image/decode"
else:
raise ValueError("Invalid type")
resp = await client.post(url, json=request)
resp.raise_for_status()
if resp.json().get("code") != 0:
raise RuntimeError(f"Tokenize task creation failed, {resp.json().get('message')}")
return resp.json().get("result")
except httpx.RequestError as e:
raise RuntimeError(f"Failed to decode: {e}") from e

View File

@@ -0,0 +1,84 @@
import httpx
import pytest
import respx
from fastdeploy.input.tokenzier_client import (
AsyncTokenizerClient,
ImageEncodeRequest,
VideoEncodeRequest,
)
@pytest.mark.asyncio
@respx.mock
async def test_encode_image_success():
base_url = "http://testserver"
client = AsyncTokenizerClient(base_url=base_url)
# Mock 创建任务接口
respx.post(f"{base_url}/image/encode").mock(
return_value=httpx.Response(200, json={"code": 0, "task_tag": "task123"})
)
# Mock 轮询接口,返回完成状态
mock_get_ret = {
"state": "Finished",
"result": {"feature_url": "bos://host:port/key", "feature_shape": [80, 45, 1563]},
}
respx.get(f"{base_url}/encode/get").mock(return_value=httpx.Response(200, json=mock_get_ret))
request = ImageEncodeRequest(
version="v1", req_id="req_img_001", is_gen=False, resolution=512, image_url="http://example.com/image.jpg"
)
result = await client.encode_image(request)
assert result["feature_url"] == "bos://host:port/key"
assert result["feature_shape"] == [80, 45, 1563]
@pytest.mark.asyncio
@respx.mock
async def test_encode_video_failure():
base_url = "http://testserver"
client = AsyncTokenizerClient(base_url=base_url, max_wait=1)
respx.post(f"{base_url}/video/encode").mock(
return_value=httpx.Response(200, json={"code": 0, "task_tag": "task_vid_001"})
)
# 模拟轮询接口失败状态
respx.get(f"{base_url}/encode/get").mock(
return_value=httpx.Response(200, json={"state": "Error", "message": "Encode failed"})
)
request = VideoEncodeRequest(
version="v1",
req_id="req_vid_001",
is_gen=True,
resolution=720,
video_url="http://example.com/video.mp4",
start_ts=0.0,
end_ts=10.0,
frames=30,
)
with pytest.raises(RuntimeError, match="Encode failed"):
await client.encode_video(request)
@pytest.mark.asyncio
@respx.mock
async def test_encode_timeout():
base_url = "http://testserver"
client = AsyncTokenizerClient(base_url=base_url, max_wait=1, poll_interval=0.1)
respx.post(f"{base_url}/image/encode").mock(
return_value=httpx.Response(200, json={"code": 0, "task_tag": "task_timeout"})
)
# 模拟轮询接口一直返回等待状态,导致超时
respx.get(f"{base_url}/encode/get").mock(return_value=httpx.Response(200, json={"status": "processing"}))
request = ImageEncodeRequest(
version="v1", req_id="req_img_timeout", is_gen=False, resolution=256, image_url="http://example.com/image.jpg"
)
with pytest.raises(TimeoutError):
await client.encode_image(request)

View File

@@ -1,3 +1,19 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch

View File

@@ -1,3 +1,19 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch

View File

@@ -0,0 +1,134 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
from unittest.mock import AsyncMock, MagicMock
from fastdeploy.entrypoints.openai.response_processors import ChatResponseProcessor
class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.mock_data_processor = MagicMock()
self.mock_data_processor.process_response_dict = MagicMock(
side_effect=lambda response_dict, **_: {"processed": True, "raw": response_dict}
)
async def asyncSetUp(self):
self.processor_mm = ChatResponseProcessor(
data_processor=self.mock_data_processor,
enable_mm_output=True,
eoi_token_id=101032,
eos_token_id=2,
decoder_base_url="http://fake-decoder",
)
self.processor_mm.decoder_client.decode_image = AsyncMock(
return_value={"http_url": "http://image.url/test.png"}
)
async def test_text_only_mode(self):
"""不开启 multimodal 时,直接走 data_processor"""
processor = ChatResponseProcessor(self.mock_data_processor)
request_outputs = [{"outputs": {"text": "hello"}}]
results = [
r
async for r in processor.process_response_chat(
request_outputs, stream=False, enable_thinking=False, include_stop_str_in_output=False
)
]
self.mock_data_processor.process_response_dict.assert_called_once()
self.assertEqual(results[0]["processed"], True)
self.assertEqual(results[0]["raw"]["outputs"]["text"], "hello")
async def test_streaming_text_and_image(self):
"""流式模式下text → image → text"""
request_outputs = [
{"request_id": "req1", "outputs": {"decode_type": 0, "token_ids": [1], "text": "hi"}},
{"request_id": "req1", "outputs": {"decode_type": 1, "token_ids": [[11, 22]]}},
{"request_id": "req1", "outputs": {"decode_type": 0, "token_ids": [101032], "text": "done"}},
]
results = [
r
async for r in self.processor_mm.process_response_chat(
request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False
)
]
# 第一个 yieldtext
text_part = results[0]["outputs"]["multipart"][0]
self.assertEqual(text_part["type"], "text")
self.assertEqual(text_part["text"], "hi")
# 第二个 yieldimagetoken_ids 被拼起来了)
image_part = results[1]["outputs"]["multipart"][0]
self.assertEqual(image_part["type"], "image")
self.assertEqual(image_part["url"], "http://image.url/test.png")
self.assertEqual(results[1]["outputs"]["token_ids"], [[11, 22]])
# 第三个 yieldtext
text_part = results[2]["outputs"]["multipart"][0]
self.assertEqual(text_part["type"], "text")
self.assertEqual(text_part["text"], "done")
async def test_streaming_buffer_accumulation(self):
"""流式模式decode_type=1 只累积 buffer不 yield"""
request_outputs = [{"request_id": "req2", "outputs": {"decode_type": 1, "token_ids": [[33, 44]]}}]
results = [
r
async for r in self.processor_mm.process_response_chat(
request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False
)
]
self.assertEqual(results, [])
self.assertEqual(self.processor_mm._mm_buffer, [[33, 44]])
async def test_non_streaming_accumulate_and_emit(self):
"""非流式模式:等 eos_token_id 才输出 multiparttext+image"""
request_outputs = [
{"request_id": "req3", "outputs": {"decode_type": 0, "token_ids": [10], "text": "hello"}},
{"request_id": "req3", "outputs": {"decode_type": 1, "token_ids": [[55, 66]]}},
{"request_id": "req3", "outputs": {"decode_type": 0, "token_ids": [2], "text": "bye"}}, # eos_token_id
]
results = [
r
async for r in self.processor_mm.process_response_chat(
request_outputs, stream=False, enable_thinking=False, include_stop_str_in_output=False
)
]
# 只在最后一个输出 yield
self.assertEqual(len(results), 1)
multipart = results[0]["outputs"]["multipart"]
self.assertEqual(multipart[0]["type"], "text")
self.assertEqual(multipart[0]["text"], "hello")
self.assertEqual(multipart[1]["type"], "image")
self.assertEqual(multipart[1]["url"], "http://image.url/test.png")
self.assertEqual(multipart[2]["type"], "text")
self.assertEqual(multipart[2]["text"], "bye")
if __name__ == "__main__":
unittest.main()

View File

@@ -1,3 +1,19 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest import unittest
from typing import List from typing import List
from unittest.mock import Mock from unittest.mock import Mock

View File

@@ -1,3 +1,19 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio import asyncio
import unittest import unittest