mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
[Feature] add custom chat template (#3251)
* add custom chat_template * add custom chat_template * add unittest * fix * add docs * fix comment * add offline chat * fix unit test * fix unit test * fix * fix pre commit * fix unit test * add unit test * add unit test * add unit test * fix pre_commit * fix enable_thinking * fix pre commit * fix pre commit * fix unit test * add requirements
This commit is contained in:
@@ -161,6 +161,9 @@ The following extra parameters are supported:
|
|||||||
chat_template_kwargs: Optional[dict] = None
|
chat_template_kwargs: Optional[dict] = None
|
||||||
# Additional parameters passed to the chat template, used for customizing dialogue formats (default None).
|
# Additional parameters passed to the chat template, used for customizing dialogue formats (default None).
|
||||||
|
|
||||||
|
chat_template: Optional[str] = None
|
||||||
|
# Custom chat template will override the model's default chat template (default None).
|
||||||
|
|
||||||
reasoning_max_tokens: Optional[int] = None
|
reasoning_max_tokens: Optional[int] = None
|
||||||
# Maximum number of tokens to generate during reasoning (e.g., CoT, chain of thought) (default None means using global max_tokens).
|
# Maximum number of tokens to generate during reasoning (e.g., CoT, chain of thought) (default None means using global max_tokens).
|
||||||
|
|
||||||
|
@@ -46,6 +46,7 @@ When using FastDeploy to deploy models (including offline inference and service
|
|||||||
| ```dynamic_load_weight``` | `int` | Whether to enable dynamic weight loading, default: 0 |
|
| ```dynamic_load_weight``` | `int` | Whether to enable dynamic weight loading, default: 0 |
|
||||||
| ```enable_expert_parallel``` | `bool` | Whether to enable expert parallel |
|
| ```enable_expert_parallel``` | `bool` | Whether to enable expert parallel |
|
||||||
| ```enable_logprob``` | `bool` | Whether to enable return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message.If logrpob is not used, this parameter can be omitted when starting |
|
| ```enable_logprob``` | `bool` | Whether to enable return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message.If logrpob is not used, this parameter can be omitted when starting |
|
||||||
|
| ```chat_template``` | `str` | Specify the template used for model concatenation, It supports both string input and file path input. The default value is None. If not specified, the model's default template will be used. |
|
||||||
|
|
||||||
## 1. Relationship between KVCache allocation, ```num_gpu_blocks_override``` and ```block_size```?
|
## 1. Relationship between KVCache allocation, ```num_gpu_blocks_override``` and ```block_size```?
|
||||||
|
|
||||||
|
@@ -160,6 +160,9 @@ repetition_penalty: Optional[float] = None
|
|||||||
chat_template_kwargs: Optional[dict] = None
|
chat_template_kwargs: Optional[dict] = None
|
||||||
# 传递给聊天模板(chat template)的额外参数,用于自定义对话格式(默认 None)。
|
# 传递给聊天模板(chat template)的额外参数,用于自定义对话格式(默认 None)。
|
||||||
|
|
||||||
|
chat_template: Optional[str] = None
|
||||||
|
# 自定义聊天模板,会覆盖模型默认的聊天模板,(默认 None)。
|
||||||
|
|
||||||
reasoning_max_tokens: Optional[int] = None
|
reasoning_max_tokens: Optional[int] = None
|
||||||
# 推理(如 CoT, 思维链)过程中生成的最大 token 数(默认 None 表示使用全局 max_tokens)。
|
# 推理(如 CoT, 思维链)过程中生成的最大 token 数(默认 None 表示使用全局 max_tokens)。
|
||||||
|
|
||||||
|
@@ -44,6 +44,7 @@
|
|||||||
| ```dynamic_load_weight``` | `int` | 是否动态加载权重,默认0 |
|
| ```dynamic_load_weight``` | `int` | 是否动态加载权重,默认0 |
|
||||||
| ```enable_expert_parallel``` | `bool` | 是否启用专家并行 |
|
| ```enable_expert_parallel``` | `bool` | 是否启用专家并行 |
|
||||||
| ```enable_logprob``` | `bool` | 是否启用输出token返回logprob。如果未使用 logrpob,则在启动时可以省略此参数。 |
|
| ```enable_logprob``` | `bool` | 是否启用输出token返回logprob。如果未使用 logrpob,则在启动时可以省略此参数。 |
|
||||||
|
| ```chat_template``` | `str` | 指定模型拼接使用的模板,支持字符串与文件路径,默认为None,如未指定,则使用模型默认模板 |
|
||||||
|
|
||||||
## 1. KVCache分配与```num_gpu_blocks_override```、```block_size```的关系?
|
## 1. KVCache分配与```num_gpu_blocks_override```、```block_size```的关系?
|
||||||
|
|
||||||
|
@@ -94,6 +94,10 @@ class EngineArgs:
|
|||||||
"""
|
"""
|
||||||
specifies the reasoning parser to use for extracting reasoning content from the model output
|
specifies the reasoning parser to use for extracting reasoning content from the model output
|
||||||
"""
|
"""
|
||||||
|
chat_template: str = None
|
||||||
|
"""
|
||||||
|
chat template or chat template file path
|
||||||
|
"""
|
||||||
tool_call_parser: str = None
|
tool_call_parser: str = None
|
||||||
"""
|
"""
|
||||||
specifies the tool call parser to use for extracting tool call from the model output
|
specifies the tool call parser to use for extracting tool call from the model output
|
||||||
@@ -442,6 +446,12 @@ class EngineArgs:
|
|||||||
help="Flag specifies the reasoning parser to use for extracting "
|
help="Flag specifies the reasoning parser to use for extracting "
|
||||||
"reasoning content from the model output",
|
"reasoning content from the model output",
|
||||||
)
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
"--chat-template",
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.chat_template,
|
||||||
|
help="chat template or chat template file path",
|
||||||
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--tool-call-parser",
|
"--tool-call-parser",
|
||||||
type=str,
|
type=str,
|
||||||
|
@@ -72,6 +72,7 @@ class Request:
|
|||||||
guided_json_object: Optional[bool] = None,
|
guided_json_object: Optional[bool] = None,
|
||||||
enable_thinking: Optional[bool] = True,
|
enable_thinking: Optional[bool] = True,
|
||||||
trace_carrier: dict = dict(),
|
trace_carrier: dict = dict(),
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
@@ -111,6 +112,8 @@ class Request:
|
|||||||
self.enable_thinking = enable_thinking
|
self.enable_thinking = enable_thinking
|
||||||
self.trace_carrier = trace_carrier
|
self.trace_carrier = trace_carrier
|
||||||
|
|
||||||
|
self.chat_template = chat_template
|
||||||
|
|
||||||
# token num
|
# token num
|
||||||
self.block_tables = []
|
self.block_tables = []
|
||||||
self.output_token_ids = []
|
self.output_token_ids = []
|
||||||
@@ -152,6 +155,7 @@ class Request:
|
|||||||
guided_json_object=d.get("guided_json_object", None),
|
guided_json_object=d.get("guided_json_object", None),
|
||||||
enable_thinking=d.get("enable_thinking", True),
|
enable_thinking=d.get("enable_thinking", True),
|
||||||
trace_carrier=d.get("trace_carrier", {}),
|
trace_carrier=d.get("trace_carrier", {}),
|
||||||
|
chat_template=d.get("chat_template", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -191,6 +195,7 @@ class Request:
|
|||||||
"draft_token_ids": self.draft_token_ids,
|
"draft_token_ids": self.draft_token_ids,
|
||||||
"enable_thinking": self.enable_thinking,
|
"enable_thinking": self.enable_thinking,
|
||||||
"trace_carrier": self.trace_carrier,
|
"trace_carrier": self.trace_carrier,
|
||||||
|
"chat_template": self.chat_template,
|
||||||
}
|
}
|
||||||
add_params = [
|
add_params = [
|
||||||
"guided_json",
|
"guided_json",
|
||||||
|
@@ -16,7 +16,8 @@
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import List, Literal, Union
|
from pathlib import Path
|
||||||
|
from typing import List, Literal, Optional, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@@ -159,5 +160,37 @@ def parse_chat_messages(messages):
|
|||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
|
|
||||||
|
def load_chat_template(
|
||||||
|
chat_template: Union[Path, str],
|
||||||
|
is_literal: bool = False,
|
||||||
|
) -> Optional[str]:
|
||||||
|
if chat_template is None:
|
||||||
|
return None
|
||||||
|
if is_literal:
|
||||||
|
if isinstance(chat_template, Path):
|
||||||
|
raise TypeError("chat_template is expected to be read directly " "from its value")
|
||||||
|
|
||||||
|
return chat_template
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(chat_template) as f:
|
||||||
|
return f.read()
|
||||||
|
except OSError as e:
|
||||||
|
if isinstance(chat_template, Path):
|
||||||
|
raise
|
||||||
|
JINJA_CHARS = "{}\n"
|
||||||
|
if not any(c in chat_template for c in JINJA_CHARS):
|
||||||
|
msg = (
|
||||||
|
f"The supplied chat template ({chat_template}) "
|
||||||
|
f"looks like a file path, but it failed to be "
|
||||||
|
f"opened. Reason: {e}"
|
||||||
|
)
|
||||||
|
raise ValueError(msg) from e
|
||||||
|
|
||||||
|
# If opening a file fails, set chat template to be args to
|
||||||
|
# ensure we decode so our escape are interpreted correctly
|
||||||
|
return load_chat_template(chat_template, is_literal=True)
|
||||||
|
|
||||||
|
|
||||||
def random_tool_call_id() -> str:
|
def random_tool_call_id() -> str:
|
||||||
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
|
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
|
||||||
|
@@ -28,6 +28,7 @@ from tqdm import tqdm
|
|||||||
from fastdeploy.engine.args_utils import EngineArgs
|
from fastdeploy.engine.args_utils import EngineArgs
|
||||||
from fastdeploy.engine.engine import LLMEngine
|
from fastdeploy.engine.engine import LLMEngine
|
||||||
from fastdeploy.engine.sampling_params import SamplingParams
|
from fastdeploy.engine.sampling_params import SamplingParams
|
||||||
|
from fastdeploy.entrypoints.chat_utils import load_chat_template
|
||||||
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
|
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
from fastdeploy.plugins.model_register import load_model_register_plugins
|
from fastdeploy.plugins.model_register import load_model_register_plugins
|
||||||
from fastdeploy.utils import (
|
from fastdeploy.utils import (
|
||||||
@@ -74,6 +75,7 @@ class LLM:
|
|||||||
revision: Optional[str] = "master",
|
revision: Optional[str] = "master",
|
||||||
tokenizer: Optional[str] = None,
|
tokenizer: Optional[str] = None,
|
||||||
enable_logprob: Optional[bool] = False,
|
enable_logprob: Optional[bool] = False,
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
deprecated_kwargs_warning(**kwargs)
|
deprecated_kwargs_warning(**kwargs)
|
||||||
@@ -102,6 +104,7 @@ class LLM:
|
|||||||
self.master_node_ip = self.llm_engine.cfg.master_ip
|
self.master_node_ip = self.llm_engine.cfg.master_ip
|
||||||
self._receive_output_thread = threading.Thread(target=self._receive_output, daemon=True)
|
self._receive_output_thread = threading.Thread(target=self._receive_output, daemon=True)
|
||||||
self._receive_output_thread.start()
|
self._receive_output_thread.start()
|
||||||
|
self.chat_template = load_chat_template(chat_template)
|
||||||
|
|
||||||
def _check_master(self):
|
def _check_master(self):
|
||||||
"""
|
"""
|
||||||
@@ -196,6 +199,7 @@ class LLM:
|
|||||||
sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None,
|
sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -229,6 +233,9 @@ class LLM:
|
|||||||
if sampling_params_len != 1 and len(messages) != sampling_params_len:
|
if sampling_params_len != 1 and len(messages) != sampling_params_len:
|
||||||
raise ValueError("messages and sampling_params must be the same length.")
|
raise ValueError("messages and sampling_params must be the same length.")
|
||||||
|
|
||||||
|
if chat_template is None:
|
||||||
|
chat_template = self.chat_template
|
||||||
|
|
||||||
messages_len = len(messages)
|
messages_len = len(messages)
|
||||||
for i in range(messages_len):
|
for i in range(messages_len):
|
||||||
messages[i] = {"messages": messages[i]}
|
messages[i] = {"messages": messages[i]}
|
||||||
@@ -236,6 +243,7 @@ class LLM:
|
|||||||
prompts=messages,
|
prompts=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
chat_template_kwargs=chat_template_kwargs,
|
chat_template_kwargs=chat_template_kwargs,
|
||||||
|
chat_template=chat_template,
|
||||||
)
|
)
|
||||||
|
|
||||||
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs
|
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs
|
||||||
|
@@ -30,6 +30,7 @@ from prometheus_client import CONTENT_TYPE_LATEST
|
|||||||
|
|
||||||
from fastdeploy.engine.args_utils import EngineArgs
|
from fastdeploy.engine.args_utils import EngineArgs
|
||||||
from fastdeploy.engine.engine import LLMEngine
|
from fastdeploy.engine.engine import LLMEngine
|
||||||
|
from fastdeploy.entrypoints.chat_utils import load_chat_template
|
||||||
from fastdeploy.entrypoints.engine_client import EngineClient
|
from fastdeploy.entrypoints.engine_client import EngineClient
|
||||||
from fastdeploy.entrypoints.openai.protocol import (
|
from fastdeploy.entrypoints.openai.protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
@@ -77,6 +78,7 @@ parser.add_argument("--max-concurrency", default=512, type=int, help="max concur
|
|||||||
parser = EngineArgs.add_cli_args(parser)
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.model = retrive_model_from_server(args.model, args.revision)
|
args.model = retrive_model_from_server(args.model, args.revision)
|
||||||
|
chat_template = load_chat_template(args.chat_template)
|
||||||
if args.tool_parser_plugin:
|
if args.tool_parser_plugin:
|
||||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||||
llm_engine = None
|
llm_engine = None
|
||||||
@@ -141,7 +143,7 @@ async def lifespan(app: FastAPI):
|
|||||||
args.tool_call_parser,
|
args.tool_call_parser,
|
||||||
)
|
)
|
||||||
app.state.dynamic_load_weight = args.dynamic_load_weight
|
app.state.dynamic_load_weight = args.dynamic_load_weight
|
||||||
chat_handler = OpenAIServingChat(engine_client, pid, args.ips, args.max_waiting_time)
|
chat_handler = OpenAIServingChat(engine_client, pid, args.ips, args.max_waiting_time, chat_template)
|
||||||
completion_handler = OpenAIServingCompletion(engine_client, pid, args.ips, args.max_waiting_time)
|
completion_handler = OpenAIServingCompletion(engine_client, pid, args.ips, args.max_waiting_time)
|
||||||
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
|
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
|
||||||
engine_client.pid = pid
|
engine_client.pid = pid
|
||||||
|
@@ -524,6 +524,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
|
|
||||||
# doc: start-completion-extra-params
|
# doc: start-completion-extra-params
|
||||||
chat_template_kwargs: Optional[dict] = None
|
chat_template_kwargs: Optional[dict] = None
|
||||||
|
chat_template: Optional[str] = None
|
||||||
reasoning_max_tokens: Optional[int] = None
|
reasoning_max_tokens: Optional[int] = None
|
||||||
structural_tag: Optional[str] = None
|
structural_tag: Optional[str] = None
|
||||||
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
||||||
|
@@ -49,12 +49,13 @@ class OpenAIServingChat:
|
|||||||
OpenAI-style chat completions serving
|
OpenAI-style chat completions serving
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, engine_client, pid, ips, max_waiting_time):
|
def __init__(self, engine_client, pid, ips, max_waiting_time, chat_template):
|
||||||
self.engine_client = engine_client
|
self.engine_client = engine_client
|
||||||
self.pid = pid
|
self.pid = pid
|
||||||
self.master_ip = ips
|
self.master_ip = ips
|
||||||
self.max_waiting_time = max_waiting_time
|
self.max_waiting_time = max_waiting_time
|
||||||
self.host_ip = get_host_ip()
|
self.host_ip = get_host_ip()
|
||||||
|
self.chat_template = chat_template
|
||||||
if self.master_ip is not None:
|
if self.master_ip is not None:
|
||||||
if isinstance(self.master_ip, list):
|
if isinstance(self.master_ip, list):
|
||||||
self.master_ip = self.master_ip[0]
|
self.master_ip = self.master_ip[0]
|
||||||
@@ -86,6 +87,8 @@ class OpenAIServingChat:
|
|||||||
text_after_process = None
|
text_after_process = None
|
||||||
try:
|
try:
|
||||||
current_req_dict = request.to_dict_for_infer(request_id)
|
current_req_dict = request.to_dict_for_infer(request_id)
|
||||||
|
if "chat_template" not in current_req_dict:
|
||||||
|
current_req_dict["chat_template"] = self.chat_template
|
||||||
current_req_dict["arrival_time"] = time.time()
|
current_req_dict["arrival_time"] = time.time()
|
||||||
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
|
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
|
||||||
text_after_process = current_req_dict.get("text_after_process")
|
text_after_process = current_req_dict.get("text_after_process")
|
||||||
|
@@ -87,6 +87,7 @@ class ErnieProcessor(BaseDataProcessor):
|
|||||||
bool: Whether preprocessing is successful
|
bool: Whether preprocessing is successful
|
||||||
str: error message
|
str: error message
|
||||||
"""
|
"""
|
||||||
|
request.chat_template = kwargs.get("chat_template")
|
||||||
request = self._apply_default_parameters(request)
|
request = self._apply_default_parameters(request)
|
||||||
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
|
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
|
||||||
request.eos_token_ids = self.eos_token_ids
|
request.eos_token_ids = self.eos_token_ids
|
||||||
@@ -342,6 +343,7 @@ class ErnieProcessor(BaseDataProcessor):
|
|||||||
tokenize=False,
|
tokenize=False,
|
||||||
split_special_tokens=False,
|
split_special_tokens=False,
|
||||||
add_special_tokens=False,
|
add_special_tokens=False,
|
||||||
|
chat_template=request_or_messages.get("chat_template", None),
|
||||||
)
|
)
|
||||||
request_or_messages["text_after_process"] = spliced_message
|
request_or_messages["text_after_process"] = spliced_message
|
||||||
req_id = None
|
req_id = None
|
||||||
|
@@ -109,6 +109,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
|
|||||||
|
|
||||||
def process_request(self, request, max_model_len=None, **kwargs):
|
def process_request(self, request, max_model_len=None, **kwargs):
|
||||||
"""process the input data"""
|
"""process the input data"""
|
||||||
|
request.chat_template = kwargs.get("chat_template")
|
||||||
task = request.to_dict()
|
task = request.to_dict()
|
||||||
task["enable_thinking"] = kwargs.get("enable_thinking", True)
|
task["enable_thinking"] = kwargs.get("enable_thinking", True)
|
||||||
self.process_request_dict(task, max_model_len)
|
self.process_request_dict(task, max_model_len)
|
||||||
|
@@ -494,10 +494,12 @@ class DataProcessor:
|
|||||||
"""
|
"""
|
||||||
if self.tokenizer.chat_template is None:
|
if self.tokenizer.chat_template is None:
|
||||||
raise ValueError("This model does not support chat_template.")
|
raise ValueError("This model does not support chat_template.")
|
||||||
|
|
||||||
prompt_token_template = self.tokenizer.apply_chat_template(
|
prompt_token_template = self.tokenizer.apply_chat_template(
|
||||||
request,
|
request,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=request.get("add_generation_prompt", True),
|
add_generation_prompt=request.get("add_generation_prompt", True),
|
||||||
|
chat_template=request.get("chat_template", None),
|
||||||
)
|
)
|
||||||
prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace(
|
prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace(
|
||||||
"<|video@placeholder|>", ""
|
"<|video@placeholder|>", ""
|
||||||
|
@@ -204,6 +204,7 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
bool: Whether preprocessing is successful
|
bool: Whether preprocessing is successful
|
||||||
str: error message
|
str: error message
|
||||||
"""
|
"""
|
||||||
|
request.chat_template = kwargs.get("chat_template")
|
||||||
request = self._apply_default_parameters(request)
|
request = self._apply_default_parameters(request)
|
||||||
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
|
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
|
||||||
request.eos_token_ids = self.eos_token_ids
|
request.eos_token_ids = self.eos_token_ids
|
||||||
@@ -486,6 +487,7 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
split_special_tokens=False,
|
split_special_tokens=False,
|
||||||
add_special_tokens=False,
|
add_special_tokens=False,
|
||||||
return_tensors="pd",
|
return_tensors="pd",
|
||||||
|
chat_template=request.get("chat_template", None),
|
||||||
)
|
)
|
||||||
request["text_after_process"] = spliced_message
|
request["text_after_process"] = spliced_message
|
||||||
req_id = None
|
req_id = None
|
||||||
|
@@ -35,3 +35,4 @@ opentelemetry-instrumentation-mysql
|
|||||||
opentelemetry-distro
|
opentelemetry-distro
|
||||||
opentelemetry-exporter-otlp
|
opentelemetry-exporter-otlp
|
||||||
opentelemetry-instrumentation-fastapi
|
opentelemetry-instrumentation-fastapi
|
||||||
|
partial_json_parser
|
||||||
|
@@ -36,3 +36,4 @@ opentelemetry-instrumentation-mysql
|
|||||||
opentelemetry-distro
|
opentelemetry-distro
|
||||||
opentelemetry-exporter-otlp
|
opentelemetry-exporter-otlp
|
||||||
opentelemetry-instrumentation-fastapi
|
opentelemetry-instrumentation-fastapi
|
||||||
|
partial_json_parser
|
||||||
|
@@ -37,3 +37,4 @@ opentelemetry-instrumentation-mysql
|
|||||||
opentelemetry-distro
|
opentelemetry-distro
|
||||||
opentelemetry-exporter-otlp
|
opentelemetry-exporter-otlp
|
||||||
opentelemetry-instrumentation-fastapi
|
opentelemetry-instrumentation-fastapi
|
||||||
|
partial_json_parser
|
||||||
|
205
test/utils/test_custom_chat_template.py
Normal file
205
test/utils/test_custom_chat_template.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||||||
|
|
||||||
|
from fastdeploy.engine.request import Request
|
||||||
|
from fastdeploy.engine.sampling_params import SamplingParams
|
||||||
|
from fastdeploy.entrypoints.chat_utils import load_chat_template
|
||||||
|
from fastdeploy.entrypoints.llm import LLM
|
||||||
|
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
|
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
|
from fastdeploy.input.ernie_processor import ErnieProcessor
|
||||||
|
from fastdeploy.input.ernie_vl_processor import ErnieMoEVLProcessor
|
||||||
|
from fastdeploy.input.text_processor import DataProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""
|
||||||
|
Set up the test environment by creating an instance of the LLM class using Mock.
|
||||||
|
"""
|
||||||
|
self.input_chat_template = "unit test \n"
|
||||||
|
self.mock_engine = MagicMock()
|
||||||
|
self.tokenizer = MagicMock()
|
||||||
|
|
||||||
|
def test_load_chat_template_non(self):
|
||||||
|
result = load_chat_template(None)
|
||||||
|
self.assertEqual(None, result)
|
||||||
|
|
||||||
|
def test_load_chat_template_str(self):
|
||||||
|
result = load_chat_template(self.input_chat_template)
|
||||||
|
self.assertEqual(self.input_chat_template, result)
|
||||||
|
|
||||||
|
def test_load_chat_template_path(self):
|
||||||
|
with open("chat_template", "w", encoding="utf-8") as file:
|
||||||
|
file.write(self.input_chat_template)
|
||||||
|
file_path = os.path.join(os.getcwd(), "chat_template")
|
||||||
|
result = load_chat_template(file_path)
|
||||||
|
os.remove(file_path)
|
||||||
|
self.assertEqual(self.input_chat_template, result)
|
||||||
|
|
||||||
|
def test_load_chat_template_non_str_and_path(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
load_chat_template("unit test")
|
||||||
|
|
||||||
|
def test_path_with_literal_true(self):
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
load_chat_template(Path("./chat_template"), is_literal=True)
|
||||||
|
|
||||||
|
def test_path_object_file_error(self):
|
||||||
|
with patch("builtins.open", mock_open()) as mock_file:
|
||||||
|
mock_file.side_effect = OSError("File error")
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
load_chat_template(Path("./chat_template"))
|
||||||
|
|
||||||
|
async def test_serving_chat(self):
|
||||||
|
request = ChatCompletionRequest(messages=[{"role": "user", "content": "你好"}])
|
||||||
|
self.chat_completion_handler = OpenAIServingChat(
|
||||||
|
self.mock_engine, pid=123, ips=None, max_waiting_time=-1, chat_template=self.input_chat_template
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_chat_completion_full_generator(
|
||||||
|
request, request_id, model_name, prompt_token_ids, text_after_process
|
||||||
|
):
|
||||||
|
return prompt_token_ids
|
||||||
|
|
||||||
|
def mock_format_and_add_data(current_req_dict):
|
||||||
|
return current_req_dict
|
||||||
|
|
||||||
|
self.chat_completion_handler.chat_completion_full_generator = mock_chat_completion_full_generator
|
||||||
|
self.chat_completion_handler.engine_client.format_and_add_data = mock_format_and_add_data
|
||||||
|
self.chat_completion_handler.engine_client.semaphore = AsyncMock()
|
||||||
|
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=None)
|
||||||
|
self.chat_completion_handler.engine_client.semaphore.status = MagicMock(return_value="mock_status")
|
||||||
|
chat_completiom = await self.chat_completion_handler.create_chat_completion(request)
|
||||||
|
self.assertEqual(self.input_chat_template, chat_completiom["chat_template"])
|
||||||
|
|
||||||
|
async def test_serving_chat_cus(self):
|
||||||
|
request = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}], chat_template="hello")
|
||||||
|
self.chat_completion_handler = OpenAIServingChat(
|
||||||
|
self.mock_engine, pid=123, ips=None, max_waiting_time=10, chat_template=self.input_chat_template
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_chat_completion_full_generator(
|
||||||
|
request, request_id, model_name, prompt_token_ids, text_after_process
|
||||||
|
):
|
||||||
|
return prompt_token_ids
|
||||||
|
|
||||||
|
def mock_format_and_add_data(current_req_dict):
|
||||||
|
return current_req_dict
|
||||||
|
|
||||||
|
self.chat_completion_handler.chat_completion_full_generator = mock_chat_completion_full_generator
|
||||||
|
self.chat_completion_handler.engine_client.format_and_add_data = mock_format_and_add_data
|
||||||
|
self.chat_completion_handler.engine_client.semaphore = AsyncMock()
|
||||||
|
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=None)
|
||||||
|
self.chat_completion_handler.engine_client.semaphore.status = MagicMock(return_value="mock_status")
|
||||||
|
chat_completion = await self.chat_completion_handler.create_chat_completion(request)
|
||||||
|
self.assertEqual("hello", chat_completion["chat_template"])
|
||||||
|
|
||||||
|
@patch("fastdeploy.input.ernie_vl_processor.ErnieMoEVLProcessor.__init__")
|
||||||
|
def test_vl_processor(self, mock_class):
|
||||||
|
mock_class.return_value = None
|
||||||
|
vl_processor = ErnieMoEVLProcessor()
|
||||||
|
mock_request = Request.from_dict({"request_id": "123"})
|
||||||
|
|
||||||
|
def mock_apply_default_parameters(request):
|
||||||
|
return request
|
||||||
|
|
||||||
|
def mock_process_request(request, max_model_len):
|
||||||
|
return request
|
||||||
|
|
||||||
|
vl_processor._apply_default_parameters = mock_apply_default_parameters
|
||||||
|
vl_processor.process_request_dict = mock_process_request
|
||||||
|
result = vl_processor.process_request(mock_request, chat_template="hello")
|
||||||
|
self.assertEqual("hello", result.chat_template)
|
||||||
|
|
||||||
|
@patch("fastdeploy.input.text_processor.DataProcessor.__init__")
|
||||||
|
def test_text_processor_process_request(self, mock_class):
|
||||||
|
mock_class.return_value = None
|
||||||
|
text_processor = DataProcessor()
|
||||||
|
mock_request = Request.from_dict(
|
||||||
|
{"request_id": "123", "prompt": "hi", "max_tokens": 128, "temperature": 1, "top_p": 1}
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_apply_default_parameters(request):
|
||||||
|
return request
|
||||||
|
|
||||||
|
def mock_process_request(request, max_model_len):
|
||||||
|
return request
|
||||||
|
|
||||||
|
def mock_text2ids(text, max_model_len):
|
||||||
|
return [1]
|
||||||
|
|
||||||
|
text_processor._apply_default_parameters = mock_apply_default_parameters
|
||||||
|
text_processor.process_request_dict = mock_process_request
|
||||||
|
text_processor.text2ids = mock_text2ids
|
||||||
|
text_processor.eos_token_ids = [1]
|
||||||
|
result = text_processor.process_request(mock_request, chat_template="hello")
|
||||||
|
self.assertEqual("hello", result.chat_template)
|
||||||
|
|
||||||
|
@patch("fastdeploy.input.ernie_processor.ErnieProcessor.__init__")
|
||||||
|
def test_ernie_processor_process(self, mock_class):
|
||||||
|
mock_class.return_value = None
|
||||||
|
ernie_processor = ErnieProcessor()
|
||||||
|
mock_request = Request.from_dict(
|
||||||
|
{"request_id": "123", "messages": ["hi"], "max_tokens": 128, "temperature": 1, "top_p": 1}
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_apply_default_parameters(request):
|
||||||
|
return request
|
||||||
|
|
||||||
|
def mock_process_request(request, max_model_len):
|
||||||
|
return request
|
||||||
|
|
||||||
|
def mock_messages2ids(text):
|
||||||
|
return [1]
|
||||||
|
|
||||||
|
ernie_processor._apply_default_parameters = mock_apply_default_parameters
|
||||||
|
ernie_processor.process_request_dict = mock_process_request
|
||||||
|
ernie_processor.messages2ids = mock_messages2ids
|
||||||
|
ernie_processor.eos_token_ids = [1]
|
||||||
|
result = ernie_processor.process_request(mock_request, chat_template="hello")
|
||||||
|
self.assertEqual("hello", result.chat_template)
|
||||||
|
|
||||||
|
@patch("fastdeploy.entrypoints.llm.LLM.__init__")
|
||||||
|
def test_llm_load(self, mock_class):
|
||||||
|
mock_class.return_value = None
|
||||||
|
llm = LLM()
|
||||||
|
llm.llm_engine = MagicMock()
|
||||||
|
llm.default_sampling_params = MagicMock()
|
||||||
|
llm.chat_template = "hello"
|
||||||
|
|
||||||
|
def mock_run_engine(req_ids, **kwargs):
|
||||||
|
return req_ids
|
||||||
|
|
||||||
|
def mock_add_request(**kwargs):
|
||||||
|
return kwargs.get("chat_template")
|
||||||
|
|
||||||
|
llm._run_engine = mock_run_engine
|
||||||
|
llm._add_request = mock_add_request
|
||||||
|
result = llm.chat(["hello"], sampling_params=SamplingParams(1))
|
||||||
|
self.assertEqual("hello", result)
|
||||||
|
|
||||||
|
@patch("fastdeploy.entrypoints.llm.LLM.__init__")
|
||||||
|
def test_llm(self, mock_class):
|
||||||
|
mock_class.return_value = None
|
||||||
|
llm = LLM()
|
||||||
|
llm.llm_engine = MagicMock()
|
||||||
|
llm.default_sampling_params = MagicMock()
|
||||||
|
|
||||||
|
def mock_run_engine(req_ids, **kwargs):
|
||||||
|
return req_ids
|
||||||
|
|
||||||
|
def mock_add_request(**kwargs):
|
||||||
|
return kwargs.get("chat_template")
|
||||||
|
|
||||||
|
llm._run_engine = mock_run_engine
|
||||||
|
llm._add_request = mock_add_request
|
||||||
|
result = llm.chat(["hello"], sampling_params=SamplingParams(1), chat_template="hello")
|
||||||
|
self.assertEqual("hello", result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Reference in New Issue
Block a user