From b9af95cf1c267c21499abbfa81a1a14682be7e23 Mon Sep 17 00:00:00 2001 From: SunLei Date: Sat, 30 Aug 2025 17:06:26 +0800 Subject: [PATCH] [Feature] Add AsyncTokenizerClient&ChatResponseProcessor with remote encode&decode support. (#3674) * [Feature] add AsyncTokenizerClient * add decode_image * Add response_processors with remote decode support. * [Feature] add tokenizer_base_url startup argument * Revert comment removal and restore original content. * [Feature] Non-streaming requests now support remote image decoding. * Fix parameter type issue in decode_image call. * Keep completion_token_ids when return_token_ids = False. * add copyright --- fastdeploy/demo/tokenzier_client_demo.py | 74 ++++++++ fastdeploy/engine/args_utils.py | 10 ++ fastdeploy/entrypoints/openai/api_server.py | 12 +- fastdeploy/entrypoints/openai/protocol.py | 6 +- .../entrypoints/openai/response_processors.py | 145 ++++++++++++++++ fastdeploy/entrypoints/openai/serving_chat.py | 90 +++++++--- fastdeploy/input/tokenzier_client.py | 163 ++++++++++++++++++ test/input/test_tokenizer_client.py | 84 +++++++++ .../openai/test_build_sample_logprobs.py | 16 ++ .../openai/test_completion_echo.py | 16 ++ .../openai/test_response_processors.py | 134 ++++++++++++++ .../openai/test_serving_completion.py | 16 ++ .../entrypoints/openai/test_serving_models.py | 16 ++ 13 files changed, 757 insertions(+), 25 deletions(-) create mode 100644 fastdeploy/demo/tokenzier_client_demo.py create mode 100644 fastdeploy/entrypoints/openai/response_processors.py create mode 100644 fastdeploy/input/tokenzier_client.py create mode 100644 test/input/test_tokenizer_client.py create mode 100644 tests/entrypoints/openai/test_response_processors.py diff --git a/fastdeploy/demo/tokenzier_client_demo.py b/fastdeploy/demo/tokenzier_client_demo.py new file mode 100644 index 000000000..0f4ba36c8 --- /dev/null +++ b/fastdeploy/demo/tokenzier_client_demo.py @@ -0,0 +1,74 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import asyncio + +from fastdeploy.input.tokenzier_client import ( + AsyncTokenizerClient, + ImageDecodeRequest, + ImageEncodeRequest, + VideoEncodeRequest, +) + + +async def main(): + """ + 测试AsyncTokenizerClient类 + """ + base_url = "http://example.com/" + + client = AsyncTokenizerClient(base_url=base_url) + + # # 测试图片编码请求 + image_encode_request = ImageEncodeRequest( + version="v1", req_id="req_image_001", is_gen=False, resolution=512, image_url="http://example.com/image.jpg" + ) + + image_encode_ret = await client.encode_image(image_encode_request) + print(f"Image encode result:{image_encode_ret}") + + # 测试视频编码请求 + video_encode_req = VideoEncodeRequest( + version="v1", + req_id="req_video_001", + video_url="http://example.com/video.mp4", + is_gen=False, + resolution=1024, + start_ts=0, + end_ts=5, + frames=1, + ) + video_encode_result = await client.encode_video(video_encode_req) + print(f"Video Encode Result:{video_encode_result}") + # 测试图片解码请求 + with open("./image_decode_demo.json", "r", encoding="utf-8") as file: + import json + import time + + start_time = time.time() + start_process_time = time.process_time() # 记录开始时间 + json_data = json.load(file) + image_decoding_request = ImageDecodeRequest(req_id="req_image_001", data=json_data.get("data")) + # import pdb; pdb.set_trace() + image_decode_result = await client.decode_image(image_decoding_request) + print(f"Image decode result:{image_decode_result}") + elapsed_time = time.time() - start_time + elapsed_process_time = time.process_time() - start_process_time + print(f"decode elapsed_time: {elapsed_time:.6f}s, elapsed_process_time: {elapsed_process_time:.6f}s") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 3cd8bda97..24d3f5284 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -71,6 +71,10 @@ class EngineArgs: """ The name or path of the tokenizer (defaults to model path if not provided). """ + tokenizer_base_url: str = None + """ + The base URL of the remote tokenizer service (used instead of local tokenizer if provided). + """ max_model_len: int = 2048 """ Maximum context length supported by the model. @@ -426,6 +430,12 @@ class EngineArgs: default=EngineArgs.tokenizer, help="Tokenizer name or path (defaults to model path if not specified).", ) + model_group.add_argument( + "--tokenizer-base-url", + type=nullable_str, + default=EngineArgs.tokenizer_base_url, + help="The base URL of the remote tokenizer service (used instead of local tokenizer if provided).", + ) model_group.add_argument( "--max-model-len", type=int, diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index ede64ad11..4764ad9c7 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -77,6 +77,9 @@ parser.add_argument( help="max waiting time for connection, if set value -1 means no waiting time limit", ) parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency") +parser.add_argument( + "--enable-mm-output", action="store_true", help="Enable 'multimodal_content' field in response output. " +) parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() args.model = retrive_model_from_server(args.model, args.revision) @@ -176,7 +179,14 @@ async def lifespan(app: FastAPI): ) app.state.model_handler = model_handler chat_handler = OpenAIServingChat( - engine_client, app.state.model_handler, pid, args.ips, args.max_waiting_time, chat_template + engine_client, + app.state.model_handler, + pid, + args.ips, + args.max_waiting_time, + chat_template, + args.enable_mm_output, + args.tokenizer_base_url, ) completion_handler = OpenAIServingCompletion( engine_client, diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index a2ca85ddf..b74e0ffb4 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -163,8 +163,9 @@ class ChatMessage(BaseModel): Chat message. """ - role: str - content: str + role: Optional[str] = None + content: Optional[str] = None + multimodal_content: Optional[List[Any]] = None reasoning_content: Optional[str] = None tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None prompt_token_ids: Optional[List[int]] = None @@ -226,6 +227,7 @@ class DeltaMessage(BaseModel): role: Optional[str] = None content: Optional[str] = None + multimodal_content: Optional[List[Any]] = None prompt_token_ids: Optional[List[int]] = None completion_token_ids: Optional[List[int]] = None reasoning_content: Optional[str] = None diff --git a/fastdeploy/entrypoints/openai/response_processors.py b/fastdeploy/entrypoints/openai/response_processors.py new file mode 100644 index 000000000..742268eba --- /dev/null +++ b/fastdeploy/entrypoints/openai/response_processors.py @@ -0,0 +1,145 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Any, List, Optional + +from fastdeploy.input.tokenzier_client import AsyncTokenizerClient, ImageDecodeRequest + + +class ChatResponseProcessor: + """ + A decoder class to build multimodal content (text/image) from token_ids. + + Attributes: + eoi_token_id: Token ID indicating the end of an image (). + """ + + def __init__( + self, + data_processor, + enable_mm_output: Optional[bool] = False, + eoi_token_id: Optional[int] = 101032, + eos_token_id: Optional[int] = 2, + decoder_base_url: Optional[str] = None, + ): + self.data_processor = data_processor + self.enable_mm_output = enable_mm_output + self.eoi_token_id = eoi_token_id + self.eos_token_id = eos_token_id + if decoder_base_url is not None: + self.decoder_client = AsyncTokenizerClient(base_url=decoder_base_url) + self._mm_buffer: List[Any] = [] # Buffer for accumulating image token_ids + self._end_image_code_request_output: Optional[Any] = None + self._multipart_buffer = [] + + def enable_multimodal_content(self): + return self.enable_mm_output + + def accumulate_token_ids(self, request_output): + decode_type = request_output["outputs"].get("decode_type", 0) + + if not self._multipart_buffer: + self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output}) + else: + last_part = self._multipart_buffer[-1] + + if last_part["decode_type"] == decode_type: + last_token_ids = last_part["request_output"]["outputs"]["token_ids"] + last_token_ids.extend(request_output["outputs"]["token_ids"]) + request_output["outputs"]["token_ids"] = last_token_ids + last_part["request_output"] = request_output + else: + self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output}) + + async def process_response_chat(self, request_outputs, stream, enable_thinking, include_stop_str_in_output): + """ + Process a list of responses into a generator that yields each processed response as it's generated. + Args: + request_outputs: The list of outputs to be processed. + stream: Whether or not to stream the output. + enable_thinking: Whether or not to show thinking messages. + include_stop_str_in_output: Whether or not to include stop strings in the output. + """ + for request_output in request_outputs: + if not self.enable_mm_output: + yield self.data_processor.process_response_dict( + response_dict=request_output, + stream=stream, + enable_thinking=enable_thinking, + include_stop_str_in_output=include_stop_str_in_output, + ) + elif stream: + decode_type = request_output["outputs"].get("decode_type", 0) + token_ids = request_output["outputs"]["token_ids"] + if decode_type == 0: + if self.eoi_token_id and self.eoi_token_id in token_ids: + if self._mm_buffer: + all_tokens = self._mm_buffer + self._mm_buffer = [] + image = {"type": "image"} + if self.decoder_client: + req_id = request_output["request_id"] + image_ret = await self.decoder_client.decode_image( + request=ImageDecodeRequest(req_id=req_id, data=all_tokens) + ) + image["url"] = image_ret["http_url"] + image_output = self._end_image_code_request_output + image_output["outputs"]["multipart"] = [image] + image_output["outputs"]["token_ids"] = all_tokens + yield image_output + + self.data_processor.process_response_dict( + response_dict=request_output, + stream=stream, + enable_thinking=enable_thinking, + include_stop_str_in_output=include_stop_str_in_output, + ) + text = {"type": "text", "text": request_output["outputs"]["text"]} + request_output["outputs"]["multipart"] = [text] + yield request_output + + elif decode_type == 1: + self._mm_buffer.extend(token_ids) + self._end_image_code_request_output = request_output + else: + self.accumulate_token_ids(request_output) + token_ids = request_output["outputs"]["token_ids"] + if token_ids[-1] == self.eos_token_id: + multipart = [] + for part in self._multipart_buffer: + if part["decode_type"] == 0: + self.data_processor.process_response_dict( + response_dict=part["request_output"], + stream=False, + enable_thinking=enable_thinking, + include_stop_str_in_output=include_stop_str_in_output, + ) + text = {"type": "text", "text": part["request_output"]["outputs"]["text"]} + multipart.append(text) + elif part["decode_type"] == 1: + image = {"type": "image"} + if self.decoder_client: + req_id = part["request_output"]["request_id"] + all_tokens = part["request_output"]["outputs"]["token_ids"] + image_ret = await self.decoder_client.decode_image( + request=ImageDecodeRequest(req_id=req_id, data=all_tokens) + ) + image["url"] = image_ret["http_url"] + multipart.append(image) + + lasrt_request_output = self._multipart_buffer[-1]["request_output"] + lasrt_request_output["outputs"]["multipart"] = multipart + yield lasrt_request_output diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index c157bd0e7..df074771c 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -36,6 +36,7 @@ from fastdeploy.entrypoints.openai.protocol import ( PromptTokenUsageInfo, UsageInfo, ) +from fastdeploy.entrypoints.openai.response_processors import ChatResponseProcessor from fastdeploy.metrics.work_metrics import work_process_metrics from fastdeploy.utils import api_server_logger from fastdeploy.worker.output import LogprobsLists @@ -46,12 +47,24 @@ class OpenAIServingChat: OpenAI-style chat completions serving """ - def __init__(self, engine_client, models, pid, ips, max_waiting_time, chat_template): + def __init__( + self, + engine_client, + models, + pid, + ips, + max_waiting_time, + chat_template, + enable_mm_output: Optional[bool] = False, + tokenizer_base_url: Optional[str] = None, + ): self.engine_client = engine_client self.models = models self.pid = pid self.max_waiting_time = max_waiting_time self.chat_template = chat_template + self.enable_mm_output = enable_mm_output + self.tokenizer_base_url = tokenizer_base_url if ips is not None: if isinstance(ips, list): self.master_ip = ips[0] @@ -198,6 +211,11 @@ class OpenAIServingChat: dealer.write([b"", request_id.encode("utf-8")]) choices = [] current_waiting_time = 0 + response_processor = ChatResponseProcessor( + data_processor=self.engine_client.data_processor, + enable_mm_output=self.enable_mm_output, + decoder_base_url=self.tokenizer_base_url, + ) while num_choices > 0: try: response = await asyncio.wait_for(response_queue.get(), timeout=10) @@ -215,17 +233,18 @@ class OpenAIServingChat: current_waiting_time = 0 await asyncio.sleep(0.01) continue - for res in response: + + generator = response_processor.process_response_chat( + response, + stream=True, + enable_thinking=enable_thinking, + include_stop_str_in_output=include_stop_str_in_output, + ) + + async for res in generator: if res.get("error_code", 200) != 200: raise ValueError("{}".format(res["error_msg"])) - self.engine_client.data_processor.process_response_dict( - res, - stream=True, - enable_thinking=enable_thinking, - include_stop_str_in_output=include_stop_str_in_output, - ) - if res["metrics"]["first_token_time"] is not None: arrival_time = res["metrics"]["first_token_time"] inference_start_time = res["metrics"]["inference_start_time"] @@ -239,13 +258,22 @@ class OpenAIServingChat: index=i, delta=DeltaMessage( role="assistant", - content="", reasoning_content="", tool_calls=None, prompt_token_ids=None, completion_token_ids=None, ), ) + if response_processor.enable_multimodal_content(): + choice.delta.multimodal_content = [ + { + "type": "text", + "text": "", + } + ] + else: + choice.delta.content = "" + if request.return_token_ids: choice.delta.prompt_token_ids = list(prompt_token_ids) choice.delta.text_after_process = text_after_process @@ -269,7 +297,6 @@ class OpenAIServingChat: first_iteration = False output = res["outputs"] - delta_text = output["text"] output_top_logprobs = output["top_logprobs"] previous_num_tokens += len(output["token_ids"]) logprobs_res: Optional[LogProbs] = None @@ -279,12 +306,17 @@ class OpenAIServingChat: ) delta_message = DeltaMessage( - content=delta_text, reasoning_content="", prompt_token_ids=None, - completion_token_ids=None, tool_calls=None, + completion_token_ids=None, ) + + if response_processor.enable_multimodal_content(): + delta_message.multimodal_content = output["multipart"] + else: + delta_message.content = output["text"] + if not res["finished"] and "delta_message" in output: delta_message_output = output["delta_message"] if delta_message_output is None: @@ -317,7 +349,10 @@ class OpenAIServingChat: choice.finish_reason = "recover_stop" if request.return_token_ids: - choice.delta.completion_token_ids = list(output["token_ids"]) + if response_processor.enable_multimodal_content(): + choice.delta.multimodal_content[0]["completion_token_ids"] = list(output["token_ids"]) + else: + choice.delta.completion_token_ids = list(output["token_ids"]) choice.delta.raw_prediction = output.get("raw_prediction") choice.delta.completion_tokens = output.get("raw_prediction") if include_continuous_usage: @@ -395,6 +430,11 @@ class OpenAIServingChat: current_waiting_time = 0 logprob_contents = [] completion_token_ids = [] + response_processor = ChatResponseProcessor( + data_processor=self.engine_client.data_processor, + enable_mm_output=self.enable_mm_output, + decoder_base_url=self.tokenizer_base_url, + ) while True: try: response = await asyncio.wait_for(response_queue.get(), timeout=10) @@ -411,15 +451,16 @@ class OpenAIServingChat: continue task_is_finished = False - for data in response: + + generator = response_processor.process_response_chat( + response, + stream=False, + enable_thinking=enable_thinking, + include_stop_str_in_output=include_stop_str_in_output, + ) + async for data in generator: if data.get("error_code", 200) != 200: raise ValueError("{}".format(data["error_msg"])) - data = self.engine_client.data_processor.process_response_dict( - data, - stream=False, - enable_thinking=enable_thinking, - include_stop_str_in_output=include_stop_str_in_output, - ) # api_server_logger.debug(f"Client {request_id} received: {data}") previous_num_tokens += len(data["outputs"]["token_ids"]) completion_token_ids.extend(data["outputs"]["token_ids"]) @@ -447,7 +488,6 @@ class OpenAIServingChat: output = final_res["outputs"] message = ChatMessage( role="assistant", - content=output["text"], reasoning_content=output.get("reasoning_content"), tool_calls=output.get("tool_call"), prompt_token_ids=prompt_token_ids if request.return_token_ids else None, @@ -457,6 +497,12 @@ class OpenAIServingChat: raw_prediction=output.get("raw_prediction") if request.return_token_ids else None, completion_tokens=output.get("raw_prediction") if request.return_token_ids else None, ) + + if response_processor.enable_multimodal_content(): + message.multimodal_content = output.get("multipart") + else: + message.content = output["text"] + logprobs_full_res = None if logprob_contents: logprobs_full_res = LogProbs(content=logprob_contents) diff --git a/fastdeploy/input/tokenzier_client.py b/fastdeploy/input/tokenzier_client.py new file mode 100644 index 000000000..d1bdeb761 --- /dev/null +++ b/fastdeploy/input/tokenzier_client.py @@ -0,0 +1,163 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import asyncio +from typing import Any, Optional, Union + +import httpx +from pydantic import BaseModel, HttpUrl + +from fastdeploy.utils import data_processor_logger + + +class BaseEncodeRequest(BaseModel): + version: str + req_id: str + is_gen: bool + resolution: int + + +class ImageEncodeRequest(BaseEncodeRequest): + image_url: Union[str, HttpUrl] + + +class VideoEncodeRequest(BaseEncodeRequest): + video_url: Union[str, HttpUrl] + start_ts: int + end_ts: int + frames: int + + +class ImageDecodeRequest(BaseModel): + req_id: str + data: list[Any] + + +class AsyncTokenizerClient: + def __init__( + self, + base_url: Optional[str] = None, + timeout: float = 5.0, + poll_interval: float = 0.5, + max_wait: float = 60.0, + ): + """ + :param mode: 'local' 或 'remote' + :param base_url: 远程服务地址 + :param timeout: 单次 HTTP 请求超时(秒) + :param poll_interval: 查询结果的轮询间隔(秒) + :param max_wait: 最大等待时间(秒) + """ + self.base_url = base_url + self.timeout = timeout + self.poll_interval = poll_interval + self.max_wait = max_wait + + async def encode_image(self, request: ImageEncodeRequest): + return await self._async_encode_request("image", request.__dict__) + + async def encode_video(self, request: VideoEncodeRequest): + return await self._async_encode_request("video", request.__dict__) + + async def decode_image(self, request: ImageDecodeRequest): + return await self._async_decode_request("image", request.__dict__) + + async def log_request(self, request): + data_processor_logger.debug(f">>> Request: {request.method} {request.url}") + data_processor_logger.debug(f">>> Headers: {request.headers}") + if request.content: + data_processor_logger.debug(f">>> Content: {request.content.decode('utf-8')}") + + async def log_response(self, response): + data_processor_logger.debug(f"<<< Response status: {response.status_code}") + data_processor_logger.debug(f"<<< Headers: {response.headers}") + + async def _async_encode_request(self, type: str, request: dict): + if not self.base_url: + raise ValueError("Missing base_url") + + async with httpx.AsyncClient( + timeout=self.timeout, event_hooks={"request": [self.log_request], "response": [self.log_response]} + ) as client: + req_id = request.get("req_id") + try: + url = None + if type == "image": + url = f"{self.base_url}/image/encode" + elif type == "video": + url = f"{self.base_url}/video/encode" + else: + raise ValueError("Invalid type") + + resp = await client.post(url, json=request) + resp.raise_for_status() + except httpx.RequestError as e: + raise RuntimeError(f"Failed to create tokenize task: {e}") from e + + task_info = resp.json() + if task_info.get("code") != 0: + raise RuntimeError(f"Tokenize task creation failed, {task_info.get('message')}") + + task_tag = task_info.get("task_tag") + if not task_tag: + raise RuntimeError("No task_tag returned from server") + + # 2. 轮询结果 + start_time = asyncio.get_event_loop().time() + while True: + try: + r = await client.get( + f"{self.base_url}/encode/get", params={"task_tag": task_tag, "req_id": req_id} + ) + r.raise_for_status() + data = r.json() + + # 异步encode任务当前执行状态: Processing, Finished, Error + if data.get("state") == "Finished": + return data.get("result") + elif data.get("state") == "Error": + raise RuntimeError(f"Tokenize task failed: {data.get('message')}") + + except httpx.RequestError: + # 网络问题时继续轮询 + pass + + # 超时检测 + if asyncio.get_event_loop().time() - start_time > self.max_wait: + raise TimeoutError(f"Tokenize task {task_tag} timed out after {self.max_wait}s") + + await asyncio.sleep(self.poll_interval) + + async def _async_decode_request(self, type: str, request: dict): + if not self.base_url: + raise ValueError("Missing base_url") + + async with httpx.AsyncClient( + timeout=self.timeout, event_hooks={"request": [self.log_request], "response": [self.log_response]} + ) as client: + try: + url = None + if type == "image": + url = f"{self.base_url}/image/decode" + else: + raise ValueError("Invalid type") + resp = await client.post(url, json=request) + resp.raise_for_status() + if resp.json().get("code") != 0: + raise RuntimeError(f"Tokenize task creation failed, {resp.json().get('message')}") + return resp.json().get("result") + except httpx.RequestError as e: + raise RuntimeError(f"Failed to decode: {e}") from e diff --git a/test/input/test_tokenizer_client.py b/test/input/test_tokenizer_client.py new file mode 100644 index 000000000..64c50e929 --- /dev/null +++ b/test/input/test_tokenizer_client.py @@ -0,0 +1,84 @@ +import httpx +import pytest +import respx + +from fastdeploy.input.tokenzier_client import ( + AsyncTokenizerClient, + ImageEncodeRequest, + VideoEncodeRequest, +) + + +@pytest.mark.asyncio +@respx.mock +async def test_encode_image_success(): + base_url = "http://testserver" + client = AsyncTokenizerClient(base_url=base_url) + + # Mock 创建任务接口 + respx.post(f"{base_url}/image/encode").mock( + return_value=httpx.Response(200, json={"code": 0, "task_tag": "task123"}) + ) + # Mock 轮询接口,返回完成状态 + mock_get_ret = { + "state": "Finished", + "result": {"feature_url": "bos://host:port/key", "feature_shape": [80, 45, 1563]}, + } + respx.get(f"{base_url}/encode/get").mock(return_value=httpx.Response(200, json=mock_get_ret)) + + request = ImageEncodeRequest( + version="v1", req_id="req_img_001", is_gen=False, resolution=512, image_url="http://example.com/image.jpg" + ) + + result = await client.encode_image(request) + assert result["feature_url"] == "bos://host:port/key" + assert result["feature_shape"] == [80, 45, 1563] + + +@pytest.mark.asyncio +@respx.mock +async def test_encode_video_failure(): + base_url = "http://testserver" + client = AsyncTokenizerClient(base_url=base_url, max_wait=1) + + respx.post(f"{base_url}/video/encode").mock( + return_value=httpx.Response(200, json={"code": 0, "task_tag": "task_vid_001"}) + ) + # 模拟轮询接口失败状态 + respx.get(f"{base_url}/encode/get").mock( + return_value=httpx.Response(200, json={"state": "Error", "message": "Encode failed"}) + ) + + request = VideoEncodeRequest( + version="v1", + req_id="req_vid_001", + is_gen=True, + resolution=720, + video_url="http://example.com/video.mp4", + start_ts=0.0, + end_ts=10.0, + frames=30, + ) + + with pytest.raises(RuntimeError, match="Encode failed"): + await client.encode_video(request) + + +@pytest.mark.asyncio +@respx.mock +async def test_encode_timeout(): + base_url = "http://testserver" + client = AsyncTokenizerClient(base_url=base_url, max_wait=1, poll_interval=0.1) + + respx.post(f"{base_url}/image/encode").mock( + return_value=httpx.Response(200, json={"code": 0, "task_tag": "task_timeout"}) + ) + # 模拟轮询接口一直返回等待状态,导致超时 + respx.get(f"{base_url}/encode/get").mock(return_value=httpx.Response(200, json={"status": "processing"})) + + request = ImageEncodeRequest( + version="v1", req_id="req_img_timeout", is_gen=False, resolution=256, image_url="http://example.com/image.jpg" + ) + + with pytest.raises(TimeoutError): + await client.encode_image(request) diff --git a/tests/entrypoints/openai/test_build_sample_logprobs.py b/tests/entrypoints/openai/test_build_sample_logprobs.py index 76ff8e87b..74a00fcda 100644 --- a/tests/entrypoints/openai/test_build_sample_logprobs.py +++ b/tests/entrypoints/openai/test_build_sample_logprobs.py @@ -1,3 +1,19 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + import unittest from unittest.mock import MagicMock, patch diff --git a/tests/entrypoints/openai/test_completion_echo.py b/tests/entrypoints/openai/test_completion_echo.py index 565e5ad93..52a223070 100644 --- a/tests/entrypoints/openai/test_completion_echo.py +++ b/tests/entrypoints/openai/test_completion_echo.py @@ -1,3 +1,19 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + import unittest from unittest.mock import MagicMock, patch diff --git a/tests/entrypoints/openai/test_response_processors.py b/tests/entrypoints/openai/test_response_processors.py new file mode 100644 index 000000000..9aabfbeff --- /dev/null +++ b/tests/entrypoints/openai/test_response_processors.py @@ -0,0 +1,134 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest +from unittest.mock import AsyncMock, MagicMock + +from fastdeploy.entrypoints.openai.response_processors import ChatResponseProcessor + + +class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + self.mock_data_processor = MagicMock() + self.mock_data_processor.process_response_dict = MagicMock( + side_effect=lambda response_dict, **_: {"processed": True, "raw": response_dict} + ) + + async def asyncSetUp(self): + self.processor_mm = ChatResponseProcessor( + data_processor=self.mock_data_processor, + enable_mm_output=True, + eoi_token_id=101032, + eos_token_id=2, + decoder_base_url="http://fake-decoder", + ) + self.processor_mm.decoder_client.decode_image = AsyncMock( + return_value={"http_url": "http://image.url/test.png"} + ) + + async def test_text_only_mode(self): + """不开启 multimodal 时,直接走 data_processor""" + processor = ChatResponseProcessor(self.mock_data_processor) + request_outputs = [{"outputs": {"text": "hello"}}] + + results = [ + r + async for r in processor.process_response_chat( + request_outputs, stream=False, enable_thinking=False, include_stop_str_in_output=False + ) + ] + + self.mock_data_processor.process_response_dict.assert_called_once() + self.assertEqual(results[0]["processed"], True) + self.assertEqual(results[0]["raw"]["outputs"]["text"], "hello") + + async def test_streaming_text_and_image(self): + """流式模式下:text → image → text""" + request_outputs = [ + {"request_id": "req1", "outputs": {"decode_type": 0, "token_ids": [1], "text": "hi"}}, + {"request_id": "req1", "outputs": {"decode_type": 1, "token_ids": [[11, 22]]}}, + {"request_id": "req1", "outputs": {"decode_type": 0, "token_ids": [101032], "text": "done"}}, + ] + + results = [ + r + async for r in self.processor_mm.process_response_chat( + request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False + ) + ] + + # 第一个 yield:text + text_part = results[0]["outputs"]["multipart"][0] + self.assertEqual(text_part["type"], "text") + self.assertEqual(text_part["text"], "hi") + + # 第二个 yield:image(token_ids 被拼起来了) + image_part = results[1]["outputs"]["multipart"][0] + self.assertEqual(image_part["type"], "image") + self.assertEqual(image_part["url"], "http://image.url/test.png") + self.assertEqual(results[1]["outputs"]["token_ids"], [[11, 22]]) + + # 第三个 yield:text + text_part = results[2]["outputs"]["multipart"][0] + self.assertEqual(text_part["type"], "text") + self.assertEqual(text_part["text"], "done") + + async def test_streaming_buffer_accumulation(self): + """流式模式:decode_type=1 只累积 buffer,不 yield""" + request_outputs = [{"request_id": "req2", "outputs": {"decode_type": 1, "token_ids": [[33, 44]]}}] + + results = [ + r + async for r in self.processor_mm.process_response_chat( + request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False + ) + ] + + self.assertEqual(results, []) + self.assertEqual(self.processor_mm._mm_buffer, [[33, 44]]) + + async def test_non_streaming_accumulate_and_emit(self): + """非流式模式:等 eos_token_id 才输出 multipart(text+image)""" + request_outputs = [ + {"request_id": "req3", "outputs": {"decode_type": 0, "token_ids": [10], "text": "hello"}}, + {"request_id": "req3", "outputs": {"decode_type": 1, "token_ids": [[55, 66]]}}, + {"request_id": "req3", "outputs": {"decode_type": 0, "token_ids": [2], "text": "bye"}}, # eos_token_id + ] + + results = [ + r + async for r in self.processor_mm.process_response_chat( + request_outputs, stream=False, enable_thinking=False, include_stop_str_in_output=False + ) + ] + + # 只在最后一个输出 yield + self.assertEqual(len(results), 1) + multipart = results[0]["outputs"]["multipart"] + + self.assertEqual(multipart[0]["type"], "text") + self.assertEqual(multipart[0]["text"], "hello") + + self.assertEqual(multipart[1]["type"], "image") + self.assertEqual(multipart[1]["url"], "http://image.url/test.png") + + self.assertEqual(multipart[2]["type"], "text") + self.assertEqual(multipart[2]["text"], "bye") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/entrypoints/openai/test_serving_completion.py b/tests/entrypoints/openai/test_serving_completion.py index 82370ca0b..f3941c80b 100644 --- a/tests/entrypoints/openai/test_serving_completion.py +++ b/tests/entrypoints/openai/test_serving_completion.py @@ -1,3 +1,19 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + import unittest from typing import List from unittest.mock import Mock diff --git a/tests/entrypoints/openai/test_serving_models.py b/tests/entrypoints/openai/test_serving_models.py index a6b804508..1ef2fff7e 100644 --- a/tests/entrypoints/openai/test_serving_models.py +++ b/tests/entrypoints/openai/test_serving_models.py @@ -1,3 +1,19 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + import asyncio import unittest