diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index aeb99f33f..fa23aaaee 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import inspect import os import time import traceback @@ -112,7 +113,7 @@ class EngineClient: self.zmq_client = ZmqClient(model, mode) self.zmq_client.connect() - def format_and_add_data(self, prompts: dict): + async def format_and_add_data(self, prompts: dict): """ Format the request data and send the request to the server. """ @@ -123,10 +124,10 @@ class EngineClient: if "max_tokens" not in prompts: prompts["max_tokens"] = self.max_model_len - 1 - self.add_requests(prompts) + await self.add_requests(prompts) return prompts["prompt_token_ids"] - def add_requests(self, task): + async def add_requests(self, task): """ Add a new request to the queue. @@ -140,7 +141,10 @@ class EngineClient: task["preprocess_start_time"] = time.time() try: - self.data_processor.process_request_dict(task, self.max_model_len) + if inspect.iscoroutinefunction(self.data_processor.process_request_dict): + await self.data_processor.process_request_dict(task, self.max_model_len) + else: + self.data_processor.process_request_dict(task, self.max_model_len) task["prompt_token_ids_len"] = len(task["prompt_token_ids"]) input_ids_len = task["prompt_token_ids_len"] diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 16f5f78a0..cddfef634 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -119,7 +119,7 @@ class OpenAIServingChat: if "chat_template" not in current_req_dict: current_req_dict["chat_template"] = self.chat_template current_req_dict["arrival_time"] = time.time() - prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict) + prompt_token_ids = await self.engine_client.format_and_add_data(current_req_dict) text_after_process = current_req_dict.get("text_after_process") if isinstance(prompt_token_ids, np.ndarray): prompt_token_ids = prompt_token_ids.tolist() diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index aa5d5f3c5..3df22de9c 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -146,7 +146,7 @@ class OpenAIServingCompletion: request_id_idx = f"{request_id}-{idx}" current_req_dict = request.to_dict_for_infer(request_id_idx, prompt) current_req_dict["arrival_time"] = time.time() - prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict) # tokenize + prompt_token_ids = await self.engine_client.format_and_add_data(current_req_dict) # tokenize if isinstance(prompt_token_ids, np.ndarray): prompt_token_ids = prompt_token_ids.tolist() text_after_process_list.append(current_req_dict.get("text_after_process")) diff --git a/tests/utils/test_custom_chat_template.py b/tests/utils/test_custom_chat_template.py index acb6be960..71a617044 100644 --- a/tests/utils/test_custom_chat_template.py +++ b/tests/utils/test_custom_chat_template.py @@ -70,7 +70,7 @@ class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase): ): return prompt_token_ids - def mock_format_and_add_data(current_req_dict): + async def mock_format_and_add_data(current_req_dict): return current_req_dict self.chat_completion_handler.chat_completion_full_generator = mock_chat_completion_full_generator @@ -97,7 +97,7 @@ class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase): ): return prompt_token_ids - def mock_format_and_add_data(current_req_dict): + async def mock_format_and_add_data(current_req_dict): return current_req_dict self.chat_completion_handler.chat_completion_full_generator = mock_chat_completion_full_generator