[Feature] add custom chat template (#3251)

* add custom chat_template

* add custom chat_template

* add unittest

* fix

* add docs

* fix comment

* add offline chat

* fix unit test

* fix unit test

* fix

* fix pre commit

* fix unit test

* add unit test

* add unit test

* add unit test

* fix pre_commit

* fix enable_thinking

* fix pre commit

* fix pre commit

* fix unit test

* add requirements
This commit is contained in:
luukunn
2025-08-18 16:34:08 +08:00
committed by GitHub
parent 70ee910cd5
commit 9c129813f9
19 changed files with 288 additions and 3 deletions

View File

@@ -28,6 +28,7 @@ from tqdm import tqdm
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.engine import LLMEngine
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.chat_utils import load_chat_template
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
from fastdeploy.plugins.model_register import load_model_register_plugins
from fastdeploy.utils import (
@@ -74,6 +75,7 @@ class LLM:
revision: Optional[str] = "master",
tokenizer: Optional[str] = None,
enable_logprob: Optional[bool] = False,
chat_template: Optional[str] = None,
**kwargs,
):
deprecated_kwargs_warning(**kwargs)
@@ -102,6 +104,7 @@ class LLM:
self.master_node_ip = self.llm_engine.cfg.master_ip
self._receive_output_thread = threading.Thread(target=self._receive_output, daemon=True)
self._receive_output_thread.start()
self.chat_template = load_chat_template(chat_template)
def _check_master(self):
"""
@@ -196,6 +199,7 @@ class LLM:
sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None,
use_tqdm: bool = True,
chat_template_kwargs: Optional[dict[str, Any]] = None,
chat_template: Optional[str] = None,
):
"""
Args:
@@ -229,6 +233,9 @@ class LLM:
if sampling_params_len != 1 and len(messages) != sampling_params_len:
raise ValueError("messages and sampling_params must be the same length.")
if chat_template is None:
chat_template = self.chat_template
messages_len = len(messages)
for i in range(messages_len):
messages[i] = {"messages": messages[i]}
@@ -236,6 +243,7 @@ class LLM:
prompts=messages,
sampling_params=sampling_params,
chat_template_kwargs=chat_template_kwargs,
chat_template=chat_template,
)
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs