diff --git a/fastdeploy/input/ernie_processor.py b/fastdeploy/input/ernie_processor.py index 9f54b7284..95e64711f 100644 --- a/fastdeploy/input/ernie_processor.py +++ b/fastdeploy/input/ernie_processor.py @@ -108,11 +108,14 @@ class ErnieProcessor(BaseDataProcessor): data_processor_logger.info(f"req_id:{request.request_id}, tokens:{tokens}, token_ids: {token_ids}") else: task = request.to_dict() - chat_template_kwargs = request.get("chat_template_kwargs") + chat_template_kwargs = kwargs.get("chat_template_kwargs") if chat_template_kwargs: - for k, v in chat_template_kwargs.items(): - if k not in task: - task[k] = v + if isinstance(chat_template_kwargs, dict): + for k, v in chat_template_kwargs.items(): + if k not in task: + task[k] = v + else: + raise ValueError("Invalid input: chat_template_kwargs must be a dict") request.prompt_token_ids = self.messages2ids(task) if len(request.prompt_token_ids) == 0: @@ -146,11 +149,7 @@ class ErnieProcessor(BaseDataProcessor): request = self._apply_default_parameters(request) if not request.get("eos_token_ids"): request["eos_token_ids"] = self.eos_token_ids - chat_template_kwargs = request.get("chat_template_kwargs") - if chat_template_kwargs: - for k, v in chat_template_kwargs.items(): - if k not in request: - request[k] = v + # processing stop_sequences stop_sequences = request.get("stop", []) if stop_sequences: @@ -172,6 +171,14 @@ class ErnieProcessor(BaseDataProcessor): req_id = request.get("request_id", None) data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}") else: + chat_template_kwargs = request.get("chat_template_kwargs") + if chat_template_kwargs: + if isinstance(chat_template_kwargs, dict): + for k, v in chat_template_kwargs.items(): + if k not in request: + request[k] = v + else: + raise ValueError("Invalid input: chat_template_kwargs must be a dict") request["prompt_token_ids"] = self.messages2ids(request) if len(request["prompt_token_ids"]) == 0: raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs") diff --git a/fastdeploy/input/ernie_vl_processor.py b/fastdeploy/input/ernie_vl_processor.py index 296b07b75..45af57617 100644 --- a/fastdeploy/input/ernie_vl_processor.py +++ b/fastdeploy/input/ernie_vl_processor.py @@ -198,11 +198,6 @@ class ErnieMoEVLProcessor(ErnieProcessor): request = self._apply_default_parameters(request) if not request.get("eos_token_ids"): request["eos_token_ids"] = self.eos_token_ids - chat_template_kwargs = request.get("chat_template_kwargs") - if chat_template_kwargs: - for k, v in chat_template_kwargs.items(): - if k not in request: - request[k] = v stop_sequences = request.get("stop", []) if stop_sequences: @@ -222,6 +217,14 @@ class ErnieMoEVLProcessor(ErnieProcessor): elif request.get("messages"): messages = request["messages"] self._check_mm_limits(messages) + chat_template_kwargs = request.get("chat_template_kwargs") + if chat_template_kwargs: + if isinstance(chat_template_kwargs, dict): + for k, v in chat_template_kwargs.items(): + if k not in request: + request[k] = v + else: + raise ValueError("Invalid input: chat_template_kwargs must be a dict") outputs = self.ernie_processor.request2ids(request) else: raise ValueError(f"Request must contain 'prompt', or 'messages': {request}") diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py index c4235265e..123bf5b43 100644 --- a/fastdeploy/input/text_processor.py +++ b/fastdeploy/input/text_processor.py @@ -222,9 +222,12 @@ class DataProcessor(BaseDataProcessor): task = request.to_dict() chat_template_kwargs = kwargs.get("chat_template_kwargs") if chat_template_kwargs: - for k, v in chat_template_kwargs.items(): - if k not in task: - task[k] = v + if isinstance(chat_template_kwargs, dict): + for k, v in chat_template_kwargs.items(): + if k not in task: + task[k] = v + else: + raise ValueError("Invalid input: chat_template_kwargs must be a dict") request.prompt_token_ids = self.messages2ids(task) else: raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.") @@ -276,9 +279,12 @@ class DataProcessor(BaseDataProcessor): raise ValueError("This model does not support chat_template.") chat_template_kwargs = request.get("chat_template_kwargs") if chat_template_kwargs: - for k, v in chat_template_kwargs.items(): - if k not in request: - request[k] = v + if isinstance(chat_template_kwargs, dict): + for k, v in chat_template_kwargs.items(): + if k not in request: + request[k] = v + else: + raise ValueError("Invalid input: chat_template_kwargs must be a dict") request["prompt_token_ids"] = self.messages2ids(request) else: raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")