From 18f4977aecefe1cdf946cd54186dffb78cd7060a Mon Sep 17 00:00:00 2001 From: luukunn <83932082+luukunn@users.noreply.github.com> Date: Wed, 24 Sep 2025 18:56:32 +0800 Subject: [PATCH] [fix]update apply_chat_template (#4137) * update apply_chat_template * fix unittest * fix unittest * fix * fix * fix unit test * fix * fix unit test * add unit test --- fastdeploy/engine/engine.py | 4 +- fastdeploy/entrypoints/engine_client.py | 3 + fastdeploy/input/ernie4_5_processor.py | 13 ++- .../ernie4_5_vl_processor.py | 1 - .../input/ernie4_5_vl_processor/process.py | 8 +- fastdeploy/input/text_processor.py | 13 ++- tests/entrypoints/test_engine_client.py | 36 ++++++++ tests/input/test_ernie_processor.py | 25 ++++++ tests/input/test_text_processor.py | 63 +++++++++++++ tests/utils/test_custom_chat_template.py | 89 ------------------- 10 files changed, 146 insertions(+), 109 deletions(-) create mode 100644 tests/entrypoints/test_engine_client.py create mode 100644 tests/input/test_text_processor.py diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 8689490d8..b523adb6f 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -220,7 +220,9 @@ class LLMEngine: if sampling_params is not None: request.sampling_params = sampling_params request.preprocess_start_time = time.time() - + chat_template_kwargs = kwargs.get("chat_template_kwargs") or {} + chat_template_kwargs["chat_template"] = kwargs.get("chat_template") + kwargs["chat_template_kwargs"] = chat_template_kwargs request = self.data_processor.process_request(request, self.cfg.max_model_len, **kwargs) request.prompt_token_ids_len = len(request.prompt_token_ids) request.need_prefill_tokens = request.prompt_token_ids_len diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 9f533b90f..49b3656fd 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -144,6 +144,9 @@ class EngineClient: task["preprocess_start_time"] = time.time() try: + chat_template_kwargs = task.get("chat_template_kwargs", {}) + chat_template_kwargs.update({"chat_template": task.get("chat_template"), "tools": task.get("tools")}) + task["chat_template_kwargs"] = chat_template_kwargs if inspect.iscoroutinefunction(self.data_processor.process_request_dict): await self.data_processor.process_request_dict(task, self.max_model_len) else: diff --git a/fastdeploy/input/ernie4_5_processor.py b/fastdeploy/input/ernie4_5_processor.py index f364ecba1..8d2463a08 100644 --- a/fastdeploy/input/ernie4_5_processor.py +++ b/fastdeploy/input/ernie4_5_processor.py @@ -88,7 +88,6 @@ class Ernie4_5Processor(BaseDataProcessor): str: error message """ data_processor_logger.info(f"Start processing request: {request}") - request.chat_template = kwargs.get("chat_template") request = self._apply_default_parameters(request) if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0: request.eos_token_ids = self.eos_token_ids @@ -127,7 +126,7 @@ class Ernie4_5Processor(BaseDataProcessor): ) elif request.messages is not None: task = request.to_dict() - chat_template_kwargs = kwargs.get("chat_template_kwargs") + chat_template_kwargs = kwargs.get("chat_template_kwargs", {}) if chat_template_kwargs: if isinstance(chat_template_kwargs, dict): for k, v in chat_template_kwargs.items(): @@ -135,7 +134,7 @@ class Ernie4_5Processor(BaseDataProcessor): task[k] = v else: raise ValueError("Invalid input: chat_template_kwargs must be a dict") - request.prompt_token_ids = self.messages2ids(task) + request.prompt_token_ids = self.messages2ids(task, **chat_template_kwargs) else: raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.") @@ -205,7 +204,7 @@ class Ernie4_5Processor(BaseDataProcessor): req_id = request.get("request_id", None) data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}") elif request.get("messages"): - chat_template_kwargs = request.get("chat_template_kwargs") + chat_template_kwargs = request.get("chat_template_kwargs", {}) if chat_template_kwargs: if isinstance(chat_template_kwargs, dict): for k, v in chat_template_kwargs.items(): @@ -213,7 +212,7 @@ class Ernie4_5Processor(BaseDataProcessor): request[k] = v else: raise ValueError("Invalid input: chat_template_kwargs must be a dict") - request["prompt_token_ids"] = self.messages2ids(request) + request["prompt_token_ids"] = self.messages2ids(request, **chat_template_kwargs) else: raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}") @@ -379,7 +378,7 @@ class Ernie4_5Processor(BaseDataProcessor): del self.tool_parser_dict[req_id] return response_dict - def messages2ids(self, request_or_messages): + def messages2ids(self, request_or_messages, **kwargs): """ Convert multi-turn messages into ID sequences. @@ -397,7 +396,7 @@ class Ernie4_5Processor(BaseDataProcessor): tokenize=False, split_special_tokens=False, add_special_tokens=False, - chat_template=request_or_messages.get("chat_template", None), + **kwargs, ) request_or_messages["text_after_process"] = spliced_message req_id = None diff --git a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py index 77690b920..deec2f2af 100644 --- a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py +++ b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py @@ -113,7 +113,6 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor): def process_request(self, request, max_model_len=None, **kwargs): """process the input data""" - request.chat_template = kwargs.get("chat_template") task = request.to_dict() task["chat_template_kwargs"] = kwargs.get("chat_template_kwargs") self.process_request_dict(task, max_model_len) diff --git a/fastdeploy/input/ernie4_5_vl_processor/process.py b/fastdeploy/input/ernie4_5_vl_processor/process.py index 0616dd5b1..ea22850dd 100644 --- a/fastdeploy/input/ernie4_5_vl_processor/process.py +++ b/fastdeploy/input/ernie4_5_vl_processor/process.py @@ -250,8 +250,8 @@ class DataProcessor: "video", ]: image_message_list.append(item) - - prompt_token_ids = self.apply_chat_template(request) + chat_template_kwargs = request.get("chat_template_kwargs", {}) + prompt_token_ids = self.apply_chat_template(request, **chat_template_kwargs) if len(prompt_token_ids) == 0: raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs") image_start_index = 0 @@ -480,7 +480,7 @@ class DataProcessor: break self.tokenizer = Ernie4_5Tokenizer.from_pretrained(self.model_name_or_path) - def apply_chat_template(self, request): + def apply_chat_template(self, request, **kwargs): """ Convert multi-turn messages into ID sequences. @@ -498,7 +498,7 @@ class DataProcessor: request, tokenize=False, add_generation_prompt=request.get("add_generation_prompt", True), - chat_template=request.get("chat_template", None), + **kwargs, ) prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace( "<|video@placeholder|>", "" diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py index a1baf8e46..a29e1b260 100644 --- a/fastdeploy/input/text_processor.py +++ b/fastdeploy/input/text_processor.py @@ -208,7 +208,6 @@ class DataProcessor(BaseDataProcessor): str: error message """ data_processor_logger.info(f"Start processing request: {request}") - request.chat_template = kwargs.get("chat_template") request = self._apply_default_parameters(request) if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0: request.eos_token_ids = self.eos_token_ids @@ -242,7 +241,7 @@ class DataProcessor(BaseDataProcessor): if self.tokenizer.chat_template is None: raise ValueError("This model does not support chat_template.") task = request.to_dict() - chat_template_kwargs = kwargs.get("chat_template_kwargs") + chat_template_kwargs = kwargs.get("chat_template_kwargs", {}) if chat_template_kwargs: if isinstance(chat_template_kwargs, dict): for k, v in chat_template_kwargs.items(): @@ -251,7 +250,7 @@ class DataProcessor(BaseDataProcessor): else: raise ValueError("Invalid input: chat_template_kwargs must be a dict") task.setdefault("enable_thinking", True) - request.prompt_token_ids = self.messages2ids(task) + request.prompt_token_ids = self.messages2ids(task, **chat_template_kwargs) else: raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.") @@ -316,7 +315,7 @@ class DataProcessor(BaseDataProcessor): elif request.get("messages"): if self.tokenizer.chat_template is None: raise ValueError("This model does not support chat_template.") - chat_template_kwargs = request.get("chat_template_kwargs") + chat_template_kwargs = request.get("chat_template_kwargs", {}) if chat_template_kwargs: if isinstance(chat_template_kwargs, dict): for k, v in chat_template_kwargs.items(): @@ -325,7 +324,7 @@ class DataProcessor(BaseDataProcessor): else: raise ValueError("Invalid input: chat_template_kwargs must be a dict") request.setdefault("enable_thinking", True) - request["prompt_token_ids"] = self.messages2ids(request) + request["prompt_token_ids"] = self.messages2ids(request, **chat_template_kwargs) else: raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}") @@ -530,7 +529,7 @@ class DataProcessor(BaseDataProcessor): return tokens["input_ids"][0] - def messages2ids(self, request): + def messages2ids(self, request, **kwargs): """ Convert multi-turn messages into ID sequences. @@ -547,7 +546,7 @@ class DataProcessor(BaseDataProcessor): split_special_tokens=False, add_special_tokens=False, return_tensors="pd", - chat_template=request.get("chat_template", None), + **kwargs, ) request["text_after_process"] = spliced_message req_id = None diff --git a/tests/entrypoints/test_engine_client.py b/tests/entrypoints/test_engine_client.py new file mode 100644 index 000000000..e11fa5493 --- /dev/null +++ b/tests/entrypoints/test_engine_client.py @@ -0,0 +1,36 @@ +import unittest +from unittest.mock import MagicMock, patch + +from fastdeploy.entrypoints.engine_client import EngineClient + + +class TestEngineClient(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + # 创建 EngineClient 实例的模拟对象 + with patch.object(EngineClient, "__init__", return_value=None) as mock_init: + self.engine_client = EngineClient("model_path") + mock_init.side_effect = lambda *args, **kwargs: print(f"__init__ called with {args}, {kwargs}") + + self.engine_client.data_processor = MagicMock() + self.engine_client.zmq_client = MagicMock() + self.engine_client.max_model_len = 1024 + self.engine_client.enable_mm = False + + async def test_add_request(self): + request = { + "chat_template_kwargs": {"enable_thinking": True}, + "prompt_token_ids": [1], + "chat_template": "Hello", + "max_tokens": 20, + "tools": [1], + } + + await self.engine_client.add_requests(request) + assert "chat_template" in request["chat_template_kwargs"], "'chat_template' not found in 'chat_template_kwargs" + assert "tools" in request["chat_template_kwargs"], "'tools' not found in 'chat_template_kwargs'" + assert request["chat_template_kwargs"]["chat_template"] == "Hello" + assert request["chat_template_kwargs"]["tools"] == [1] + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/input/test_ernie_processor.py b/tests/input/test_ernie_processor.py index c87604bbc..b2357eeaa 100644 --- a/tests/input/test_ernie_processor.py +++ b/tests/input/test_ernie_processor.py @@ -17,6 +17,8 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase): self.processor.decode_status = {} self.processor.reasoning_end_dict = {} self.processor.tool_parser_dict = {} + self.processor.generation_config = MagicMock() + self.processor.eos_token_ids = [1] # 模拟 ids2tokens 方法 def mock_ids2tokens(token_ids, task_id): @@ -24,6 +26,18 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase): self.processor.ids2tokens = mock_ids2tokens + def mock_messages2ids(request, **kwargs): + if "chat_template" in kwargs: + return [1] + else: + return [0] + + def mock_apply_default_parameters(request): + return request + + self.processor.messages2ids = mock_messages2ids + self.processor._apply_default_parameters = mock_apply_default_parameters + # 模拟推理解析器 self.mock_reasoning_parser = MagicMock() self.mock_reasoning_parser.__class__.__name__ = "ErnieX1ReasoningParser" @@ -49,6 +63,17 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase): # 验证结果 self.assertEqual(result["outputs"]["raw_prediction"], "delta_text") + def test_process_request_dict(self): + request_dict = { + "messages": [{"role": "user", "content": "Hello!"}], + "chat_template_kwargs": {"chat_template": "Hello!"}, + "eos_token_ids": [1], + "temperature": 1, + "top_p": 1, + } + result = self.processor.process_request_dict(request_dict, 100) + self.assertEqual(result["prompt_token_ids"], [1]) + if __name__ == "__main__": unittest.main() diff --git a/tests/input/test_text_processor.py b/tests/input/test_text_processor.py new file mode 100644 index 000000000..6ca0178fe --- /dev/null +++ b/tests/input/test_text_processor.py @@ -0,0 +1,63 @@ +import unittest +from unittest.mock import MagicMock, patch + +from fastdeploy.engine.request import Request +from fastdeploy.input.text_processor import DataProcessor + + +class TestDataProcessorProcess(unittest.TestCase): + def setUp(self): + # 创建 DataProcessor 实例的模拟对象 + with patch.object(DataProcessor, "__init__", return_value=None) as mock_init: + self.processor = DataProcessor("model_path") + mock_init.side_effect = lambda *args, **kwargs: print(f"__init__ called with {args}, {kwargs}") + + # 设置必要的属性 + self.processor.tokenizer = MagicMock() + self.processor.tokenizer.eos_token_id = 1 + self.processor.decode_status = {} + self.processor.reasoning_end_dict = {} + self.processor.tool_parser_dict = {} + self.processor.generation_config = MagicMock() + self.processor.eos_token_ids = [1] + + def mock_messages2ids(request, **kwargs): + if "chat_template" in kwargs: + return [1] + else: + return [0] + + def mock_apply_default_parameters(request): + return request + + self.processor.messages2ids = mock_messages2ids + self.processor._apply_default_parameters = mock_apply_default_parameters + + def test_process_request(self): + request = Request.from_dict( + { + "request_id": "123", + "messages": [{"role": "user", "content": "Hello!"}], + "eos_token_ids": [1], + "temperature": 1, + "top_p": 1, + } + ) + chat_template_kwargs = {"chat_template": "Hello!"} + result = self.processor.process_request(request, 100, chat_template_kwargs=chat_template_kwargs) + self.assertEqual(result.prompt_token_ids, [1]) + + def test_process_request_dict(self): + request_dict = { + "messages": [{"role": "user", "content": "Hello!"}], + "chat_template_kwargs": {"chat_template": "Hello!"}, + "eos_token_ids": [1], + "temperature": 1, + "top_p": 1, + } + result = self.processor.process_request_dict(request_dict, 100) + self.assertEqual(result["prompt_token_ids"], [1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/test_custom_chat_template.py b/tests/utils/test_custom_chat_template.py index 71a617044..1311289ff 100644 --- a/tests/utils/test_custom_chat_template.py +++ b/tests/utils/test_custom_chat_template.py @@ -3,15 +3,11 @@ import unittest from pathlib import Path from unittest.mock import AsyncMock, MagicMock, mock_open, patch -from fastdeploy.engine.request import Request from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.entrypoints.chat_utils import load_chat_template from fastdeploy.entrypoints.llm import LLM from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat -from fastdeploy.input.ernie4_5_processor import Ernie4_5Processor -from fastdeploy.input.ernie4_5_vl_processor import Ernie4_5_VLProcessor -from fastdeploy.input.text_processor import DataProcessor class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase): @@ -108,91 +104,6 @@ class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase): chat_completion = await self.chat_completion_handler.create_chat_completion(request) self.assertEqual("hello", chat_completion["chat_template"]) - @patch("fastdeploy.input.ernie4_5_vl_processor.Ernie4_5_VLProcessor.__init__") - def test_ernie4_5_vl_processor(self, mock_class): - mock_class.return_value = None - ernie4_5_vl_processor = Ernie4_5_VLProcessor() - mock_request = Request.from_dict({"request_id": "123"}) - - def mock_apply_default_parameters(request): - return request - - def mock_process_request(request, max_model_len): - return request - - ernie4_5_vl_processor._apply_default_parameters = mock_apply_default_parameters - ernie4_5_vl_processor.process_request_dict = mock_process_request - result = ernie4_5_vl_processor.process_request(mock_request, chat_template="hello") - self.assertEqual("hello", result.chat_template) - - @patch("fastdeploy.input.text_processor.DataProcessor.__init__") - def test_text_processor_process_request(self, mock_class): - mock_class.return_value = None - text_processor = DataProcessor() - mock_request = Request.from_dict( - {"request_id": "123", "prompt": "hi", "max_tokens": 128, "temperature": 1, "top_p": 1} - ) - - def mock_apply_default_parameters(request): - return request - - def mock_process_request(request, max_model_len): - return request - - def mock_text2ids(text, max_model_len): - return [1] - - text_processor._apply_default_parameters = mock_apply_default_parameters - text_processor.process_request_dict = mock_process_request - text_processor.text2ids = mock_text2ids - text_processor.eos_token_ids = [1] - result = text_processor.process_request(mock_request, chat_template="hello") - self.assertEqual("hello", result.chat_template) - - @patch("fastdeploy.input.ernie4_5_processor.Ernie4_5Processor.__init__") - def test_ernie4_5_processor_process(self, mock_class): - mock_class.return_value = None - ernie4_5_processor = Ernie4_5Processor() - mock_request = Request.from_dict( - {"request_id": "123", "messages": ["hi"], "max_tokens": 128, "temperature": 1, "top_p": 1} - ) - - def mock_apply_default_parameters(request): - return request - - def mock_process_request(request, max_model_len): - return request - - def mock_messages2ids(text): - return [1] - - ernie4_5_processor._apply_default_parameters = mock_apply_default_parameters - ernie4_5_processor.process_request_dict = mock_process_request - ernie4_5_processor.messages2ids = mock_messages2ids - ernie4_5_processor.eos_token_ids = [1] - ernie4_5_processor.reasoning_parser = MagicMock() - result = ernie4_5_processor.process_request(mock_request, chat_template="hello") - self.assertEqual("hello", result.chat_template) - - @patch("fastdeploy.entrypoints.llm.LLM.__init__") - def test_llm_load(self, mock_class): - mock_class.return_value = None - llm = LLM() - llm.llm_engine = MagicMock() - llm.default_sampling_params = MagicMock() - llm.chat_template = "hello" - - def mock_run_engine(req_ids, **kwargs): - return req_ids - - def mock_add_request(**kwargs): - return kwargs.get("chat_template") - - llm._run_engine = mock_run_engine - llm._add_request = mock_add_request - result = llm.chat(["hello"], sampling_params=SamplingParams(1)) - self.assertEqual("hello", result) - @patch("fastdeploy.entrypoints.llm.LLM.__init__") def test_llm(self, mock_class): mock_class.return_value = None