diff --git a/fastdeploy/input/ernie_processor.py b/fastdeploy/input/ernie_processor.py index 766feea47..2e60ef325 100644 --- a/fastdeploy/input/ernie_processor.py +++ b/fastdeploy/input/ernie_processor.py @@ -109,7 +109,13 @@ class ErnieProcessor(BaseDataProcessor): request.prompt_token_ids = token_ids data_processor_logger.info(f"req_id:{request.request_id}, tokens:{tokens}, token_ids: {token_ids}") else: - request.prompt_token_ids = self.messages2ids(request.to_dict()) + task = request.to_dict() + chat_template_kwargs = request.get("chat_template_kwargs") + if chat_template_kwargs: + for k, v in chat_template_kwargs.items(): + if k not in task: + task[k] = v + request.prompt_token_ids = self.messages2ids(task) if len(request.prompt_token_ids) == 0: raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs") diff --git a/fastdeploy/input/ernie_vl_processor.py b/fastdeploy/input/ernie_vl_processor.py index 04f29fa19..296b07b75 100644 --- a/fastdeploy/input/ernie_vl_processor.py +++ b/fastdeploy/input/ernie_vl_processor.py @@ -110,7 +110,7 @@ class ErnieMoEVLProcessor(ErnieProcessor): def process_request(self, request, max_model_len=None, **kwargs): """process the input data""" task = request.to_dict() - task["chat_template_kwargs"] = kwargs.get("chat_template_kwargs", {}) + task["chat_template_kwargs"] = kwargs.get("chat_template_kwargs") self.process_request_dict(task, max_model_len) request = Request.from_dict(task) request = self._apply_default_parameters(request) diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py index f67c922ef..a3cc0c035 100644 --- a/fastdeploy/input/text_processor.py +++ b/fastdeploy/input/text_processor.py @@ -207,11 +207,6 @@ class DataProcessor(BaseDataProcessor): request = self._apply_default_parameters(request) if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0: request.eos_token_ids = self.eos_token_ids - chat_template_kwargs = request.get("chat_template_kwargs") - if chat_template_kwargs: - for k, v in chat_template_kwargs.items(): - if k not in request: - request[k] = v stop_sequences = request.get("stop", []) if stop_sequences is not None and len(stop_sequences) != 0: stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences) @@ -225,7 +220,11 @@ class DataProcessor(BaseDataProcessor): if self.tokenizer.chat_template is None: raise ValueError("This model does not support chat_template.") task = request.to_dict() - task["enable_thinking"] = kwargs.get("enable_thinking", True) + chat_template_kwargs = kwargs.get("chat_template_kwargs") + if chat_template_kwargs: + for k, v in chat_template_kwargs.items(): + if k not in task: + task[k] = v request.prompt_token_ids = self.messages2ids(task) else: raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")