[Feature] Pass through the chat_template_kwargs to the data processing module (#3421)

* fix chat_template_args

* fix args

* add offline

* add offline

* fix

* fix

* fix default enable_thinking value

* fix default enable_thinking value

* modify condition

* Revert "modify condition"

This reverts commit 26430bdeb1.

* fix unit test
This commit is contained in:
luukunn
2025-08-19 10:50:01 +08:00
committed by GitHub
parent a053ab889b
commit 3a7a20d191
6 changed files with 50 additions and 13 deletions

View File

@@ -108,7 +108,16 @@ class ErnieProcessor(BaseDataProcessor):
request.prompt_token_ids = token_ids
data_processor_logger.info(f"req_id:{request.request_id}, tokens:{tokens}, token_ids: {token_ids}")
else:
request.prompt_token_ids = self.messages2ids(request.to_dict())
task = request.to_dict()
chat_template_kwargs = kwargs.get("chat_template_kwargs")
if chat_template_kwargs:
if isinstance(chat_template_kwargs, dict):
for k, v in chat_template_kwargs.items():
if k not in task:
task[k] = v
else:
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
request.prompt_token_ids = self.messages2ids(task)
if len(request.prompt_token_ids) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
@@ -163,6 +172,14 @@ class ErnieProcessor(BaseDataProcessor):
req_id = request.get("request_id", None)
data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
else:
chat_template_kwargs = request.get("chat_template_kwargs")
if chat_template_kwargs:
if isinstance(chat_template_kwargs, dict):
for k, v in chat_template_kwargs.items():
if k not in request:
request[k] = v
else:
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
request["prompt_token_ids"] = self.messages2ids(request)
if len(request["prompt_token_ids"]) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")