mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
[fix]update apply_chat_template (#4249)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [fix]Modify follow-up push parameters and Modify the verification method for thinking length (#4086) * 续推参数 generated_token_ids 修改成 completion_token_ids;修改思考长度校验方式 * 续推参数 generated_token_ids 修改成 completion_token_ids;修改思考长度校验方式 * 续推参数 generated_token_ids 修改成 completion_token_ids;修改思考长度校验方式 * 续推参数 generated_token_ids 修改成 completion_token_ids;修改思考长度校验方式 * add completion_token_ids * add logger * fix reasoning_max_tokens ParameterError * add unittest * add unittest * add unittest * add unittest * add unittest * add unit test * fix * [fix]update apply_chat_template (#4137) * update apply_chat_template * fix unittest * fix unittest * fix * fix * fix unit test * fix * fix unit test * add unit test
This commit is contained in:
@@ -222,7 +222,9 @@ class LLMEngine:
|
|||||||
if sampling_params is not None:
|
if sampling_params is not None:
|
||||||
request.sampling_params = sampling_params
|
request.sampling_params = sampling_params
|
||||||
request.preprocess_start_time = time.time()
|
request.preprocess_start_time = time.time()
|
||||||
|
chat_template_kwargs = kwargs.get("chat_template_kwargs") or {}
|
||||||
|
chat_template_kwargs["chat_template"] = kwargs.get("chat_template")
|
||||||
|
kwargs["chat_template_kwargs"] = chat_template_kwargs
|
||||||
request = self.data_processor.process_request(request, self.cfg.max_model_len, **kwargs)
|
request = self.data_processor.process_request(request, self.cfg.max_model_len, **kwargs)
|
||||||
request.prompt_token_ids_len = len(request.prompt_token_ids)
|
request.prompt_token_ids_len = len(request.prompt_token_ids)
|
||||||
request.need_prefill_tokens = request.prompt_token_ids_len
|
request.need_prefill_tokens = request.prompt_token_ids_len
|
||||||
|
@@ -172,6 +172,9 @@ class EngineClient:
|
|||||||
|
|
||||||
task["preprocess_start_time"] = time.time()
|
task["preprocess_start_time"] = time.time()
|
||||||
try:
|
try:
|
||||||
|
chat_template_kwargs = task.get("chat_template_kwargs", {})
|
||||||
|
chat_template_kwargs.update({"chat_template": task.get("chat_template"), "tools": task.get("tools")})
|
||||||
|
task["chat_template_kwargs"] = chat_template_kwargs
|
||||||
if inspect.iscoroutinefunction(self.data_processor.process_request_dict):
|
if inspect.iscoroutinefunction(self.data_processor.process_request_dict):
|
||||||
await self.data_processor.process_request_dict(task, self.max_model_len)
|
await self.data_processor.process_request_dict(task, self.max_model_len)
|
||||||
else:
|
else:
|
||||||
|
@@ -88,7 +88,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
str: error message
|
str: error message
|
||||||
"""
|
"""
|
||||||
data_processor_logger.info(f"Start processing request: {request}")
|
data_processor_logger.info(f"Start processing request: {request}")
|
||||||
request.chat_template = kwargs.get("chat_template")
|
|
||||||
request = self._apply_default_parameters(request)
|
request = self._apply_default_parameters(request)
|
||||||
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
|
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
|
||||||
request.eos_token_ids = self.eos_token_ids
|
request.eos_token_ids = self.eos_token_ids
|
||||||
@@ -127,7 +126,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
)
|
)
|
||||||
elif request.messages is not None:
|
elif request.messages is not None:
|
||||||
task = request.to_dict()
|
task = request.to_dict()
|
||||||
chat_template_kwargs = kwargs.get("chat_template_kwargs")
|
chat_template_kwargs = kwargs.get("chat_template_kwargs", {})
|
||||||
if chat_template_kwargs:
|
if chat_template_kwargs:
|
||||||
if isinstance(chat_template_kwargs, dict):
|
if isinstance(chat_template_kwargs, dict):
|
||||||
for k, v in chat_template_kwargs.items():
|
for k, v in chat_template_kwargs.items():
|
||||||
@@ -135,7 +134,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
task[k] = v
|
task[k] = v
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||||
request.prompt_token_ids = self.messages2ids(task)
|
request.prompt_token_ids = self.messages2ids(task, **chat_template_kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.")
|
raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.")
|
||||||
|
|
||||||
@@ -205,7 +204,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
req_id = request.get("request_id", None)
|
req_id = request.get("request_id", None)
|
||||||
data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
|
data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
|
||||||
elif request.get("messages"):
|
elif request.get("messages"):
|
||||||
chat_template_kwargs = request.get("chat_template_kwargs")
|
chat_template_kwargs = request.get("chat_template_kwargs", {})
|
||||||
if chat_template_kwargs:
|
if chat_template_kwargs:
|
||||||
if isinstance(chat_template_kwargs, dict):
|
if isinstance(chat_template_kwargs, dict):
|
||||||
for k, v in chat_template_kwargs.items():
|
for k, v in chat_template_kwargs.items():
|
||||||
@@ -213,7 +212,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
request[k] = v
|
request[k] = v
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||||
request["prompt_token_ids"] = self.messages2ids(request)
|
request["prompt_token_ids"] = self.messages2ids(request, **chat_template_kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
|
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
|
||||||
|
|
||||||
@@ -379,7 +378,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
del self.tool_parser_dict[req_id]
|
del self.tool_parser_dict[req_id]
|
||||||
return response_dict
|
return response_dict
|
||||||
|
|
||||||
def messages2ids(self, request_or_messages):
|
def messages2ids(self, request_or_messages, **kwargs):
|
||||||
"""
|
"""
|
||||||
Convert multi-turn messages into ID sequences.
|
Convert multi-turn messages into ID sequences.
|
||||||
|
|
||||||
@@ -397,7 +396,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
tokenize=False,
|
tokenize=False,
|
||||||
split_special_tokens=False,
|
split_special_tokens=False,
|
||||||
add_special_tokens=False,
|
add_special_tokens=False,
|
||||||
chat_template=request_or_messages.get("chat_template", None),
|
**kwargs,
|
||||||
)
|
)
|
||||||
request_or_messages["text_after_process"] = spliced_message
|
request_or_messages["text_after_process"] = spliced_message
|
||||||
req_id = None
|
req_id = None
|
||||||
|
@@ -113,7 +113,6 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
|||||||
|
|
||||||
def process_request(self, request, max_model_len=None, **kwargs):
|
def process_request(self, request, max_model_len=None, **kwargs):
|
||||||
"""process the input data"""
|
"""process the input data"""
|
||||||
request.chat_template = kwargs.get("chat_template")
|
|
||||||
task = request.to_dict()
|
task = request.to_dict()
|
||||||
task["chat_template_kwargs"] = kwargs.get("chat_template_kwargs")
|
task["chat_template_kwargs"] = kwargs.get("chat_template_kwargs")
|
||||||
self.process_request_dict(task, max_model_len)
|
self.process_request_dict(task, max_model_len)
|
||||||
|
@@ -250,8 +250,8 @@ class DataProcessor:
|
|||||||
"video",
|
"video",
|
||||||
]:
|
]:
|
||||||
image_message_list.append(item)
|
image_message_list.append(item)
|
||||||
|
chat_template_kwargs = request.get("chat_template_kwargs", {})
|
||||||
prompt_token_ids = self.apply_chat_template(request)
|
prompt_token_ids = self.apply_chat_template(request, **chat_template_kwargs)
|
||||||
if len(prompt_token_ids) == 0:
|
if len(prompt_token_ids) == 0:
|
||||||
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
||||||
image_start_index = 0
|
image_start_index = 0
|
||||||
@@ -480,7 +480,7 @@ class DataProcessor:
|
|||||||
break
|
break
|
||||||
self.tokenizer = Ernie4_5Tokenizer.from_pretrained(self.model_name_or_path)
|
self.tokenizer = Ernie4_5Tokenizer.from_pretrained(self.model_name_or_path)
|
||||||
|
|
||||||
def apply_chat_template(self, request):
|
def apply_chat_template(self, request, **kwargs):
|
||||||
"""
|
"""
|
||||||
Convert multi-turn messages into ID sequences.
|
Convert multi-turn messages into ID sequences.
|
||||||
|
|
||||||
@@ -498,7 +498,7 @@ class DataProcessor:
|
|||||||
request,
|
request,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=request.get("add_generation_prompt", True),
|
add_generation_prompt=request.get("add_generation_prompt", True),
|
||||||
chat_template=request.get("chat_template", None),
|
**kwargs,
|
||||||
)
|
)
|
||||||
prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace(
|
prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace(
|
||||||
"<|video@placeholder|>", ""
|
"<|video@placeholder|>", ""
|
||||||
|
@@ -208,7 +208,6 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
str: error message
|
str: error message
|
||||||
"""
|
"""
|
||||||
data_processor_logger.info(f"Start processing request: {request}")
|
data_processor_logger.info(f"Start processing request: {request}")
|
||||||
request.chat_template = kwargs.get("chat_template")
|
|
||||||
request = self._apply_default_parameters(request)
|
request = self._apply_default_parameters(request)
|
||||||
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
|
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
|
||||||
request.eos_token_ids = self.eos_token_ids
|
request.eos_token_ids = self.eos_token_ids
|
||||||
@@ -242,7 +241,7 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
if self.tokenizer.chat_template is None:
|
if self.tokenizer.chat_template is None:
|
||||||
raise ValueError("This model does not support chat_template.")
|
raise ValueError("This model does not support chat_template.")
|
||||||
task = request.to_dict()
|
task = request.to_dict()
|
||||||
chat_template_kwargs = kwargs.get("chat_template_kwargs")
|
chat_template_kwargs = kwargs.get("chat_template_kwargs", {})
|
||||||
if chat_template_kwargs:
|
if chat_template_kwargs:
|
||||||
if isinstance(chat_template_kwargs, dict):
|
if isinstance(chat_template_kwargs, dict):
|
||||||
for k, v in chat_template_kwargs.items():
|
for k, v in chat_template_kwargs.items():
|
||||||
@@ -251,7 +250,7 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||||
task.setdefault("enable_thinking", True)
|
task.setdefault("enable_thinking", True)
|
||||||
request.prompt_token_ids = self.messages2ids(task)
|
request.prompt_token_ids = self.messages2ids(task, **chat_template_kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
|
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
|
||||||
|
|
||||||
@@ -316,7 +315,7 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
elif request.get("messages"):
|
elif request.get("messages"):
|
||||||
if self.tokenizer.chat_template is None:
|
if self.tokenizer.chat_template is None:
|
||||||
raise ValueError("This model does not support chat_template.")
|
raise ValueError("This model does not support chat_template.")
|
||||||
chat_template_kwargs = request.get("chat_template_kwargs")
|
chat_template_kwargs = request.get("chat_template_kwargs", {})
|
||||||
if chat_template_kwargs:
|
if chat_template_kwargs:
|
||||||
if isinstance(chat_template_kwargs, dict):
|
if isinstance(chat_template_kwargs, dict):
|
||||||
for k, v in chat_template_kwargs.items():
|
for k, v in chat_template_kwargs.items():
|
||||||
@@ -325,7 +324,7 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||||
request.setdefault("enable_thinking", True)
|
request.setdefault("enable_thinking", True)
|
||||||
request["prompt_token_ids"] = self.messages2ids(request)
|
request["prompt_token_ids"] = self.messages2ids(request, **chat_template_kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
|
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
|
||||||
|
|
||||||
@@ -530,7 +529,7 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
|
|
||||||
return tokens["input_ids"][0]
|
return tokens["input_ids"][0]
|
||||||
|
|
||||||
def messages2ids(self, request):
|
def messages2ids(self, request, **kwargs):
|
||||||
"""
|
"""
|
||||||
Convert multi-turn messages into ID sequences.
|
Convert multi-turn messages into ID sequences.
|
||||||
|
|
||||||
@@ -547,7 +546,7 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
split_special_tokens=False,
|
split_special_tokens=False,
|
||||||
add_special_tokens=False,
|
add_special_tokens=False,
|
||||||
return_tensors="pd",
|
return_tensors="pd",
|
||||||
chat_template=request.get("chat_template", None),
|
**kwargs,
|
||||||
)
|
)
|
||||||
request["text_after_process"] = spliced_message
|
request["text_after_process"] = spliced_message
|
||||||
req_id = None
|
req_id = None
|
||||||
|
36
tests/entrypoints/test_engine_client.py
Normal file
36
tests/entrypoints/test_engine_client.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from fastdeploy.entrypoints.engine_client import EngineClient
|
||||||
|
|
||||||
|
|
||||||
|
class TestEngineClient(unittest.IsolatedAsyncioTestCase):
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
# 创建 EngineClient 实例的模拟对象
|
||||||
|
with patch.object(EngineClient, "__init__", return_value=None) as mock_init:
|
||||||
|
self.engine_client = EngineClient("model_path")
|
||||||
|
mock_init.side_effect = lambda *args, **kwargs: print(f"__init__ called with {args}, {kwargs}")
|
||||||
|
|
||||||
|
self.engine_client.data_processor = MagicMock()
|
||||||
|
self.engine_client.zmq_client = MagicMock()
|
||||||
|
self.engine_client.max_model_len = 1024
|
||||||
|
self.engine_client.enable_mm = False
|
||||||
|
|
||||||
|
async def test_add_request(self):
|
||||||
|
request = {
|
||||||
|
"chat_template_kwargs": {"enable_thinking": True},
|
||||||
|
"prompt_token_ids": [1],
|
||||||
|
"chat_template": "Hello",
|
||||||
|
"max_tokens": 20,
|
||||||
|
"tools": [1],
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.engine_client.add_requests(request)
|
||||||
|
assert "chat_template" in request["chat_template_kwargs"], "'chat_template' not found in 'chat_template_kwargs"
|
||||||
|
assert "tools" in request["chat_template_kwargs"], "'tools' not found in 'chat_template_kwargs'"
|
||||||
|
assert request["chat_template_kwargs"]["chat_template"] == "Hello"
|
||||||
|
assert request["chat_template_kwargs"]["tools"] == [1]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@@ -17,6 +17,8 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
|||||||
self.processor.decode_status = {}
|
self.processor.decode_status = {}
|
||||||
self.processor.reasoning_end_dict = {}
|
self.processor.reasoning_end_dict = {}
|
||||||
self.processor.tool_parser_dict = {}
|
self.processor.tool_parser_dict = {}
|
||||||
|
self.processor.generation_config = MagicMock()
|
||||||
|
self.processor.eos_token_ids = [1]
|
||||||
|
|
||||||
# 模拟 ids2tokens 方法
|
# 模拟 ids2tokens 方法
|
||||||
def mock_ids2tokens(token_ids, task_id):
|
def mock_ids2tokens(token_ids, task_id):
|
||||||
@@ -24,6 +26,18 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
|||||||
|
|
||||||
self.processor.ids2tokens = mock_ids2tokens
|
self.processor.ids2tokens = mock_ids2tokens
|
||||||
|
|
||||||
|
def mock_messages2ids(request, **kwargs):
|
||||||
|
if "chat_template" in kwargs:
|
||||||
|
return [1]
|
||||||
|
else:
|
||||||
|
return [0]
|
||||||
|
|
||||||
|
def mock_apply_default_parameters(request):
|
||||||
|
return request
|
||||||
|
|
||||||
|
self.processor.messages2ids = mock_messages2ids
|
||||||
|
self.processor._apply_default_parameters = mock_apply_default_parameters
|
||||||
|
|
||||||
# 模拟推理解析器
|
# 模拟推理解析器
|
||||||
self.mock_reasoning_parser = MagicMock()
|
self.mock_reasoning_parser = MagicMock()
|
||||||
self.mock_reasoning_parser.__class__.__name__ = "ErnieX1ReasoningParser"
|
self.mock_reasoning_parser.__class__.__name__ = "ErnieX1ReasoningParser"
|
||||||
@@ -49,6 +63,17 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
|||||||
# 验证结果
|
# 验证结果
|
||||||
self.assertEqual(result["outputs"]["raw_prediction"], "delta_text")
|
self.assertEqual(result["outputs"]["raw_prediction"], "delta_text")
|
||||||
|
|
||||||
|
def test_process_request_dict(self):
|
||||||
|
request_dict = {
|
||||||
|
"messages": [{"role": "user", "content": "Hello!"}],
|
||||||
|
"chat_template_kwargs": {"chat_template": "Hello!"},
|
||||||
|
"eos_token_ids": [1],
|
||||||
|
"temperature": 1,
|
||||||
|
"top_p": 1,
|
||||||
|
}
|
||||||
|
result = self.processor.process_request_dict(request_dict, 100)
|
||||||
|
self.assertEqual(result["prompt_token_ids"], [1])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
63
tests/input/test_text_processor.py
Normal file
63
tests/input/test_text_processor.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from fastdeploy.engine.request import Request
|
||||||
|
from fastdeploy.input.text_processor import DataProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataProcessorProcess(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# 创建 DataProcessor 实例的模拟对象
|
||||||
|
with patch.object(DataProcessor, "__init__", return_value=None) as mock_init:
|
||||||
|
self.processor = DataProcessor("model_path")
|
||||||
|
mock_init.side_effect = lambda *args, **kwargs: print(f"__init__ called with {args}, {kwargs}")
|
||||||
|
|
||||||
|
# 设置必要的属性
|
||||||
|
self.processor.tokenizer = MagicMock()
|
||||||
|
self.processor.tokenizer.eos_token_id = 1
|
||||||
|
self.processor.decode_status = {}
|
||||||
|
self.processor.reasoning_end_dict = {}
|
||||||
|
self.processor.tool_parser_dict = {}
|
||||||
|
self.processor.generation_config = MagicMock()
|
||||||
|
self.processor.eos_token_ids = [1]
|
||||||
|
|
||||||
|
def mock_messages2ids(request, **kwargs):
|
||||||
|
if "chat_template" in kwargs:
|
||||||
|
return [1]
|
||||||
|
else:
|
||||||
|
return [0]
|
||||||
|
|
||||||
|
def mock_apply_default_parameters(request):
|
||||||
|
return request
|
||||||
|
|
||||||
|
self.processor.messages2ids = mock_messages2ids
|
||||||
|
self.processor._apply_default_parameters = mock_apply_default_parameters
|
||||||
|
|
||||||
|
def test_process_request(self):
|
||||||
|
request = Request.from_dict(
|
||||||
|
{
|
||||||
|
"request_id": "123",
|
||||||
|
"messages": [{"role": "user", "content": "Hello!"}],
|
||||||
|
"eos_token_ids": [1],
|
||||||
|
"temperature": 1,
|
||||||
|
"top_p": 1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
chat_template_kwargs = {"chat_template": "Hello!"}
|
||||||
|
result = self.processor.process_request(request, 100, chat_template_kwargs=chat_template_kwargs)
|
||||||
|
self.assertEqual(result.prompt_token_ids, [1])
|
||||||
|
|
||||||
|
def test_process_request_dict(self):
|
||||||
|
request_dict = {
|
||||||
|
"messages": [{"role": "user", "content": "Hello!"}],
|
||||||
|
"chat_template_kwargs": {"chat_template": "Hello!"},
|
||||||
|
"eos_token_ids": [1],
|
||||||
|
"temperature": 1,
|
||||||
|
"top_p": 1,
|
||||||
|
}
|
||||||
|
result = self.processor.process_request_dict(request_dict, 100)
|
||||||
|
self.assertEqual(result["prompt_token_ids"], [1])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@@ -3,15 +3,11 @@ import unittest
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||||||
|
|
||||||
from fastdeploy.engine.request import Request
|
|
||||||
from fastdeploy.engine.sampling_params import SamplingParams
|
from fastdeploy.engine.sampling_params import SamplingParams
|
||||||
from fastdeploy.entrypoints.chat_utils import load_chat_template
|
from fastdeploy.entrypoints.chat_utils import load_chat_template
|
||||||
from fastdeploy.entrypoints.llm import LLM
|
from fastdeploy.entrypoints.llm import LLM
|
||||||
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest
|
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
|
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from fastdeploy.input.ernie4_5_processor import Ernie4_5Processor
|
|
||||||
from fastdeploy.input.ernie4_5_vl_processor import Ernie4_5_VLProcessor
|
|
||||||
from fastdeploy.input.text_processor import DataProcessor
|
|
||||||
|
|
||||||
|
|
||||||
class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase):
|
class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase):
|
||||||
@@ -108,91 +104,6 @@ class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase):
|
|||||||
chat_completion = await self.chat_completion_handler.create_chat_completion(request)
|
chat_completion = await self.chat_completion_handler.create_chat_completion(request)
|
||||||
self.assertEqual("hello", chat_completion["chat_template"])
|
self.assertEqual("hello", chat_completion["chat_template"])
|
||||||
|
|
||||||
@patch("fastdeploy.input.ernie4_5_vl_processor.Ernie4_5_VLProcessor.__init__")
|
|
||||||
def test_ernie4_5_vl_processor(self, mock_class):
|
|
||||||
mock_class.return_value = None
|
|
||||||
ernie4_5_vl_processor = Ernie4_5_VLProcessor()
|
|
||||||
mock_request = Request.from_dict({"request_id": "123"})
|
|
||||||
|
|
||||||
def mock_apply_default_parameters(request):
|
|
||||||
return request
|
|
||||||
|
|
||||||
def mock_process_request(request, max_model_len):
|
|
||||||
return request
|
|
||||||
|
|
||||||
ernie4_5_vl_processor._apply_default_parameters = mock_apply_default_parameters
|
|
||||||
ernie4_5_vl_processor.process_request_dict = mock_process_request
|
|
||||||
result = ernie4_5_vl_processor.process_request(mock_request, chat_template="hello")
|
|
||||||
self.assertEqual("hello", result.chat_template)
|
|
||||||
|
|
||||||
@patch("fastdeploy.input.text_processor.DataProcessor.__init__")
|
|
||||||
def test_text_processor_process_request(self, mock_class):
|
|
||||||
mock_class.return_value = None
|
|
||||||
text_processor = DataProcessor()
|
|
||||||
mock_request = Request.from_dict(
|
|
||||||
{"request_id": "123", "prompt": "hi", "max_tokens": 128, "temperature": 1, "top_p": 1}
|
|
||||||
)
|
|
||||||
|
|
||||||
def mock_apply_default_parameters(request):
|
|
||||||
return request
|
|
||||||
|
|
||||||
def mock_process_request(request, max_model_len):
|
|
||||||
return request
|
|
||||||
|
|
||||||
def mock_text2ids(text, max_model_len):
|
|
||||||
return [1]
|
|
||||||
|
|
||||||
text_processor._apply_default_parameters = mock_apply_default_parameters
|
|
||||||
text_processor.process_request_dict = mock_process_request
|
|
||||||
text_processor.text2ids = mock_text2ids
|
|
||||||
text_processor.eos_token_ids = [1]
|
|
||||||
result = text_processor.process_request(mock_request, chat_template="hello")
|
|
||||||
self.assertEqual("hello", result.chat_template)
|
|
||||||
|
|
||||||
@patch("fastdeploy.input.ernie4_5_processor.Ernie4_5Processor.__init__")
|
|
||||||
def test_ernie4_5_processor_process(self, mock_class):
|
|
||||||
mock_class.return_value = None
|
|
||||||
ernie4_5_processor = Ernie4_5Processor()
|
|
||||||
mock_request = Request.from_dict(
|
|
||||||
{"request_id": "123", "messages": ["hi"], "max_tokens": 128, "temperature": 1, "top_p": 1}
|
|
||||||
)
|
|
||||||
|
|
||||||
def mock_apply_default_parameters(request):
|
|
||||||
return request
|
|
||||||
|
|
||||||
def mock_process_request(request, max_model_len):
|
|
||||||
return request
|
|
||||||
|
|
||||||
def mock_messages2ids(text):
|
|
||||||
return [1]
|
|
||||||
|
|
||||||
ernie4_5_processor._apply_default_parameters = mock_apply_default_parameters
|
|
||||||
ernie4_5_processor.process_request_dict = mock_process_request
|
|
||||||
ernie4_5_processor.messages2ids = mock_messages2ids
|
|
||||||
ernie4_5_processor.eos_token_ids = [1]
|
|
||||||
ernie4_5_processor.reasoning_parser = MagicMock()
|
|
||||||
result = ernie4_5_processor.process_request(mock_request, chat_template="hello")
|
|
||||||
self.assertEqual("hello", result.chat_template)
|
|
||||||
|
|
||||||
@patch("fastdeploy.entrypoints.llm.LLM.__init__")
|
|
||||||
def test_llm_load(self, mock_class):
|
|
||||||
mock_class.return_value = None
|
|
||||||
llm = LLM()
|
|
||||||
llm.llm_engine = MagicMock()
|
|
||||||
llm.default_sampling_params = MagicMock()
|
|
||||||
llm.chat_template = "hello"
|
|
||||||
|
|
||||||
def mock_run_engine(req_ids, **kwargs):
|
|
||||||
return req_ids
|
|
||||||
|
|
||||||
def mock_add_request(**kwargs):
|
|
||||||
return kwargs.get("chat_template")
|
|
||||||
|
|
||||||
llm._run_engine = mock_run_engine
|
|
||||||
llm._add_request = mock_add_request
|
|
||||||
result = llm.chat(["hello"], sampling_params=SamplingParams(1))
|
|
||||||
self.assertEqual("hello", result)
|
|
||||||
|
|
||||||
@patch("fastdeploy.entrypoints.llm.LLM.__init__")
|
@patch("fastdeploy.entrypoints.llm.LLM.__init__")
|
||||||
def test_llm(self, mock_class):
|
def test_llm(self, mock_class):
|
||||||
mock_class.return_value = None
|
mock_class.return_value = None
|
||||||
|
Reference in New Issue
Block a user