mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-25 09:31:38 +08:00
[Fearture] Support mm model close prefix cache (#4502)
* support mm prefix cache close * add * fix * fix * fix --------- Co-authored-by: ltd0924 <luotingdan@baidu.com>
This commit is contained in:
@@ -20,6 +20,17 @@ from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
|
||||
|
||||
DISABLE_PREFIX_CACHE_MM_MODEL: set[str] = {
|
||||
"Ernie5ForCausalLM",
|
||||
}
|
||||
|
||||
|
||||
def is_mm_model_disable_prefix_cache(model_arch):
|
||||
"""
|
||||
check if the model architecture is in DISABLE_PREFIX_CACHE_MM_MODEL
|
||||
"""
|
||||
return model_arch in DISABLE_PREFIX_CACHE_MM_MODEL
|
||||
|
||||
|
||||
class CacheStatus(Enum):
|
||||
"""
|
||||
|
||||
@@ -55,15 +55,24 @@ class EngineClient:
|
||||
enable_logprob=False,
|
||||
workers=1,
|
||||
tool_parser=None,
|
||||
enable_prefix_caching=None,
|
||||
):
|
||||
import fastdeploy.model_executor.models # noqa: F401
|
||||
|
||||
architectures = ModelConfig({"model": model_name_or_path}).architectures[0]
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
if MultimodalRegistry.contains_model(architectures):
|
||||
self.enable_mm = True
|
||||
else:
|
||||
self.enable_mm = False
|
||||
|
||||
if self.enable_mm and self.enable_prefix_caching:
|
||||
from fastdeploy.cache_manager.cache_data import (
|
||||
is_mm_model_disable_prefix_cache,
|
||||
)
|
||||
|
||||
self.disable_prefix_mm = is_mm_model_disable_prefix_cache(architectures)
|
||||
|
||||
input_processor = InputPreprocessor(
|
||||
tokenizer,
|
||||
reasoning_parser,
|
||||
@@ -127,6 +136,16 @@ class EngineClient:
|
||||
await self.add_requests(prompts)
|
||||
return prompts["prompt_token_ids"]
|
||||
|
||||
def _check_mm_disable_prefix_cache(self, task):
|
||||
is_multimodal_data = False
|
||||
if self.disable_prefix_mm:
|
||||
multimodal_inputs = task.get("multimodal_inputs", [])
|
||||
if multimodal_inputs:
|
||||
token_type_ids = multimodal_inputs.get("token_type_ids", [])
|
||||
if token_type_ids:
|
||||
is_multimodal_data = np.sum(token_type_ids) > 0
|
||||
return is_multimodal_data
|
||||
|
||||
async def add_requests(self, task):
|
||||
"""
|
||||
Add a new request to the queue.
|
||||
@@ -146,6 +165,16 @@ class EngineClient:
|
||||
else:
|
||||
self.data_processor.process_request_dict(task, self.max_model_len)
|
||||
|
||||
if self.enable_mm and self.enable_prefix_caching:
|
||||
if self._check_mm_disable_prefix_cache(task):
|
||||
api_server_logger.error(
|
||||
"The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache"
|
||||
)
|
||||
raise EngineError(
|
||||
"The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache",
|
||||
error_code=400,
|
||||
)
|
||||
|
||||
task["prompt_token_ids_len"] = len(task["prompt_token_ids"])
|
||||
input_ids_len = task["prompt_token_ids_len"]
|
||||
task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens"))
|
||||
|
||||
@@ -170,6 +170,7 @@ async def lifespan(app: FastAPI):
|
||||
enable_logprob=args.enable_logprob,
|
||||
workers=args.workers,
|
||||
tool_parser=args.tool_call_parser,
|
||||
enable_prefix_caching=args.enable_prefix_caching,
|
||||
)
|
||||
await engine_client.connection_manager.initialize()
|
||||
app.state.dynamic_load_weight = args.dynamic_load_weight
|
||||
|
||||
Reference in New Issue
Block a user