diff --git a/docs/online_serving/README.md b/docs/online_serving/README.md index 761e79720..6cdf1be92 100644 --- a/docs/online_serving/README.md +++ b/docs/online_serving/README.md @@ -161,6 +161,9 @@ The following extra parameters are supported: chat_template_kwargs: Optional[dict] = None # Additional parameters passed to the chat template, used for customizing dialogue formats (default None). +chat_template: Optional[str] = None +# Custom chat template will override the model's default chat template (default None). + reasoning_max_tokens: Optional[int] = None # Maximum number of tokens to generate during reasoning (e.g., CoT, chain of thought) (default None means using global max_tokens). diff --git a/docs/parameters.md b/docs/parameters.md index f302fbe42..81f582335 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -46,6 +46,7 @@ When using FastDeploy to deploy models (including offline inference and service | ```dynamic_load_weight``` | `int` | Whether to enable dynamic weight loading, default: 0 | | ```enable_expert_parallel``` | `bool` | Whether to enable expert parallel | | ```enable_logprob``` | `bool` | Whether to enable return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message.If logrpob is not used, this parameter can be omitted when starting | +| ```chat_template``` | `str` | Specify the template used for model concatenation, It supports both string input and file path input. The default value is None. If not specified, the model's default template will be used. | ## 1. Relationship between KVCache allocation, ```num_gpu_blocks_override``` and ```block_size```? diff --git a/docs/zh/online_serving/README.md b/docs/zh/online_serving/README.md index a68eedbdb..d55daffc3 100644 --- a/docs/zh/online_serving/README.md +++ b/docs/zh/online_serving/README.md @@ -160,6 +160,9 @@ repetition_penalty: Optional[float] = None chat_template_kwargs: Optional[dict] = None # 传递给聊天模板(chat template)的额外参数,用于自定义对话格式(默认 None)。 +chat_template: Optional[str] = None +# 自定义聊天模板,会覆盖模型默认的聊天模板,(默认 None)。 + reasoning_max_tokens: Optional[int] = None # 推理(如 CoT, 思维链)过程中生成的最大 token 数(默认 None 表示使用全局 max_tokens)。 diff --git a/docs/zh/parameters.md b/docs/zh/parameters.md index 09ba05d60..e68d342f3 100644 --- a/docs/zh/parameters.md +++ b/docs/zh/parameters.md @@ -44,6 +44,7 @@ | ```dynamic_load_weight``` | `int` | 是否动态加载权重,默认0 | | ```enable_expert_parallel``` | `bool` | 是否启用专家并行 | | ```enable_logprob``` | `bool` | 是否启用输出token返回logprob。如果未使用 logrpob,则在启动时可以省略此参数。 | +| ```chat_template``` | `str` | 指定模型拼接使用的模板,支持字符串与文件路径,默认为None,如未指定,则使用模型默认模板 | ## 1. KVCache分配与```num_gpu_blocks_override```、```block_size```的关系? diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 15e5e3cb7..af7b3ffb0 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -94,6 +94,10 @@ class EngineArgs: """ specifies the reasoning parser to use for extracting reasoning content from the model output """ + chat_template: str = None + """ + chat template or chat template file path + """ tool_call_parser: str = None """ specifies the tool call parser to use for extracting tool call from the model output @@ -442,6 +446,12 @@ class EngineArgs: help="Flag specifies the reasoning parser to use for extracting " "reasoning content from the model output", ) + model_group.add_argument( + "--chat-template", + type=str, + default=EngineArgs.chat_template, + help="chat template or chat template file path", + ) model_group.add_argument( "--tool-call-parser", type=str, diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index b9fa895e6..67c0caa08 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -72,6 +72,7 @@ class Request: guided_json_object: Optional[bool] = None, enable_thinking: Optional[bool] = True, trace_carrier: dict = dict(), + chat_template: Optional[str] = None, ) -> None: self.request_id = request_id self.prompt = prompt @@ -111,6 +112,8 @@ class Request: self.enable_thinking = enable_thinking self.trace_carrier = trace_carrier + self.chat_template = chat_template + # token num self.block_tables = [] self.output_token_ids = [] @@ -152,6 +155,7 @@ class Request: guided_json_object=d.get("guided_json_object", None), enable_thinking=d.get("enable_thinking", True), trace_carrier=d.get("trace_carrier", {}), + chat_template=d.get("chat_template", None), ) @property @@ -191,6 +195,7 @@ class Request: "draft_token_ids": self.draft_token_ids, "enable_thinking": self.enable_thinking, "trace_carrier": self.trace_carrier, + "chat_template": self.chat_template, } add_params = [ "guided_json", diff --git a/fastdeploy/entrypoints/chat_utils.py b/fastdeploy/entrypoints/chat_utils.py index fbe69c413..c90df1529 100644 --- a/fastdeploy/entrypoints/chat_utils.py +++ b/fastdeploy/entrypoints/chat_utils.py @@ -16,7 +16,8 @@ import uuid from copy import deepcopy -from typing import List, Literal, Union +from pathlib import Path +from typing import List, Literal, Optional, Union from urllib.parse import urlparse import requests @@ -159,5 +160,37 @@ def parse_chat_messages(messages): return conversation +def load_chat_template( + chat_template: Union[Path, str], + is_literal: bool = False, +) -> Optional[str]: + if chat_template is None: + return None + if is_literal: + if isinstance(chat_template, Path): + raise TypeError("chat_template is expected to be read directly " "from its value") + + return chat_template + + try: + with open(chat_template) as f: + return f.read() + except OSError as e: + if isinstance(chat_template, Path): + raise + JINJA_CHARS = "{}\n" + if not any(c in chat_template for c in JINJA_CHARS): + msg = ( + f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}" + ) + raise ValueError(msg) from e + + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + return load_chat_template(chat_template, is_literal=True) + + def random_tool_call_id() -> str: return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py index 001cfad3e..5bfe46495 100644 --- a/fastdeploy/entrypoints/llm.py +++ b/fastdeploy/entrypoints/llm.py @@ -28,6 +28,7 @@ from tqdm import tqdm from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.engine import LLMEngine from fastdeploy.engine.sampling_params import SamplingParams +from fastdeploy.entrypoints.chat_utils import load_chat_template from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager from fastdeploy.plugins.model_register import load_model_register_plugins from fastdeploy.utils import ( @@ -74,6 +75,7 @@ class LLM: revision: Optional[str] = "master", tokenizer: Optional[str] = None, enable_logprob: Optional[bool] = False, + chat_template: Optional[str] = None, **kwargs, ): deprecated_kwargs_warning(**kwargs) @@ -102,6 +104,7 @@ class LLM: self.master_node_ip = self.llm_engine.cfg.master_ip self._receive_output_thread = threading.Thread(target=self._receive_output, daemon=True) self._receive_output_thread.start() + self.chat_template = load_chat_template(chat_template) def _check_master(self): """ @@ -196,6 +199,7 @@ class LLM: sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None, use_tqdm: bool = True, chat_template_kwargs: Optional[dict[str, Any]] = None, + chat_template: Optional[str] = None, ): """ Args: @@ -229,6 +233,9 @@ class LLM: if sampling_params_len != 1 and len(messages) != sampling_params_len: raise ValueError("messages and sampling_params must be the same length.") + if chat_template is None: + chat_template = self.chat_template + messages_len = len(messages) for i in range(messages_len): messages[i] = {"messages": messages[i]} @@ -236,6 +243,7 @@ class LLM: prompts=messages, sampling_params=sampling_params, chat_template_kwargs=chat_template_kwargs, + chat_template=chat_template, ) topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 3f92ced34..ee5f0d62a 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -30,6 +30,7 @@ from prometheus_client import CONTENT_TYPE_LATEST from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.engine import LLMEngine +from fastdeploy.entrypoints.chat_utils import load_chat_template from fastdeploy.entrypoints.engine_client import EngineClient from fastdeploy.entrypoints.openai.protocol import ( ChatCompletionRequest, @@ -77,6 +78,7 @@ parser.add_argument("--max-concurrency", default=512, type=int, help="max concur parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() args.model = retrive_model_from_server(args.model, args.revision) +chat_template = load_chat_template(args.chat_template) if args.tool_parser_plugin: ToolParserManager.import_tool_parser(args.tool_parser_plugin) llm_engine = None @@ -141,7 +143,7 @@ async def lifespan(app: FastAPI): args.tool_call_parser, ) app.state.dynamic_load_weight = args.dynamic_load_weight - chat_handler = OpenAIServingChat(engine_client, pid, args.ips, args.max_waiting_time) + chat_handler = OpenAIServingChat(engine_client, pid, args.ips, args.max_waiting_time, chat_template) completion_handler = OpenAIServingCompletion(engine_client, pid, args.ips, args.max_waiting_time) engine_client.create_zmq_client(model=pid, mode=zmq.PUSH) engine_client.pid = pid diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 2049fb971..508c27f06 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -524,6 +524,7 @@ class ChatCompletionRequest(BaseModel): # doc: start-completion-extra-params chat_template_kwargs: Optional[dict] = None + chat_template: Optional[str] = None reasoning_max_tokens: Optional[int] = None structural_tag: Optional[str] = None guided_json: Optional[Union[str, dict, BaseModel]] = None diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index b14f28e62..d52433c0d 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -49,12 +49,13 @@ class OpenAIServingChat: OpenAI-style chat completions serving """ - def __init__(self, engine_client, pid, ips, max_waiting_time): + def __init__(self, engine_client, pid, ips, max_waiting_time, chat_template): self.engine_client = engine_client self.pid = pid self.master_ip = ips self.max_waiting_time = max_waiting_time self.host_ip = get_host_ip() + self.chat_template = chat_template if self.master_ip is not None: if isinstance(self.master_ip, list): self.master_ip = self.master_ip[0] @@ -86,6 +87,8 @@ class OpenAIServingChat: text_after_process = None try: current_req_dict = request.to_dict_for_infer(request_id) + if "chat_template" not in current_req_dict: + current_req_dict["chat_template"] = self.chat_template current_req_dict["arrival_time"] = time.time() prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict) text_after_process = current_req_dict.get("text_after_process") diff --git a/fastdeploy/input/ernie_processor.py b/fastdeploy/input/ernie_processor.py index e4424a0b8..3401803c4 100644 --- a/fastdeploy/input/ernie_processor.py +++ b/fastdeploy/input/ernie_processor.py @@ -87,6 +87,7 @@ class ErnieProcessor(BaseDataProcessor): bool: Whether preprocessing is successful str: error message """ + request.chat_template = kwargs.get("chat_template") request = self._apply_default_parameters(request) if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0: request.eos_token_ids = self.eos_token_ids @@ -342,6 +343,7 @@ class ErnieProcessor(BaseDataProcessor): tokenize=False, split_special_tokens=False, add_special_tokens=False, + chat_template=request_or_messages.get("chat_template", None), ) request_or_messages["text_after_process"] = spliced_message req_id = None diff --git a/fastdeploy/input/ernie_vl_processor.py b/fastdeploy/input/ernie_vl_processor.py index e8239f7ad..adbf990e8 100644 --- a/fastdeploy/input/ernie_vl_processor.py +++ b/fastdeploy/input/ernie_vl_processor.py @@ -109,6 +109,7 @@ class ErnieMoEVLProcessor(ErnieProcessor): def process_request(self, request, max_model_len=None, **kwargs): """process the input data""" + request.chat_template = kwargs.get("chat_template") task = request.to_dict() task["enable_thinking"] = kwargs.get("enable_thinking", True) self.process_request_dict(task, max_model_len) diff --git a/fastdeploy/input/mm_processor/process.py b/fastdeploy/input/mm_processor/process.py index 65fad4dbd..9df979cc0 100644 --- a/fastdeploy/input/mm_processor/process.py +++ b/fastdeploy/input/mm_processor/process.py @@ -494,10 +494,12 @@ class DataProcessor: """ if self.tokenizer.chat_template is None: raise ValueError("This model does not support chat_template.") + prompt_token_template = self.tokenizer.apply_chat_template( request, tokenize=False, add_generation_prompt=request.get("add_generation_prompt", True), + chat_template=request.get("chat_template", None), ) prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace( "<|video@placeholder|>", "" diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py index e842e964b..225fe4fbb 100644 --- a/fastdeploy/input/text_processor.py +++ b/fastdeploy/input/text_processor.py @@ -204,6 +204,7 @@ class DataProcessor(BaseDataProcessor): bool: Whether preprocessing is successful str: error message """ + request.chat_template = kwargs.get("chat_template") request = self._apply_default_parameters(request) if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0: request.eos_token_ids = self.eos_token_ids @@ -486,6 +487,7 @@ class DataProcessor(BaseDataProcessor): split_special_tokens=False, add_special_tokens=False, return_tensors="pd", + chat_template=request.get("chat_template", None), ) request["text_after_process"] = spliced_message req_id = None diff --git a/requirements_dcu.txt b/requirements_dcu.txt index 24098bc98..79bac3a62 100644 --- a/requirements_dcu.txt +++ b/requirements_dcu.txt @@ -35,3 +35,4 @@ opentelemetry-instrumentation-mysql opentelemetry-distro  opentelemetry-exporter-otlp opentelemetry-instrumentation-fastapi +partial_json_parser diff --git a/requirements_iluvatar.txt b/requirements_iluvatar.txt index 46bf217bb..d481e3feb 100644 --- a/requirements_iluvatar.txt +++ b/requirements_iluvatar.txt @@ -36,3 +36,4 @@ opentelemetry-instrumentation-mysql opentelemetry-distro opentelemetry-exporter-otlp opentelemetry-instrumentation-fastapi +partial_json_parser diff --git a/requirements_metaxgpu.txt b/requirements_metaxgpu.txt index 305f9825f..7aa310fa2 100644 --- a/requirements_metaxgpu.txt +++ b/requirements_metaxgpu.txt @@ -37,3 +37,4 @@ opentelemetry-instrumentation-mysql opentelemetry-distro  opentelemetry-exporter-otlp opentelemetry-instrumentation-fastapi +partial_json_parser diff --git a/test/utils/test_custom_chat_template.py b/test/utils/test_custom_chat_template.py new file mode 100644 index 000000000..27a66c4e9 --- /dev/null +++ b/test/utils/test_custom_chat_template.py @@ -0,0 +1,205 @@ +import os +import unittest +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, mock_open, patch + +from fastdeploy.engine.request import Request +from fastdeploy.engine.sampling_params import SamplingParams +from fastdeploy.entrypoints.chat_utils import load_chat_template +from fastdeploy.entrypoints.llm import LLM +from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest +from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat +from fastdeploy.input.ernie_processor import ErnieProcessor +from fastdeploy.input.ernie_vl_processor import ErnieMoEVLProcessor +from fastdeploy.input.text_processor import DataProcessor + + +class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + """ + Set up the test environment by creating an instance of the LLM class using Mock. + """ + self.input_chat_template = "unit test \n" + self.mock_engine = MagicMock() + self.tokenizer = MagicMock() + + def test_load_chat_template_non(self): + result = load_chat_template(None) + self.assertEqual(None, result) + + def test_load_chat_template_str(self): + result = load_chat_template(self.input_chat_template) + self.assertEqual(self.input_chat_template, result) + + def test_load_chat_template_path(self): + with open("chat_template", "w", encoding="utf-8") as file: + file.write(self.input_chat_template) + file_path = os.path.join(os.getcwd(), "chat_template") + result = load_chat_template(file_path) + os.remove(file_path) + self.assertEqual(self.input_chat_template, result) + + def test_load_chat_template_non_str_and_path(self): + with self.assertRaises(ValueError): + load_chat_template("unit test") + + def test_path_with_literal_true(self): + with self.assertRaises(TypeError): + load_chat_template(Path("./chat_template"), is_literal=True) + + def test_path_object_file_error(self): + with patch("builtins.open", mock_open()) as mock_file: + mock_file.side_effect = OSError("File error") + with self.assertRaises(OSError): + load_chat_template(Path("./chat_template")) + + async def test_serving_chat(self): + request = ChatCompletionRequest(messages=[{"role": "user", "content": "你好"}]) + self.chat_completion_handler = OpenAIServingChat( + self.mock_engine, pid=123, ips=None, max_waiting_time=-1, chat_template=self.input_chat_template + ) + + async def mock_chat_completion_full_generator( + request, request_id, model_name, prompt_token_ids, text_after_process + ): + return prompt_token_ids + + def mock_format_and_add_data(current_req_dict): + return current_req_dict + + self.chat_completion_handler.chat_completion_full_generator = mock_chat_completion_full_generator + self.chat_completion_handler.engine_client.format_and_add_data = mock_format_and_add_data + self.chat_completion_handler.engine_client.semaphore = AsyncMock() + self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=None) + self.chat_completion_handler.engine_client.semaphore.status = MagicMock(return_value="mock_status") + chat_completiom = await self.chat_completion_handler.create_chat_completion(request) + self.assertEqual(self.input_chat_template, chat_completiom["chat_template"]) + + async def test_serving_chat_cus(self): + request = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}], chat_template="hello") + self.chat_completion_handler = OpenAIServingChat( + self.mock_engine, pid=123, ips=None, max_waiting_time=10, chat_template=self.input_chat_template + ) + + async def mock_chat_completion_full_generator( + request, request_id, model_name, prompt_token_ids, text_after_process + ): + return prompt_token_ids + + def mock_format_and_add_data(current_req_dict): + return current_req_dict + + self.chat_completion_handler.chat_completion_full_generator = mock_chat_completion_full_generator + self.chat_completion_handler.engine_client.format_and_add_data = mock_format_and_add_data + self.chat_completion_handler.engine_client.semaphore = AsyncMock() + self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=None) + self.chat_completion_handler.engine_client.semaphore.status = MagicMock(return_value="mock_status") + chat_completion = await self.chat_completion_handler.create_chat_completion(request) + self.assertEqual("hello", chat_completion["chat_template"]) + + @patch("fastdeploy.input.ernie_vl_processor.ErnieMoEVLProcessor.__init__") + def test_vl_processor(self, mock_class): + mock_class.return_value = None + vl_processor = ErnieMoEVLProcessor() + mock_request = Request.from_dict({"request_id": "123"}) + + def mock_apply_default_parameters(request): + return request + + def mock_process_request(request, max_model_len): + return request + + vl_processor._apply_default_parameters = mock_apply_default_parameters + vl_processor.process_request_dict = mock_process_request + result = vl_processor.process_request(mock_request, chat_template="hello") + self.assertEqual("hello", result.chat_template) + + @patch("fastdeploy.input.text_processor.DataProcessor.__init__") + def test_text_processor_process_request(self, mock_class): + mock_class.return_value = None + text_processor = DataProcessor() + mock_request = Request.from_dict( + {"request_id": "123", "prompt": "hi", "max_tokens": 128, "temperature": 1, "top_p": 1} + ) + + def mock_apply_default_parameters(request): + return request + + def mock_process_request(request, max_model_len): + return request + + def mock_text2ids(text, max_model_len): + return [1] + + text_processor._apply_default_parameters = mock_apply_default_parameters + text_processor.process_request_dict = mock_process_request + text_processor.text2ids = mock_text2ids + text_processor.eos_token_ids = [1] + result = text_processor.process_request(mock_request, chat_template="hello") + self.assertEqual("hello", result.chat_template) + + @patch("fastdeploy.input.ernie_processor.ErnieProcessor.__init__") + def test_ernie_processor_process(self, mock_class): + mock_class.return_value = None + ernie_processor = ErnieProcessor() + mock_request = Request.from_dict( + {"request_id": "123", "messages": ["hi"], "max_tokens": 128, "temperature": 1, "top_p": 1} + ) + + def mock_apply_default_parameters(request): + return request + + def mock_process_request(request, max_model_len): + return request + + def mock_messages2ids(text): + return [1] + + ernie_processor._apply_default_parameters = mock_apply_default_parameters + ernie_processor.process_request_dict = mock_process_request + ernie_processor.messages2ids = mock_messages2ids + ernie_processor.eos_token_ids = [1] + result = ernie_processor.process_request(mock_request, chat_template="hello") + self.assertEqual("hello", result.chat_template) + + @patch("fastdeploy.entrypoints.llm.LLM.__init__") + def test_llm_load(self, mock_class): + mock_class.return_value = None + llm = LLM() + llm.llm_engine = MagicMock() + llm.default_sampling_params = MagicMock() + llm.chat_template = "hello" + + def mock_run_engine(req_ids, **kwargs): + return req_ids + + def mock_add_request(**kwargs): + return kwargs.get("chat_template") + + llm._run_engine = mock_run_engine + llm._add_request = mock_add_request + result = llm.chat(["hello"], sampling_params=SamplingParams(1)) + self.assertEqual("hello", result) + + @patch("fastdeploy.entrypoints.llm.LLM.__init__") + def test_llm(self, mock_class): + mock_class.return_value = None + llm = LLM() + llm.llm_engine = MagicMock() + llm.default_sampling_params = MagicMock() + + def mock_run_engine(req_ids, **kwargs): + return req_ids + + def mock_add_request(**kwargs): + return kwargs.get("chat_template") + + llm._run_engine = mock_run_engine + llm._add_request = mock_add_request + result = llm.chat(["hello"], sampling_params=SamplingParams(1), chat_template="hello") + self.assertEqual("hello", result) + + +if __name__ == "__main__": + unittest.main()