mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Fearture] Support mm model close prefix cache (#4459)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* [Feature] support prefix cache in DP * fix * Update common_engine.py * Update common_engine.py * Update common_engine.py * Update common_engine.py * [BugFix] fix workers more than 1 * fix * Update api_server.py * fix * Update api_server.py * fix * [Fearture] Support mm model close prefix cache * Update api_server.py * Update engine_client.py * Update engine_client.py * add test * Update test_chat.py * fix * fix * Update test_chat.py * Update test_chat.py --------- Co-authored-by: ltd0924 <luotingdan@baidu.com> Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -21,6 +21,18 @@ from fastdeploy.utils import get_logger
|
||||
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
|
||||
|
||||
|
||||
DISABLE_PREFIX_CACHE_MM_MODEL: set[str] = {
|
||||
"Ernie5ForCausalLM",
|
||||
}
|
||||
|
||||
|
||||
def is_mm_model_disable_prefix_cache(model_config):
|
||||
"""
|
||||
check if the model architecture is in DISABLE_PREFIX_CACHE_MM_MODEL
|
||||
"""
|
||||
return model_config._architecture in DISABLE_PREFIX_CACHE_MM_MODEL
|
||||
|
||||
|
||||
class CacheStatus(Enum):
|
||||
"""
|
||||
cache status enum class
|
||||
|
||||
@@ -86,6 +86,13 @@ class EngineClient:
|
||||
self.enable_splitwise = splitwise_role != "mixed"
|
||||
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
|
||||
if self.enable_mm and self.enable_prefix_caching:
|
||||
from fastdeploy.cache_manager.cache_data import (
|
||||
is_mm_model_disable_prefix_cache,
|
||||
)
|
||||
|
||||
self.disable_prefix_mm = is_mm_model_disable_prefix_cache(model_config)
|
||||
|
||||
if tensor_parallel_size <= max_chips_per_node:
|
||||
self.is_master = True
|
||||
else:
|
||||
@@ -152,6 +159,16 @@ class EngineClient:
|
||||
await self.add_requests(prompts)
|
||||
return prompts["prompt_token_ids"]
|
||||
|
||||
def _check_mm_disable_prefix_cache(self, task):
|
||||
is_multimodal_data = False
|
||||
if self.disable_prefix_mm:
|
||||
multimodal_inputs = task.get("multimodal_inputs", [])
|
||||
if multimodal_inputs:
|
||||
token_type_ids = multimodal_inputs.get("token_type_ids", [])
|
||||
if token_type_ids:
|
||||
is_multimodal_data = np.sum(token_type_ids) > 0
|
||||
return is_multimodal_data
|
||||
|
||||
async def add_requests(self, task):
|
||||
"""
|
||||
Add a new request to the queue.
|
||||
@@ -174,6 +191,16 @@ class EngineClient:
|
||||
else:
|
||||
self.data_processor.process_request_dict(task, self.max_model_len)
|
||||
|
||||
if self.enable_mm and self.enable_prefix_caching:
|
||||
if self._check_mm_disable_prefix_cache(task):
|
||||
api_server_logger.error(
|
||||
"The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache"
|
||||
)
|
||||
raise EngineError(
|
||||
"The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache",
|
||||
error_code=400,
|
||||
)
|
||||
|
||||
task["prompt_token_ids_len"] = len(task["prompt_token_ids"])
|
||||
input_ids_len = task["prompt_token_ids_len"]
|
||||
task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens"))
|
||||
|
||||
@@ -269,7 +269,7 @@ class TokenProcessor:
|
||||
if self.tokens_counter[task_id] == 0:
|
||||
if task.messages is not None:
|
||||
result.prompt = task.messages
|
||||
result.num_cached_tokens = task.num_cached_tokens
|
||||
result.num_cached_tokens = task.num_cached_tokens
|
||||
|
||||
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
|
||||
result = self._process_per_token(task, i, token_ids, result, is_prefill)
|
||||
|
||||
@@ -27,11 +27,13 @@ MODEL_NAME = os.getenv("MODEL_PATH") + "/ERNIE-4.5-0.3B-Paddle"
|
||||
class TestChat(unittest.TestCase):
|
||||
"""Test case for chat functionality"""
|
||||
|
||||
COMMON_PREFIX = "I am a highly capable, compassionate, and trustworthy AI assistant dedicated to providing you with exceptional support. Whatever questions or challenges you may have, I will utilize my full capabilities to offer thoughtful and comprehensive assistance. As your intelligent companion, I consistently maintain honesty, transparency, and patience to ensure our interactions are both productive and enjoyable."
|
||||
|
||||
PROMPTS = [
|
||||
[{"content": "The color of tomato is ", "role": "user"}],
|
||||
[{"content": "The equation 2+3= ", "role": "user"}],
|
||||
[{"content": "The equation 4-1= ", "role": "user"}],
|
||||
[{"content": "PaddlePaddle is ", "role": "user"}],
|
||||
[{"content": COMMON_PREFIX + "The color of tomato is ", "role": "user"}],
|
||||
[{"content": COMMON_PREFIX + "The equation 2+3= ", "role": "user"}],
|
||||
[{"content": COMMON_PREFIX + "The equation 4-1= ", "role": "user"}],
|
||||
]
|
||||
|
||||
@classmethod
|
||||
@@ -58,6 +60,8 @@ class TestChat(unittest.TestCase):
|
||||
def test_chat(self):
|
||||
outputs = self.llm.chat(messages=self.PROMPTS, sampling_params=None)
|
||||
self.assertEqual(len(self.PROMPTS), len(outputs))
|
||||
self.assertEqual(outputs[-1].num_cached_tokens, outputs[-2].num_cached_tokens)
|
||||
self.assertEqual(outputs[-1].num_cached_tokens, 64)
|
||||
|
||||
def test_chat_with_tools(self):
|
||||
"""Test chat with tools:
|
||||
|
||||
Reference in New Issue
Block a user