mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature] add custom chat template (#3251)
* add custom chat_template * add custom chat_template * add unittest * fix * add docs * fix comment * add offline chat * fix unit test * fix unit test * fix * fix pre commit * fix unit test * add unit test * add unit test * add unit test * fix pre_commit * fix enable_thinking * fix pre commit * fix pre commit * fix unit test * add requirements
This commit is contained in:
@@ -28,6 +28,7 @@ from tqdm import tqdm
|
||||
from fastdeploy.engine.args_utils import EngineArgs
|
||||
from fastdeploy.engine.engine import LLMEngine
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.entrypoints.chat_utils import load_chat_template
|
||||
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from fastdeploy.plugins.model_register import load_model_register_plugins
|
||||
from fastdeploy.utils import (
|
||||
@@ -74,6 +75,7 @@ class LLM:
|
||||
revision: Optional[str] = "master",
|
||||
tokenizer: Optional[str] = None,
|
||||
enable_logprob: Optional[bool] = False,
|
||||
chat_template: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
deprecated_kwargs_warning(**kwargs)
|
||||
@@ -102,6 +104,7 @@ class LLM:
|
||||
self.master_node_ip = self.llm_engine.cfg.master_ip
|
||||
self._receive_output_thread = threading.Thread(target=self._receive_output, daemon=True)
|
||||
self._receive_output_thread.start()
|
||||
self.chat_template = load_chat_template(chat_template)
|
||||
|
||||
def _check_master(self):
|
||||
"""
|
||||
@@ -196,6 +199,7 @@ class LLM:
|
||||
sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
chat_template: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -229,6 +233,9 @@ class LLM:
|
||||
if sampling_params_len != 1 and len(messages) != sampling_params_len:
|
||||
raise ValueError("messages and sampling_params must be the same length.")
|
||||
|
||||
if chat_template is None:
|
||||
chat_template = self.chat_template
|
||||
|
||||
messages_len = len(messages)
|
||||
for i in range(messages_len):
|
||||
messages[i] = {"messages": messages[i]}
|
||||
@@ -236,6 +243,7 @@ class LLM:
|
||||
prompts=messages,
|
||||
sampling_params=sampling_params,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
chat_template=chat_template,
|
||||
)
|
||||
|
||||
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs
|
||||
|
Reference in New Issue
Block a user