mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[Feature] Pass through the chat_template_kwargs
to the data processing module (#3421)
* fix chat_template_args
* fix args
* add offline
* add offline
* fix
* fix
* fix default enable_thinking value
* fix default enable_thinking value
* modify condition
* Revert "modify condition"
This reverts commit 26430bdeb1
.
* fix unit test
This commit is contained in:
@@ -465,10 +465,7 @@ class LLMEngine:
|
|||||||
request.sampling_params = sampling_params
|
request.sampling_params = sampling_params
|
||||||
request.preprocess_start_time = time.time()
|
request.preprocess_start_time = time.time()
|
||||||
|
|
||||||
enable_thinking = None
|
request = self.data_processor.process_request(request, self.cfg.max_model_len, **kwargs)
|
||||||
if kwargs is not None:
|
|
||||||
enable_thinking = kwargs.get("enable_thinking", None)
|
|
||||||
request = self.data_processor.process_request(request, self.cfg.max_model_len, enable_thinking=enable_thinking)
|
|
||||||
request.prompt_token_ids_len = len(request.prompt_token_ids)
|
request.prompt_token_ids_len = len(request.prompt_token_ids)
|
||||||
request.need_prefill_tokens = request.prompt_token_ids_len
|
request.need_prefill_tokens = request.prompt_token_ids_len
|
||||||
input_ids_len = request.prompt_token_ids_len
|
input_ids_len = request.prompt_token_ids_len
|
||||||
|
@@ -256,7 +256,7 @@ class LLM:
|
|||||||
self,
|
self,
|
||||||
prompts,
|
prompts,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
添加一个请求到 LLM Engine,并返回该请求的 ID。
|
添加一个请求到 LLM Engine,并返回该请求的 ID。
|
||||||
@@ -297,10 +297,7 @@ class LLM:
|
|||||||
current_sampling_params = sampling_params[i]
|
current_sampling_params = sampling_params[i]
|
||||||
else:
|
else:
|
||||||
current_sampling_params = sampling_params
|
current_sampling_params = sampling_params
|
||||||
enable_thinking = None
|
self.llm_engine.add_requests(tasks, current_sampling_params, **kwargs)
|
||||||
if chat_template_kwargs is not None:
|
|
||||||
enable_thinking = chat_template_kwargs.get("enable_thinking", None)
|
|
||||||
self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking)
|
|
||||||
return req_ids
|
return req_ids
|
||||||
|
|
||||||
def _decode_token(self, token_id: int) -> str:
|
def _decode_token(self, token_id: int) -> str:
|
||||||
|
@@ -108,7 +108,16 @@ class ErnieProcessor(BaseDataProcessor):
|
|||||||
request.prompt_token_ids = token_ids
|
request.prompt_token_ids = token_ids
|
||||||
data_processor_logger.info(f"req_id:{request.request_id}, tokens:{tokens}, token_ids: {token_ids}")
|
data_processor_logger.info(f"req_id:{request.request_id}, tokens:{tokens}, token_ids: {token_ids}")
|
||||||
else:
|
else:
|
||||||
request.prompt_token_ids = self.messages2ids(request.to_dict())
|
task = request.to_dict()
|
||||||
|
chat_template_kwargs = kwargs.get("chat_template_kwargs")
|
||||||
|
if chat_template_kwargs:
|
||||||
|
if isinstance(chat_template_kwargs, dict):
|
||||||
|
for k, v in chat_template_kwargs.items():
|
||||||
|
if k not in task:
|
||||||
|
task[k] = v
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||||
|
request.prompt_token_ids = self.messages2ids(task)
|
||||||
|
|
||||||
if len(request.prompt_token_ids) == 0:
|
if len(request.prompt_token_ids) == 0:
|
||||||
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
||||||
@@ -163,6 +172,14 @@ class ErnieProcessor(BaseDataProcessor):
|
|||||||
req_id = request.get("request_id", None)
|
req_id = request.get("request_id", None)
|
||||||
data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
|
data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
|
||||||
else:
|
else:
|
||||||
|
chat_template_kwargs = request.get("chat_template_kwargs")
|
||||||
|
if chat_template_kwargs:
|
||||||
|
if isinstance(chat_template_kwargs, dict):
|
||||||
|
for k, v in chat_template_kwargs.items():
|
||||||
|
if k not in request:
|
||||||
|
request[k] = v
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||||
request["prompt_token_ids"] = self.messages2ids(request)
|
request["prompt_token_ids"] = self.messages2ids(request)
|
||||||
if len(request["prompt_token_ids"]) == 0:
|
if len(request["prompt_token_ids"]) == 0:
|
||||||
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
||||||
|
@@ -111,7 +111,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
|
|||||||
"""process the input data"""
|
"""process the input data"""
|
||||||
request.chat_template = kwargs.get("chat_template")
|
request.chat_template = kwargs.get("chat_template")
|
||||||
task = request.to_dict()
|
task = request.to_dict()
|
||||||
task["enable_thinking"] = kwargs.get("enable_thinking", True)
|
task["chat_template_kwargs"] = kwargs.get("chat_template_kwargs")
|
||||||
self.process_request_dict(task, max_model_len)
|
self.process_request_dict(task, max_model_len)
|
||||||
request = Request.from_dict(task)
|
request = Request.from_dict(task)
|
||||||
request = self._apply_default_parameters(request)
|
request = self._apply_default_parameters(request)
|
||||||
@@ -218,6 +218,15 @@ class ErnieMoEVLProcessor(ErnieProcessor):
|
|||||||
elif request.get("messages"):
|
elif request.get("messages"):
|
||||||
messages = request["messages"]
|
messages = request["messages"]
|
||||||
self._check_mm_limits(messages)
|
self._check_mm_limits(messages)
|
||||||
|
chat_template_kwargs = request.get("chat_template_kwargs")
|
||||||
|
if chat_template_kwargs:
|
||||||
|
if isinstance(chat_template_kwargs, dict):
|
||||||
|
for k, v in chat_template_kwargs.items():
|
||||||
|
if k not in request:
|
||||||
|
request[k] = v
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||||
|
request.setdefault("enable_thinking", True)
|
||||||
outputs = self.ernie_processor.request2ids(request)
|
outputs = self.ernie_processor.request2ids(request)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Request must contain 'prompt', or 'messages': {request}")
|
raise ValueError(f"Request must contain 'prompt', or 'messages': {request}")
|
||||||
|
@@ -208,7 +208,6 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
request = self._apply_default_parameters(request)
|
request = self._apply_default_parameters(request)
|
||||||
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
|
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
|
||||||
request.eos_token_ids = self.eos_token_ids
|
request.eos_token_ids = self.eos_token_ids
|
||||||
|
|
||||||
stop_sequences = request.get("stop", [])
|
stop_sequences = request.get("stop", [])
|
||||||
if stop_sequences is not None and len(stop_sequences) != 0:
|
if stop_sequences is not None and len(stop_sequences) != 0:
|
||||||
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
|
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
|
||||||
@@ -222,7 +221,15 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
if self.tokenizer.chat_template is None:
|
if self.tokenizer.chat_template is None:
|
||||||
raise ValueError("This model does not support chat_template.")
|
raise ValueError("This model does not support chat_template.")
|
||||||
task = request.to_dict()
|
task = request.to_dict()
|
||||||
task["enable_thinking"] = kwargs.get("enable_thinking", True)
|
chat_template_kwargs = kwargs.get("chat_template_kwargs")
|
||||||
|
if chat_template_kwargs:
|
||||||
|
if isinstance(chat_template_kwargs, dict):
|
||||||
|
for k, v in chat_template_kwargs.items():
|
||||||
|
if k not in task:
|
||||||
|
task[k] = v
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||||
|
task.setdefault("enable_thinking", True)
|
||||||
request.prompt_token_ids = self.messages2ids(task)
|
request.prompt_token_ids = self.messages2ids(task)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
|
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
|
||||||
@@ -272,6 +279,15 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
elif "messages" in request:
|
elif "messages" in request:
|
||||||
if self.tokenizer.chat_template is None:
|
if self.tokenizer.chat_template is None:
|
||||||
raise ValueError("This model does not support chat_template.")
|
raise ValueError("This model does not support chat_template.")
|
||||||
|
chat_template_kwargs = request.get("chat_template_kwargs")
|
||||||
|
if chat_template_kwargs:
|
||||||
|
if isinstance(chat_template_kwargs, dict):
|
||||||
|
for k, v in chat_template_kwargs.items():
|
||||||
|
if k not in request:
|
||||||
|
request[k] = v
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||||
|
request.setdefault("enable_thinking", True)
|
||||||
request["prompt_token_ids"] = self.messages2ids(request)
|
request["prompt_token_ids"] = self.messages2ids(request)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
|
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
|
||||||
|
@@ -508,6 +508,7 @@ def test_chat_with_thinking(openai_client, capsys):
|
|||||||
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
|
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
|
||||||
)
|
)
|
||||||
assert response.choices[0].message.reasoning_content is None
|
assert response.choices[0].message.reasoning_content is None
|
||||||
|
assert "</think>" not in response.choices[0].message.content
|
||||||
|
|
||||||
# enable thinking, streaming
|
# enable thinking, streaming
|
||||||
reasoning_max_tokens = 3
|
reasoning_max_tokens = 3
|
||||||
|
Reference in New Issue
Block a user