[Feature] Add AsyncTokenizerClient&ChatResponseProcessor with remote encode&decode support. (#3674)

* [Feature] add AsyncTokenizerClient

* add decode_image

* Add response_processors with remote decode support.

* [Feature] add tokenizer_base_url startup argument

* Revert comment removal and restore original content.

* [Feature] Non-streaming requests now support remote image decoding.

* Fix parameter type issue in decode_image call.

* Keep completion_token_ids when return_token_ids = False.

* add copyright
This commit is contained in:
SunLei
2025-08-30 17:06:26 +08:00
committed by GitHub
parent 9a7c231f2c
commit b9af95cf1c
13 changed files with 757 additions and 25 deletions

View File

@@ -0,0 +1,74 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
from fastdeploy.input.tokenzier_client import (
AsyncTokenizerClient,
ImageDecodeRequest,
ImageEncodeRequest,
VideoEncodeRequest,
)
async def main():
"""
测试AsyncTokenizerClient类
"""
base_url = "http://example.com/"
client = AsyncTokenizerClient(base_url=base_url)
# # 测试图片编码请求
image_encode_request = ImageEncodeRequest(
version="v1", req_id="req_image_001", is_gen=False, resolution=512, image_url="http://example.com/image.jpg"
)
image_encode_ret = await client.encode_image(image_encode_request)
print(f"Image encode result:{image_encode_ret}")
# 测试视频编码请求
video_encode_req = VideoEncodeRequest(
version="v1",
req_id="req_video_001",
video_url="http://example.com/video.mp4",
is_gen=False,
resolution=1024,
start_ts=0,
end_ts=5,
frames=1,
)
video_encode_result = await client.encode_video(video_encode_req)
print(f"Video Encode Result:{video_encode_result}")
# 测试图片解码请求
with open("./image_decode_demo.json", "r", encoding="utf-8") as file:
import json
import time
start_time = time.time()
start_process_time = time.process_time() # 记录开始时间
json_data = json.load(file)
image_decoding_request = ImageDecodeRequest(req_id="req_image_001", data=json_data.get("data"))
# import pdb; pdb.set_trace()
image_decode_result = await client.decode_image(image_decoding_request)
print(f"Image decode result:{image_decode_result}")
elapsed_time = time.time() - start_time
elapsed_process_time = time.process_time() - start_process_time
print(f"decode elapsed_time: {elapsed_time:.6f}s, elapsed_process_time: {elapsed_process_time:.6f}s")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -71,6 +71,10 @@ class EngineArgs:
"""
The name or path of the tokenizer (defaults to model path if not provided).
"""
tokenizer_base_url: str = None
"""
The base URL of the remote tokenizer service (used instead of local tokenizer if provided).
"""
max_model_len: int = 2048
"""
Maximum context length supported by the model.
@@ -426,6 +430,12 @@ class EngineArgs:
default=EngineArgs.tokenizer,
help="Tokenizer name or path (defaults to model path if not specified).",
)
model_group.add_argument(
"--tokenizer-base-url",
type=nullable_str,
default=EngineArgs.tokenizer_base_url,
help="The base URL of the remote tokenizer service (used instead of local tokenizer if provided).",
)
model_group.add_argument(
"--max-model-len",
type=int,

View File

@@ -77,6 +77,9 @@ parser.add_argument(
help="max waiting time for connection, if set value -1 means no waiting time limit",
)
parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency")
parser.add_argument(
"--enable-mm-output", action="store_true", help="Enable 'multimodal_content' field in response output. "
)
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
args.model = retrive_model_from_server(args.model, args.revision)
@@ -176,7 +179,14 @@ async def lifespan(app: FastAPI):
)
app.state.model_handler = model_handler
chat_handler = OpenAIServingChat(
engine_client, app.state.model_handler, pid, args.ips, args.max_waiting_time, chat_template
engine_client,
app.state.model_handler,
pid,
args.ips,
args.max_waiting_time,
chat_template,
args.enable_mm_output,
args.tokenizer_base_url,
)
completion_handler = OpenAIServingCompletion(
engine_client,

View File

@@ -163,8 +163,9 @@ class ChatMessage(BaseModel):
Chat message.
"""
role: str
content: str
role: Optional[str] = None
content: Optional[str] = None
multimodal_content: Optional[List[Any]] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
prompt_token_ids: Optional[List[int]] = None
@@ -226,6 +227,7 @@ class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
multimodal_content: Optional[List[Any]] = None
prompt_token_ids: Optional[List[int]] = None
completion_token_ids: Optional[List[int]] = None
reasoning_content: Optional[str] = None

View File

@@ -0,0 +1,145 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Any, List, Optional
from fastdeploy.input.tokenzier_client import AsyncTokenizerClient, ImageDecodeRequest
class ChatResponseProcessor:
"""
A decoder class to build multimodal content (text/image) from token_ids.
Attributes:
eoi_token_id: Token ID indicating the end of an image (<eoi>).
"""
def __init__(
self,
data_processor,
enable_mm_output: Optional[bool] = False,
eoi_token_id: Optional[int] = 101032,
eos_token_id: Optional[int] = 2,
decoder_base_url: Optional[str] = None,
):
self.data_processor = data_processor
self.enable_mm_output = enable_mm_output
self.eoi_token_id = eoi_token_id
self.eos_token_id = eos_token_id
if decoder_base_url is not None:
self.decoder_client = AsyncTokenizerClient(base_url=decoder_base_url)
self._mm_buffer: List[Any] = [] # Buffer for accumulating image token_ids
self._end_image_code_request_output: Optional[Any] = None
self._multipart_buffer = []
def enable_multimodal_content(self):
return self.enable_mm_output
def accumulate_token_ids(self, request_output):
decode_type = request_output["outputs"].get("decode_type", 0)
if not self._multipart_buffer:
self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output})
else:
last_part = self._multipart_buffer[-1]
if last_part["decode_type"] == decode_type:
last_token_ids = last_part["request_output"]["outputs"]["token_ids"]
last_token_ids.extend(request_output["outputs"]["token_ids"])
request_output["outputs"]["token_ids"] = last_token_ids
last_part["request_output"] = request_output
else:
self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output})
async def process_response_chat(self, request_outputs, stream, enable_thinking, include_stop_str_in_output):
"""
Process a list of responses into a generator that yields each processed response as it's generated.
Args:
request_outputs: The list of outputs to be processed.
stream: Whether or not to stream the output.
enable_thinking: Whether or not to show thinking messages.
include_stop_str_in_output: Whether or not to include stop strings in the output.
"""
for request_output in request_outputs:
if not self.enable_mm_output:
yield self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
elif stream:
decode_type = request_output["outputs"].get("decode_type", 0)
token_ids = request_output["outputs"]["token_ids"]
if decode_type == 0:
if self.eoi_token_id and self.eoi_token_id in token_ids:
if self._mm_buffer:
all_tokens = self._mm_buffer
self._mm_buffer = []
image = {"type": "image"}
if self.decoder_client:
req_id = request_output["request_id"]
image_ret = await self.decoder_client.decode_image(
request=ImageDecodeRequest(req_id=req_id, data=all_tokens)
)
image["url"] = image_ret["http_url"]
image_output = self._end_image_code_request_output
image_output["outputs"]["multipart"] = [image]
image_output["outputs"]["token_ids"] = all_tokens
yield image_output
self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
text = {"type": "text", "text": request_output["outputs"]["text"]}
request_output["outputs"]["multipart"] = [text]
yield request_output
elif decode_type == 1:
self._mm_buffer.extend(token_ids)
self._end_image_code_request_output = request_output
else:
self.accumulate_token_ids(request_output)
token_ids = request_output["outputs"]["token_ids"]
if token_ids[-1] == self.eos_token_id:
multipart = []
for part in self._multipart_buffer:
if part["decode_type"] == 0:
self.data_processor.process_response_dict(
response_dict=part["request_output"],
stream=False,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
text = {"type": "text", "text": part["request_output"]["outputs"]["text"]}
multipart.append(text)
elif part["decode_type"] == 1:
image = {"type": "image"}
if self.decoder_client:
req_id = part["request_output"]["request_id"]
all_tokens = part["request_output"]["outputs"]["token_ids"]
image_ret = await self.decoder_client.decode_image(
request=ImageDecodeRequest(req_id=req_id, data=all_tokens)
)
image["url"] = image_ret["http_url"]
multipart.append(image)
lasrt_request_output = self._multipart_buffer[-1]["request_output"]
lasrt_request_output["outputs"]["multipart"] = multipart
yield lasrt_request_output

View File

@@ -36,6 +36,7 @@ from fastdeploy.entrypoints.openai.protocol import (
PromptTokenUsageInfo,
UsageInfo,
)
from fastdeploy.entrypoints.openai.response_processors import ChatResponseProcessor
from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.utils import api_server_logger
from fastdeploy.worker.output import LogprobsLists
@@ -46,12 +47,24 @@ class OpenAIServingChat:
OpenAI-style chat completions serving
"""
def __init__(self, engine_client, models, pid, ips, max_waiting_time, chat_template):
def __init__(
self,
engine_client,
models,
pid,
ips,
max_waiting_time,
chat_template,
enable_mm_output: Optional[bool] = False,
tokenizer_base_url: Optional[str] = None,
):
self.engine_client = engine_client
self.models = models
self.pid = pid
self.max_waiting_time = max_waiting_time
self.chat_template = chat_template
self.enable_mm_output = enable_mm_output
self.tokenizer_base_url = tokenizer_base_url
if ips is not None:
if isinstance(ips, list):
self.master_ip = ips[0]
@@ -198,6 +211,11 @@ class OpenAIServingChat:
dealer.write([b"", request_id.encode("utf-8")])
choices = []
current_waiting_time = 0
response_processor = ChatResponseProcessor(
data_processor=self.engine_client.data_processor,
enable_mm_output=self.enable_mm_output,
decoder_base_url=self.tokenizer_base_url,
)
while num_choices > 0:
try:
response = await asyncio.wait_for(response_queue.get(), timeout=10)
@@ -215,17 +233,18 @@ class OpenAIServingChat:
current_waiting_time = 0
await asyncio.sleep(0.01)
continue
for res in response:
generator = response_processor.process_response_chat(
response,
stream=True,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
async for res in generator:
if res.get("error_code", 200) != 200:
raise ValueError("{}".format(res["error_msg"]))
self.engine_client.data_processor.process_response_dict(
res,
stream=True,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
if res["metrics"]["first_token_time"] is not None:
arrival_time = res["metrics"]["first_token_time"]
inference_start_time = res["metrics"]["inference_start_time"]
@@ -239,13 +258,22 @@ class OpenAIServingChat:
index=i,
delta=DeltaMessage(
role="assistant",
content="",
reasoning_content="",
tool_calls=None,
prompt_token_ids=None,
completion_token_ids=None,
),
)
if response_processor.enable_multimodal_content():
choice.delta.multimodal_content = [
{
"type": "text",
"text": "",
}
]
else:
choice.delta.content = ""
if request.return_token_ids:
choice.delta.prompt_token_ids = list(prompt_token_ids)
choice.delta.text_after_process = text_after_process
@@ -269,7 +297,6 @@ class OpenAIServingChat:
first_iteration = False
output = res["outputs"]
delta_text = output["text"]
output_top_logprobs = output["top_logprobs"]
previous_num_tokens += len(output["token_ids"])
logprobs_res: Optional[LogProbs] = None
@@ -279,12 +306,17 @@ class OpenAIServingChat:
)
delta_message = DeltaMessage(
content=delta_text,
reasoning_content="",
prompt_token_ids=None,
completion_token_ids=None,
tool_calls=None,
completion_token_ids=None,
)
if response_processor.enable_multimodal_content():
delta_message.multimodal_content = output["multipart"]
else:
delta_message.content = output["text"]
if not res["finished"] and "delta_message" in output:
delta_message_output = output["delta_message"]
if delta_message_output is None:
@@ -317,7 +349,10 @@ class OpenAIServingChat:
choice.finish_reason = "recover_stop"
if request.return_token_ids:
choice.delta.completion_token_ids = list(output["token_ids"])
if response_processor.enable_multimodal_content():
choice.delta.multimodal_content[0]["completion_token_ids"] = list(output["token_ids"])
else:
choice.delta.completion_token_ids = list(output["token_ids"])
choice.delta.raw_prediction = output.get("raw_prediction")
choice.delta.completion_tokens = output.get("raw_prediction")
if include_continuous_usage:
@@ -395,6 +430,11 @@ class OpenAIServingChat:
current_waiting_time = 0
logprob_contents = []
completion_token_ids = []
response_processor = ChatResponseProcessor(
data_processor=self.engine_client.data_processor,
enable_mm_output=self.enable_mm_output,
decoder_base_url=self.tokenizer_base_url,
)
while True:
try:
response = await asyncio.wait_for(response_queue.get(), timeout=10)
@@ -411,15 +451,16 @@ class OpenAIServingChat:
continue
task_is_finished = False
for data in response:
generator = response_processor.process_response_chat(
response,
stream=False,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
async for data in generator:
if data.get("error_code", 200) != 200:
raise ValueError("{}".format(data["error_msg"]))
data = self.engine_client.data_processor.process_response_dict(
data,
stream=False,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
# api_server_logger.debug(f"Client {request_id} received: {data}")
previous_num_tokens += len(data["outputs"]["token_ids"])
completion_token_ids.extend(data["outputs"]["token_ids"])
@@ -447,7 +488,6 @@ class OpenAIServingChat:
output = final_res["outputs"]
message = ChatMessage(
role="assistant",
content=output["text"],
reasoning_content=output.get("reasoning_content"),
tool_calls=output.get("tool_call"),
prompt_token_ids=prompt_token_ids if request.return_token_ids else None,
@@ -457,6 +497,12 @@ class OpenAIServingChat:
raw_prediction=output.get("raw_prediction") if request.return_token_ids else None,
completion_tokens=output.get("raw_prediction") if request.return_token_ids else None,
)
if response_processor.enable_multimodal_content():
message.multimodal_content = output.get("multipart")
else:
message.content = output["text"]
logprobs_full_res = None
if logprob_contents:
logprobs_full_res = LogProbs(content=logprob_contents)

View File

@@ -0,0 +1,163 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
from typing import Any, Optional, Union
import httpx
from pydantic import BaseModel, HttpUrl
from fastdeploy.utils import data_processor_logger
class BaseEncodeRequest(BaseModel):
version: str
req_id: str
is_gen: bool
resolution: int
class ImageEncodeRequest(BaseEncodeRequest):
image_url: Union[str, HttpUrl]
class VideoEncodeRequest(BaseEncodeRequest):
video_url: Union[str, HttpUrl]
start_ts: int
end_ts: int
frames: int
class ImageDecodeRequest(BaseModel):
req_id: str
data: list[Any]
class AsyncTokenizerClient:
def __init__(
self,
base_url: Optional[str] = None,
timeout: float = 5.0,
poll_interval: float = 0.5,
max_wait: float = 60.0,
):
"""
:param mode: 'local''remote'
:param base_url: 远程服务地址
:param timeout: 单次 HTTP 请求超时(秒)
:param poll_interval: 查询结果的轮询间隔(秒)
:param max_wait: 最大等待时间(秒)
"""
self.base_url = base_url
self.timeout = timeout
self.poll_interval = poll_interval
self.max_wait = max_wait
async def encode_image(self, request: ImageEncodeRequest):
return await self._async_encode_request("image", request.__dict__)
async def encode_video(self, request: VideoEncodeRequest):
return await self._async_encode_request("video", request.__dict__)
async def decode_image(self, request: ImageDecodeRequest):
return await self._async_decode_request("image", request.__dict__)
async def log_request(self, request):
data_processor_logger.debug(f">>> Request: {request.method} {request.url}")
data_processor_logger.debug(f">>> Headers: {request.headers}")
if request.content:
data_processor_logger.debug(f">>> Content: {request.content.decode('utf-8')}")
async def log_response(self, response):
data_processor_logger.debug(f"<<< Response status: {response.status_code}")
data_processor_logger.debug(f"<<< Headers: {response.headers}")
async def _async_encode_request(self, type: str, request: dict):
if not self.base_url:
raise ValueError("Missing base_url")
async with httpx.AsyncClient(
timeout=self.timeout, event_hooks={"request": [self.log_request], "response": [self.log_response]}
) as client:
req_id = request.get("req_id")
try:
url = None
if type == "image":
url = f"{self.base_url}/image/encode"
elif type == "video":
url = f"{self.base_url}/video/encode"
else:
raise ValueError("Invalid type")
resp = await client.post(url, json=request)
resp.raise_for_status()
except httpx.RequestError as e:
raise RuntimeError(f"Failed to create tokenize task: {e}") from e
task_info = resp.json()
if task_info.get("code") != 0:
raise RuntimeError(f"Tokenize task creation failed, {task_info.get('message')}")
task_tag = task_info.get("task_tag")
if not task_tag:
raise RuntimeError("No task_tag returned from server")
# 2. 轮询结果
start_time = asyncio.get_event_loop().time()
while True:
try:
r = await client.get(
f"{self.base_url}/encode/get", params={"task_tag": task_tag, "req_id": req_id}
)
r.raise_for_status()
data = r.json()
# 异步encode任务当前执行状态: Processing, Finished, Error
if data.get("state") == "Finished":
return data.get("result")
elif data.get("state") == "Error":
raise RuntimeError(f"Tokenize task failed: {data.get('message')}")
except httpx.RequestError:
# 网络问题时继续轮询
pass
# 超时检测
if asyncio.get_event_loop().time() - start_time > self.max_wait:
raise TimeoutError(f"Tokenize task {task_tag} timed out after {self.max_wait}s")
await asyncio.sleep(self.poll_interval)
async def _async_decode_request(self, type: str, request: dict):
if not self.base_url:
raise ValueError("Missing base_url")
async with httpx.AsyncClient(
timeout=self.timeout, event_hooks={"request": [self.log_request], "response": [self.log_response]}
) as client:
try:
url = None
if type == "image":
url = f"{self.base_url}/image/decode"
else:
raise ValueError("Invalid type")
resp = await client.post(url, json=request)
resp.raise_for_status()
if resp.json().get("code") != 0:
raise RuntimeError(f"Tokenize task creation failed, {resp.json().get('message')}")
return resp.json().get("result")
except httpx.RequestError as e:
raise RuntimeError(f"Failed to decode: {e}") from e

View File

@@ -0,0 +1,84 @@
import httpx
import pytest
import respx
from fastdeploy.input.tokenzier_client import (
AsyncTokenizerClient,
ImageEncodeRequest,
VideoEncodeRequest,
)
@pytest.mark.asyncio
@respx.mock
async def test_encode_image_success():
base_url = "http://testserver"
client = AsyncTokenizerClient(base_url=base_url)
# Mock 创建任务接口
respx.post(f"{base_url}/image/encode").mock(
return_value=httpx.Response(200, json={"code": 0, "task_tag": "task123"})
)
# Mock 轮询接口,返回完成状态
mock_get_ret = {
"state": "Finished",
"result": {"feature_url": "bos://host:port/key", "feature_shape": [80, 45, 1563]},
}
respx.get(f"{base_url}/encode/get").mock(return_value=httpx.Response(200, json=mock_get_ret))
request = ImageEncodeRequest(
version="v1", req_id="req_img_001", is_gen=False, resolution=512, image_url="http://example.com/image.jpg"
)
result = await client.encode_image(request)
assert result["feature_url"] == "bos://host:port/key"
assert result["feature_shape"] == [80, 45, 1563]
@pytest.mark.asyncio
@respx.mock
async def test_encode_video_failure():
base_url = "http://testserver"
client = AsyncTokenizerClient(base_url=base_url, max_wait=1)
respx.post(f"{base_url}/video/encode").mock(
return_value=httpx.Response(200, json={"code": 0, "task_tag": "task_vid_001"})
)
# 模拟轮询接口失败状态
respx.get(f"{base_url}/encode/get").mock(
return_value=httpx.Response(200, json={"state": "Error", "message": "Encode failed"})
)
request = VideoEncodeRequest(
version="v1",
req_id="req_vid_001",
is_gen=True,
resolution=720,
video_url="http://example.com/video.mp4",
start_ts=0.0,
end_ts=10.0,
frames=30,
)
with pytest.raises(RuntimeError, match="Encode failed"):
await client.encode_video(request)
@pytest.mark.asyncio
@respx.mock
async def test_encode_timeout():
base_url = "http://testserver"
client = AsyncTokenizerClient(base_url=base_url, max_wait=1, poll_interval=0.1)
respx.post(f"{base_url}/image/encode").mock(
return_value=httpx.Response(200, json={"code": 0, "task_tag": "task_timeout"})
)
# 模拟轮询接口一直返回等待状态,导致超时
respx.get(f"{base_url}/encode/get").mock(return_value=httpx.Response(200, json={"status": "processing"}))
request = ImageEncodeRequest(
version="v1", req_id="req_img_timeout", is_gen=False, resolution=256, image_url="http://example.com/image.jpg"
)
with pytest.raises(TimeoutError):
await client.encode_image(request)

View File

@@ -1,3 +1,19 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
from unittest.mock import MagicMock, patch

View File

@@ -1,3 +1,19 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
from unittest.mock import MagicMock, patch

View File

@@ -0,0 +1,134 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
from unittest.mock import AsyncMock, MagicMock
from fastdeploy.entrypoints.openai.response_processors import ChatResponseProcessor
class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.mock_data_processor = MagicMock()
self.mock_data_processor.process_response_dict = MagicMock(
side_effect=lambda response_dict, **_: {"processed": True, "raw": response_dict}
)
async def asyncSetUp(self):
self.processor_mm = ChatResponseProcessor(
data_processor=self.mock_data_processor,
enable_mm_output=True,
eoi_token_id=101032,
eos_token_id=2,
decoder_base_url="http://fake-decoder",
)
self.processor_mm.decoder_client.decode_image = AsyncMock(
return_value={"http_url": "http://image.url/test.png"}
)
async def test_text_only_mode(self):
"""不开启 multimodal 时,直接走 data_processor"""
processor = ChatResponseProcessor(self.mock_data_processor)
request_outputs = [{"outputs": {"text": "hello"}}]
results = [
r
async for r in processor.process_response_chat(
request_outputs, stream=False, enable_thinking=False, include_stop_str_in_output=False
)
]
self.mock_data_processor.process_response_dict.assert_called_once()
self.assertEqual(results[0]["processed"], True)
self.assertEqual(results[0]["raw"]["outputs"]["text"], "hello")
async def test_streaming_text_and_image(self):
"""流式模式下text → image → text"""
request_outputs = [
{"request_id": "req1", "outputs": {"decode_type": 0, "token_ids": [1], "text": "hi"}},
{"request_id": "req1", "outputs": {"decode_type": 1, "token_ids": [[11, 22]]}},
{"request_id": "req1", "outputs": {"decode_type": 0, "token_ids": [101032], "text": "done"}},
]
results = [
r
async for r in self.processor_mm.process_response_chat(
request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False
)
]
# 第一个 yieldtext
text_part = results[0]["outputs"]["multipart"][0]
self.assertEqual(text_part["type"], "text")
self.assertEqual(text_part["text"], "hi")
# 第二个 yieldimagetoken_ids 被拼起来了)
image_part = results[1]["outputs"]["multipart"][0]
self.assertEqual(image_part["type"], "image")
self.assertEqual(image_part["url"], "http://image.url/test.png")
self.assertEqual(results[1]["outputs"]["token_ids"], [[11, 22]])
# 第三个 yieldtext
text_part = results[2]["outputs"]["multipart"][0]
self.assertEqual(text_part["type"], "text")
self.assertEqual(text_part["text"], "done")
async def test_streaming_buffer_accumulation(self):
"""流式模式decode_type=1 只累积 buffer不 yield"""
request_outputs = [{"request_id": "req2", "outputs": {"decode_type": 1, "token_ids": [[33, 44]]}}]
results = [
r
async for r in self.processor_mm.process_response_chat(
request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False
)
]
self.assertEqual(results, [])
self.assertEqual(self.processor_mm._mm_buffer, [[33, 44]])
async def test_non_streaming_accumulate_and_emit(self):
"""非流式模式:等 eos_token_id 才输出 multiparttext+image"""
request_outputs = [
{"request_id": "req3", "outputs": {"decode_type": 0, "token_ids": [10], "text": "hello"}},
{"request_id": "req3", "outputs": {"decode_type": 1, "token_ids": [[55, 66]]}},
{"request_id": "req3", "outputs": {"decode_type": 0, "token_ids": [2], "text": "bye"}}, # eos_token_id
]
results = [
r
async for r in self.processor_mm.process_response_chat(
request_outputs, stream=False, enable_thinking=False, include_stop_str_in_output=False
)
]
# 只在最后一个输出 yield
self.assertEqual(len(results), 1)
multipart = results[0]["outputs"]["multipart"]
self.assertEqual(multipart[0]["type"], "text")
self.assertEqual(multipart[0]["text"], "hello")
self.assertEqual(multipart[1]["type"], "image")
self.assertEqual(multipart[1]["url"], "http://image.url/test.png")
self.assertEqual(multipart[2]["type"], "text")
self.assertEqual(multipart[2]["text"], "bye")
if __name__ == "__main__":
unittest.main()

View File

@@ -1,3 +1,19 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
from typing import List
from unittest.mock import Mock

View File

@@ -1,3 +1,19 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
import unittest