[Fearture] Support mm model close prefix cache (#4459)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* [Feature] support prefix cache in DP

* fix

* Update common_engine.py

* Update common_engine.py

* Update common_engine.py

* Update common_engine.py

* [BugFix] fix workers more than 1

* fix

* Update api_server.py

* fix

* Update api_server.py

* fix

* [Fearture] Support mm model close prefix cache

* Update api_server.py

* Update engine_client.py

* Update engine_client.py

* add test

* Update test_chat.py

* fix

* fix

* Update test_chat.py

* Update test_chat.py

---------

Co-authored-by: ltd0924 <luotingdan@baidu.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
ltd0924
2025-10-21 15:37:59 +08:00
committed by GitHub
parent 2b53c4d684
commit fb76cdfb4f
4 changed files with 47 additions and 4 deletions

View File

@@ -21,6 +21,18 @@ from fastdeploy.utils import get_logger
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
DISABLE_PREFIX_CACHE_MM_MODEL: set[str] = {
"Ernie5ForCausalLM",
}
def is_mm_model_disable_prefix_cache(model_config):
"""
check if the model architecture is in DISABLE_PREFIX_CACHE_MM_MODEL
"""
return model_config._architecture in DISABLE_PREFIX_CACHE_MM_MODEL
class CacheStatus(Enum):
"""
cache status enum class

View File

@@ -86,6 +86,13 @@ class EngineClient:
self.enable_splitwise = splitwise_role != "mixed"
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
if self.enable_mm and self.enable_prefix_caching:
from fastdeploy.cache_manager.cache_data import (
is_mm_model_disable_prefix_cache,
)
self.disable_prefix_mm = is_mm_model_disable_prefix_cache(model_config)
if tensor_parallel_size <= max_chips_per_node:
self.is_master = True
else:
@@ -152,6 +159,16 @@ class EngineClient:
await self.add_requests(prompts)
return prompts["prompt_token_ids"]
def _check_mm_disable_prefix_cache(self, task):
is_multimodal_data = False
if self.disable_prefix_mm:
multimodal_inputs = task.get("multimodal_inputs", [])
if multimodal_inputs:
token_type_ids = multimodal_inputs.get("token_type_ids", [])
if token_type_ids:
is_multimodal_data = np.sum(token_type_ids) > 0
return is_multimodal_data
async def add_requests(self, task):
"""
Add a new request to the queue.
@@ -174,6 +191,16 @@ class EngineClient:
else:
self.data_processor.process_request_dict(task, self.max_model_len)
if self.enable_mm and self.enable_prefix_caching:
if self._check_mm_disable_prefix_cache(task):
api_server_logger.error(
"The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache"
)
raise EngineError(
"The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache",
error_code=400,
)
task["prompt_token_ids_len"] = len(task["prompt_token_ids"])
input_ids_len = task["prompt_token_ids_len"]
task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens"))

View File

@@ -269,7 +269,7 @@ class TokenProcessor:
if self.tokens_counter[task_id] == 0:
if task.messages is not None:
result.prompt = task.messages
result.num_cached_tokens = task.num_cached_tokens
result.num_cached_tokens = task.num_cached_tokens
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
result = self._process_per_token(task, i, token_ids, result, is_prefill)

View File

@@ -27,11 +27,13 @@ MODEL_NAME = os.getenv("MODEL_PATH") + "/ERNIE-4.5-0.3B-Paddle"
class TestChat(unittest.TestCase):
"""Test case for chat functionality"""
COMMON_PREFIX = "I am a highly capable, compassionate, and trustworthy AI assistant dedicated to providing you with exceptional support. Whatever questions or challenges you may have, I will utilize my full capabilities to offer thoughtful and comprehensive assistance. As your intelligent companion, I consistently maintain honesty, transparency, and patience to ensure our interactions are both productive and enjoyable."
PROMPTS = [
[{"content": "The color of tomato is ", "role": "user"}],
[{"content": "The equation 2+3= ", "role": "user"}],
[{"content": "The equation 4-1= ", "role": "user"}],
[{"content": "PaddlePaddle is ", "role": "user"}],
[{"content": COMMON_PREFIX + "The color of tomato is ", "role": "user"}],
[{"content": COMMON_PREFIX + "The equation 2+3= ", "role": "user"}],
[{"content": COMMON_PREFIX + "The equation 4-1= ", "role": "user"}],
]
@classmethod
@@ -58,6 +60,8 @@ class TestChat(unittest.TestCase):
def test_chat(self):
outputs = self.llm.chat(messages=self.PROMPTS, sampling_params=None)
self.assertEqual(len(self.PROMPTS), len(outputs))
self.assertEqual(outputs[-1].num_cached_tokens, outputs[-2].num_cached_tokens)
self.assertEqual(outputs[-1].num_cached_tokens, 64)
def test_chat_with_tools(self):
"""Test chat with tools: