[Feature] add custom chat template (#3251)

* add custom chat_template

* add custom chat_template

* add unittest

* fix

* add docs

* fix comment

* add offline chat

* fix unit test

* fix unit test

* fix

* fix pre commit

* fix unit test

* add unit test

* add unit test

* add unit test

* fix pre_commit

* fix enable_thinking

* fix pre commit

* fix pre commit

* fix unit test

* add requirements
This commit is contained in:
luukunn
2025-08-18 16:34:08 +08:00
committed by GitHub
parent 70ee910cd5
commit 9c129813f9
19 changed files with 288 additions and 3 deletions

View File

@@ -161,6 +161,9 @@ The following extra parameters are supported:
chat_template_kwargs: Optional[dict] = None chat_template_kwargs: Optional[dict] = None
# Additional parameters passed to the chat template, used for customizing dialogue formats (default None). # Additional parameters passed to the chat template, used for customizing dialogue formats (default None).
chat_template: Optional[str] = None
# Custom chat template will override the model's default chat template (default None).
reasoning_max_tokens: Optional[int] = None reasoning_max_tokens: Optional[int] = None
# Maximum number of tokens to generate during reasoning (e.g., CoT, chain of thought) (default None means using global max_tokens). # Maximum number of tokens to generate during reasoning (e.g., CoT, chain of thought) (default None means using global max_tokens).

View File

@@ -46,6 +46,7 @@ When using FastDeploy to deploy models (including offline inference and service
| ```dynamic_load_weight``` | `int` | Whether to enable dynamic weight loading, default: 0 | | ```dynamic_load_weight``` | `int` | Whether to enable dynamic weight loading, default: 0 |
| ```enable_expert_parallel``` | `bool` | Whether to enable expert parallel | | ```enable_expert_parallel``` | `bool` | Whether to enable expert parallel |
| ```enable_logprob``` | `bool` | Whether to enable return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message.If logrpob is not used, this parameter can be omitted when starting | | ```enable_logprob``` | `bool` | Whether to enable return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message.If logrpob is not used, this parameter can be omitted when starting |
| ```chat_template``` | `str` | Specify the template used for model concatenation, It supports both string input and file path input. The default value is None. If not specified, the model's default template will be used. |
## 1. Relationship between KVCache allocation, ```num_gpu_blocks_override``` and ```block_size```? ## 1. Relationship between KVCache allocation, ```num_gpu_blocks_override``` and ```block_size```?

View File

@@ -160,6 +160,9 @@ repetition_penalty: Optional[float] = None
chat_template_kwargs: Optional[dict] = None chat_template_kwargs: Optional[dict] = None
# 传递给聊天模板chat template的额外参数用于自定义对话格式默认 None # 传递给聊天模板chat template的额外参数用于自定义对话格式默认 None
chat_template: Optional[str] = None
# 自定义聊天模板,会覆盖模型默认的聊天模板,(默认 None
reasoning_max_tokens: Optional[int] = None reasoning_max_tokens: Optional[int] = None
# 推理(如 CoT, 思维链)过程中生成的最大 token 数(默认 None 表示使用全局 max_tokens # 推理(如 CoT, 思维链)过程中生成的最大 token 数(默认 None 表示使用全局 max_tokens

View File

@@ -44,6 +44,7 @@
| ```dynamic_load_weight``` | `int` | 是否动态加载权重默认0 | | ```dynamic_load_weight``` | `int` | 是否动态加载权重默认0 |
| ```enable_expert_parallel``` | `bool` | 是否启用专家并行 | | ```enable_expert_parallel``` | `bool` | 是否启用专家并行 |
| ```enable_logprob``` | `bool` | 是否启用输出token返回logprob。如果未使用 logrpob则在启动时可以省略此参数。 | | ```enable_logprob``` | `bool` | 是否启用输出token返回logprob。如果未使用 logrpob则在启动时可以省略此参数。 |
| ```chat_template``` | `str` | 指定模型拼接使用的模板支持字符串与文件路径默认为None如未指定则使用模型默认模板 |
## 1. KVCache分配与```num_gpu_blocks_override```、```block_size```的关系? ## 1. KVCache分配与```num_gpu_blocks_override```、```block_size```的关系?

View File

@@ -94,6 +94,10 @@ class EngineArgs:
""" """
specifies the reasoning parser to use for extracting reasoning content from the model output specifies the reasoning parser to use for extracting reasoning content from the model output
""" """
chat_template: str = None
"""
chat template or chat template file path
"""
tool_call_parser: str = None tool_call_parser: str = None
""" """
specifies the tool call parser to use for extracting tool call from the model output specifies the tool call parser to use for extracting tool call from the model output
@@ -442,6 +446,12 @@ class EngineArgs:
help="Flag specifies the reasoning parser to use for extracting " help="Flag specifies the reasoning parser to use for extracting "
"reasoning content from the model output", "reasoning content from the model output",
) )
model_group.add_argument(
"--chat-template",
type=str,
default=EngineArgs.chat_template,
help="chat template or chat template file path",
)
model_group.add_argument( model_group.add_argument(
"--tool-call-parser", "--tool-call-parser",
type=str, type=str,

View File

@@ -72,6 +72,7 @@ class Request:
guided_json_object: Optional[bool] = None, guided_json_object: Optional[bool] = None,
enable_thinking: Optional[bool] = True, enable_thinking: Optional[bool] = True,
trace_carrier: dict = dict(), trace_carrier: dict = dict(),
chat_template: Optional[str] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.prompt = prompt self.prompt = prompt
@@ -111,6 +112,8 @@ class Request:
self.enable_thinking = enable_thinking self.enable_thinking = enable_thinking
self.trace_carrier = trace_carrier self.trace_carrier = trace_carrier
self.chat_template = chat_template
# token num # token num
self.block_tables = [] self.block_tables = []
self.output_token_ids = [] self.output_token_ids = []
@@ -152,6 +155,7 @@ class Request:
guided_json_object=d.get("guided_json_object", None), guided_json_object=d.get("guided_json_object", None),
enable_thinking=d.get("enable_thinking", True), enable_thinking=d.get("enable_thinking", True),
trace_carrier=d.get("trace_carrier", {}), trace_carrier=d.get("trace_carrier", {}),
chat_template=d.get("chat_template", None),
) )
@property @property
@@ -191,6 +195,7 @@ class Request:
"draft_token_ids": self.draft_token_ids, "draft_token_ids": self.draft_token_ids,
"enable_thinking": self.enable_thinking, "enable_thinking": self.enable_thinking,
"trace_carrier": self.trace_carrier, "trace_carrier": self.trace_carrier,
"chat_template": self.chat_template,
} }
add_params = [ add_params = [
"guided_json", "guided_json",

View File

@@ -16,7 +16,8 @@
import uuid import uuid
from copy import deepcopy from copy import deepcopy
from typing import List, Literal, Union from pathlib import Path
from typing import List, Literal, Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import requests import requests
@@ -159,5 +160,37 @@ def parse_chat_messages(messages):
return conversation return conversation
def load_chat_template(
chat_template: Union[Path, str],
is_literal: bool = False,
) -> Optional[str]:
if chat_template is None:
return None
if is_literal:
if isinstance(chat_template, Path):
raise TypeError("chat_template is expected to be read directly " "from its value")
return chat_template
try:
with open(chat_template) as f:
return f.read()
except OSError as e:
if isinstance(chat_template, Path):
raise
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (
f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}"
)
raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
return load_chat_template(chat_template, is_literal=True)
def random_tool_call_id() -> str: def random_tool_call_id() -> str:
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"

View File

@@ -28,6 +28,7 @@ from tqdm import tqdm
from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.engine import LLMEngine from fastdeploy.engine.engine import LLMEngine
from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.chat_utils import load_chat_template
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
from fastdeploy.plugins.model_register import load_model_register_plugins from fastdeploy.plugins.model_register import load_model_register_plugins
from fastdeploy.utils import ( from fastdeploy.utils import (
@@ -74,6 +75,7 @@ class LLM:
revision: Optional[str] = "master", revision: Optional[str] = "master",
tokenizer: Optional[str] = None, tokenizer: Optional[str] = None,
enable_logprob: Optional[bool] = False, enable_logprob: Optional[bool] = False,
chat_template: Optional[str] = None,
**kwargs, **kwargs,
): ):
deprecated_kwargs_warning(**kwargs) deprecated_kwargs_warning(**kwargs)
@@ -102,6 +104,7 @@ class LLM:
self.master_node_ip = self.llm_engine.cfg.master_ip self.master_node_ip = self.llm_engine.cfg.master_ip
self._receive_output_thread = threading.Thread(target=self._receive_output, daemon=True) self._receive_output_thread = threading.Thread(target=self._receive_output, daemon=True)
self._receive_output_thread.start() self._receive_output_thread.start()
self.chat_template = load_chat_template(chat_template)
def _check_master(self): def _check_master(self):
""" """
@@ -196,6 +199,7 @@ class LLM:
sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None, sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
chat_template_kwargs: Optional[dict[str, Any]] = None, chat_template_kwargs: Optional[dict[str, Any]] = None,
chat_template: Optional[str] = None,
): ):
""" """
Args: Args:
@@ -229,6 +233,9 @@ class LLM:
if sampling_params_len != 1 and len(messages) != sampling_params_len: if sampling_params_len != 1 and len(messages) != sampling_params_len:
raise ValueError("messages and sampling_params must be the same length.") raise ValueError("messages and sampling_params must be the same length.")
if chat_template is None:
chat_template = self.chat_template
messages_len = len(messages) messages_len = len(messages)
for i in range(messages_len): for i in range(messages_len):
messages[i] = {"messages": messages[i]} messages[i] = {"messages": messages[i]}
@@ -236,6 +243,7 @@ class LLM:
prompts=messages, prompts=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
chat_template_kwargs=chat_template_kwargs, chat_template_kwargs=chat_template_kwargs,
chat_template=chat_template,
) )
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs

View File

@@ -30,6 +30,7 @@ from prometheus_client import CONTENT_TYPE_LATEST
from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.engine import LLMEngine from fastdeploy.engine.engine import LLMEngine
from fastdeploy.entrypoints.chat_utils import load_chat_template
from fastdeploy.entrypoints.engine_client import EngineClient from fastdeploy.entrypoints.engine_client import EngineClient
from fastdeploy.entrypoints.openai.protocol import ( from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
@@ -77,6 +78,7 @@ parser.add_argument("--max-concurrency", default=512, type=int, help="max concur
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
args.model = retrive_model_from_server(args.model, args.revision) args.model = retrive_model_from_server(args.model, args.revision)
chat_template = load_chat_template(args.chat_template)
if args.tool_parser_plugin: if args.tool_parser_plugin:
ToolParserManager.import_tool_parser(args.tool_parser_plugin) ToolParserManager.import_tool_parser(args.tool_parser_plugin)
llm_engine = None llm_engine = None
@@ -141,7 +143,7 @@ async def lifespan(app: FastAPI):
args.tool_call_parser, args.tool_call_parser,
) )
app.state.dynamic_load_weight = args.dynamic_load_weight app.state.dynamic_load_weight = args.dynamic_load_weight
chat_handler = OpenAIServingChat(engine_client, pid, args.ips, args.max_waiting_time) chat_handler = OpenAIServingChat(engine_client, pid, args.ips, args.max_waiting_time, chat_template)
completion_handler = OpenAIServingCompletion(engine_client, pid, args.ips, args.max_waiting_time) completion_handler = OpenAIServingCompletion(engine_client, pid, args.ips, args.max_waiting_time)
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH) engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
engine_client.pid = pid engine_client.pid = pid

View File

@@ -524,6 +524,7 @@ class ChatCompletionRequest(BaseModel):
# doc: start-completion-extra-params # doc: start-completion-extra-params
chat_template_kwargs: Optional[dict] = None chat_template_kwargs: Optional[dict] = None
chat_template: Optional[str] = None
reasoning_max_tokens: Optional[int] = None reasoning_max_tokens: Optional[int] = None
structural_tag: Optional[str] = None structural_tag: Optional[str] = None
guided_json: Optional[Union[str, dict, BaseModel]] = None guided_json: Optional[Union[str, dict, BaseModel]] = None

View File

@@ -49,12 +49,13 @@ class OpenAIServingChat:
OpenAI-style chat completions serving OpenAI-style chat completions serving
""" """
def __init__(self, engine_client, pid, ips, max_waiting_time): def __init__(self, engine_client, pid, ips, max_waiting_time, chat_template):
self.engine_client = engine_client self.engine_client = engine_client
self.pid = pid self.pid = pid
self.master_ip = ips self.master_ip = ips
self.max_waiting_time = max_waiting_time self.max_waiting_time = max_waiting_time
self.host_ip = get_host_ip() self.host_ip = get_host_ip()
self.chat_template = chat_template
if self.master_ip is not None: if self.master_ip is not None:
if isinstance(self.master_ip, list): if isinstance(self.master_ip, list):
self.master_ip = self.master_ip[0] self.master_ip = self.master_ip[0]
@@ -86,6 +87,8 @@ class OpenAIServingChat:
text_after_process = None text_after_process = None
try: try:
current_req_dict = request.to_dict_for_infer(request_id) current_req_dict = request.to_dict_for_infer(request_id)
if "chat_template" not in current_req_dict:
current_req_dict["chat_template"] = self.chat_template
current_req_dict["arrival_time"] = time.time() current_req_dict["arrival_time"] = time.time()
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict) prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
text_after_process = current_req_dict.get("text_after_process") text_after_process = current_req_dict.get("text_after_process")

View File

@@ -87,6 +87,7 @@ class ErnieProcessor(BaseDataProcessor):
bool: Whether preprocessing is successful bool: Whether preprocessing is successful
str: error message str: error message
""" """
request.chat_template = kwargs.get("chat_template")
request = self._apply_default_parameters(request) request = self._apply_default_parameters(request)
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0: if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
request.eos_token_ids = self.eos_token_ids request.eos_token_ids = self.eos_token_ids
@@ -342,6 +343,7 @@ class ErnieProcessor(BaseDataProcessor):
tokenize=False, tokenize=False,
split_special_tokens=False, split_special_tokens=False,
add_special_tokens=False, add_special_tokens=False,
chat_template=request_or_messages.get("chat_template", None),
) )
request_or_messages["text_after_process"] = spliced_message request_or_messages["text_after_process"] = spliced_message
req_id = None req_id = None

View File

@@ -109,6 +109,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
def process_request(self, request, max_model_len=None, **kwargs): def process_request(self, request, max_model_len=None, **kwargs):
"""process the input data""" """process the input data"""
request.chat_template = kwargs.get("chat_template")
task = request.to_dict() task = request.to_dict()
task["enable_thinking"] = kwargs.get("enable_thinking", True) task["enable_thinking"] = kwargs.get("enable_thinking", True)
self.process_request_dict(task, max_model_len) self.process_request_dict(task, max_model_len)

View File

@@ -494,10 +494,12 @@ class DataProcessor:
""" """
if self.tokenizer.chat_template is None: if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.") raise ValueError("This model does not support chat_template.")
prompt_token_template = self.tokenizer.apply_chat_template( prompt_token_template = self.tokenizer.apply_chat_template(
request, request,
tokenize=False, tokenize=False,
add_generation_prompt=request.get("add_generation_prompt", True), add_generation_prompt=request.get("add_generation_prompt", True),
chat_template=request.get("chat_template", None),
) )
prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace( prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace(
"<|video@placeholder|>", "" "<|video@placeholder|>", ""

View File

@@ -204,6 +204,7 @@ class DataProcessor(BaseDataProcessor):
bool: Whether preprocessing is successful bool: Whether preprocessing is successful
str: error message str: error message
""" """
request.chat_template = kwargs.get("chat_template")
request = self._apply_default_parameters(request) request = self._apply_default_parameters(request)
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0: if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
request.eos_token_ids = self.eos_token_ids request.eos_token_ids = self.eos_token_ids
@@ -486,6 +487,7 @@ class DataProcessor(BaseDataProcessor):
split_special_tokens=False, split_special_tokens=False,
add_special_tokens=False, add_special_tokens=False,
return_tensors="pd", return_tensors="pd",
chat_template=request.get("chat_template", None),
) )
request["text_after_process"] = spliced_message request["text_after_process"] = spliced_message
req_id = None req_id = None

