add offline

This commit is contained in:
luukunn
2025-08-15 15:09:24 +08:00
parent 3c823d13b9
commit 0446cc72f1
3 changed files with 13 additions and 8 deletions

View File

@@ -109,7 +109,13 @@ class ErnieProcessor(BaseDataProcessor):
request.prompt_token_ids = token_ids
data_processor_logger.info(f"req_id:{request.request_id}, tokens:{tokens}, token_ids: {token_ids}")
else:
request.prompt_token_ids = self.messages2ids(request.to_dict())
task = request.to_dict()
chat_template_kwargs = request.get("chat_template_kwargs")
if chat_template_kwargs:
for k, v in chat_template_kwargs.items():
if k not in task:
task[k] = v
request.prompt_token_ids = self.messages2ids(task)
if len(request.prompt_token_ids) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")

View File

@@ -110,7 +110,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
def process_request(self, request, max_model_len=None, **kwargs):
"""process the input data"""
task = request.to_dict()
task["chat_template_kwargs"] = kwargs.get("chat_template_kwargs", {})
task["chat_template_kwargs"] = kwargs.get("chat_template_kwargs")
self.process_request_dict(task, max_model_len)
request = Request.from_dict(task)
request = self._apply_default_parameters(request)

View File

@@ -207,11 +207,6 @@ class DataProcessor(BaseDataProcessor):
request = self._apply_default_parameters(request)
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
request.eos_token_ids = self.eos_token_ids
chat_template_kwargs = request.get("chat_template_kwargs")
if chat_template_kwargs:
for k, v in chat_template_kwargs.items():
if k not in request:
request[k] = v
stop_sequences = request.get("stop", [])
if stop_sequences is not None and len(stop_sequences) != 0:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
@@ -225,7 +220,11 @@ class DataProcessor(BaseDataProcessor):
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
task = request.to_dict()
task["enable_thinking"] = kwargs.get("enable_thinking", True)
chat_template_kwargs = kwargs.get("chat_template_kwargs")
if chat_template_kwargs:
for k, v in chat_template_kwargs.items():
if k not in task:
task[k] = v
request.prompt_token_ids = self.messages2ids(task)
else:
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")