View File

@@ -35,3 +35,4 @@ opentelemetry-instrumentation-mysql
opentelemetry-distro  opentelemetry-distro 
opentelemetry-exporter-otlp opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi opentelemetry-instrumentation-fastapi
partial_json_parser

View File

@@ -36,3 +36,4 @@ opentelemetry-instrumentation-mysql
opentelemetry-distro opentelemetry-distro
opentelemetry-exporter-otlp opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi opentelemetry-instrumentation-fastapi
partial_json_parser

View File

@@ -37,3 +37,4 @@ opentelemetry-instrumentation-mysql
opentelemetry-distro  opentelemetry-distro 
opentelemetry-exporter-otlp opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi opentelemetry-instrumentation-fastapi
partial_json_parser

View File

@@ -0,0 +1,205 @@
import os
import unittest
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
from fastdeploy.engine.request import Request
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.chat_utils import load_chat_template
from fastdeploy.entrypoints.llm import LLM
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
from fastdeploy.input.ernie_processor import ErnieProcessor
from fastdeploy.input.ernie_vl_processor import ErnieMoEVLProcessor
from fastdeploy.input.text_processor import DataProcessor
class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase):
def setUp(self):
"""
Set up the test environment by creating an instance of the LLM class using Mock.
"""
self.input_chat_template = "unit test \n"
self.mock_engine = MagicMock()
self.tokenizer = MagicMock()
def test_load_chat_template_non(self):
result = load_chat_template(None)
self.assertEqual(None, result)
def test_load_chat_template_str(self):
result = load_chat_template(self.input_chat_template)
self.assertEqual(self.input_chat_template, result)
def test_load_chat_template_path(self):
with open("chat_template", "w", encoding="utf-8") as file:
file.write(self.input_chat_template)
file_path = os.path.join(os.getcwd(), "chat_template")
result = load_chat_template(file_path)
os.remove(file_path)
self.assertEqual(self.input_chat_template, result)
def test_load_chat_template_non_str_and_path(self):
with self.assertRaises(ValueError):
load_chat_template("unit test")
def test_path_with_literal_true(self):
with self.assertRaises(TypeError):
load_chat_template(Path("./chat_template"), is_literal=True)
def test_path_object_file_error(self):
with patch("builtins.open", mock_open()) as mock_file:
mock_file.side_effect = OSError("File error")
with self.assertRaises(OSError):
load_chat_template(Path("./chat_template"))
async def test_serving_chat(self):
request = ChatCompletionRequest(messages=[{"role": "user", "content": "你好"}])
self.chat_completion_handler = OpenAIServingChat(
self.mock_engine, pid=123, ips=None, max_waiting_time=-1, chat_template=self.input_chat_template
)
async def mock_chat_completion_full_generator(
request, request_id, model_name, prompt_token_ids, text_after_process
):
return prompt_token_ids
def mock_format_and_add_data(current_req_dict):
return current_req_dict
self.chat_completion_handler.chat_completion_full_generator = mock_chat_completion_full_generator
self.chat_completion_handler.engine_client.format_and_add_data = mock_format_and_add_data
self.chat_completion_handler.engine_client.semaphore = AsyncMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=None)
self.chat_completion_handler.engine_client.semaphore.status = MagicMock(return_value="mock_status")
chat_completiom = await self.chat_completion_handler.create_chat_completion(request)
self.assertEqual(self.input_chat_template, chat_completiom["chat_template"])
async def test_serving_chat_cus(self):
request = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}], chat_template="hello")
self.chat_completion_handler = OpenAIServingChat(
self.mock_engine, pid=123, ips=None, max_waiting_time=10, chat_template=self.input_chat_template
)
async def mock_chat_completion_full_generator(
request, request_id, model_name, prompt_token_ids, text_after_process
):
return prompt_token_ids
def mock_format_and_add_data(current_req_dict):
return current_req_dict
self.chat_completion_handler.chat_completion_full_generator = mock_chat_completion_full_generator
self.chat_completion_handler.engine_client.format_and_add_data = mock_format_and_add_data
self.chat_completion_handler.engine_client.semaphore = AsyncMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=None)
self.chat_completion_handler.engine_client.semaphore.status = MagicMock(return_value="mock_status")
chat_completion = await self.chat_completion_handler.create_chat_completion(request)
self.assertEqual("hello", chat_completion["chat_template"])
@patch("fastdeploy.input.ernie_vl_processor.ErnieMoEVLProcessor.__init__")
def test_vl_processor(self, mock_class):
mock_class.return_value = None
vl_processor = ErnieMoEVLProcessor()
mock_request = Request.from_dict({"request_id": "123"})
def mock_apply_default_parameters(request):
return request
def mock_process_request(request, max_model_len):
return request
vl_processor._apply_default_parameters = mock_apply_default_parameters
vl_processor.process_request_dict = mock_process_request
result = vl_processor.process_request(mock_request, chat_template="hello")
self.assertEqual("hello", result.chat_template)
@patch("fastdeploy.input.text_processor.DataProcessor.__init__")
def test_text_processor_process_request(self, mock_class):
mock_class.return_value = None
text_processor = DataProcessor()
mock_request = Request.from_dict(
{"request_id": "123", "prompt": "hi", "max_tokens": 128, "temperature": 1, "top_p": 1}
)
def mock_apply_default_parameters(request):
return request
def mock_process_request(request, max_model_len):
return request
def mock_text2ids(text, max_model_len):
return [1]
text_processor._apply_default_parameters = mock_apply_default_parameters
text_processor.process_request_dict = mock_process_request
text_processor.text2ids = mock_text2ids
text_processor.eos_token_ids = [1]
result = text_processor.process_request(mock_request, chat_template="hello")
self.assertEqual("hello", result.chat_template)
@patch("fastdeploy.input.ernie_processor.ErnieProcessor.__init__")
def test_ernie_processor_process(self, mock_class):
mock_class.return_value = None
ernie_processor = ErnieProcessor()
mock_request = Request.from_dict(
{"request_id": "123", "messages": ["hi"], "max_tokens": 128, "temperature": 1, "top_p": 1}
)
def mock_apply_default_parameters(request):
return request
def mock_process_request(request, max_model_len):
return request
def mock_messages2ids(text):
return [1]
ernie_processor._apply_default_parameters = mock_apply_default_parameters
ernie_processor.process_request_dict = mock_process_request
ernie_processor.messages2ids = mock_messages2ids
ernie_processor.eos_token_ids = [1]
result = ernie_processor.process_request(mock_request, chat_template="hello")
self.assertEqual("hello", result.chat_template)
@patch("fastdeploy.entrypoints.llm.LLM.__init__")
def test_llm_load(self, mock_class):
mock_class.return_value = None
llm = LLM()
llm.llm_engine = MagicMock()
llm.default_sampling_params = MagicMock()
llm.chat_template = "hello"
def mock_run_engine(req_ids, **kwargs):
return req_ids
def mock_add_request(**kwargs):
return kwargs.get("chat_template")
llm._run_engine = mock_run_engine
llm._add_request = mock_add_request
result = llm.chat(["hello"], sampling_params=SamplingParams(1))
self.assertEqual("hello", result)
@patch("fastdeploy.entrypoints.llm.LLM.__init__")
def test_llm(self, mock_class):
mock_class.return_value = None
llm = LLM()
llm.llm_engine = MagicMock()
llm.default_sampling_params = MagicMock()
def mock_run_engine(req_ids, **kwargs):
return req_ids
def mock_add_request(**kwargs):
return kwargs.get("chat_template")
llm._run_engine = mock_run_engine
llm._add_request = mock_add_request
result = llm.chat(["hello"], sampling_params=SamplingParams(1), chat_template="hello")
self.assertEqual("hello", result)
if __name__ == "__main__":
unittest.main()