[Fearture] Support mm model close prefix cache (#4502)

* support mm prefix cache close

* add

* fix

* fix

* fix

---------

Co-authored-by: ltd0924 <luotingdan@baidu.com>
This commit is contained in:
ltd0924
2025-10-21 09:56:47 +08:00
committed by GitHub
parent 9558912475
commit 3cd9d3060a
3 changed files with 41 additions and 0 deletions

View File

@@ -20,6 +20,17 @@ from fastdeploy.utils import get_logger
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log") logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
DISABLE_PREFIX_CACHE_MM_MODEL: set[str] = {
"Ernie5ForCausalLM",
}
def is_mm_model_disable_prefix_cache(model_arch):
"""
check if the model architecture is in DISABLE_PREFIX_CACHE_MM_MODEL
"""
return model_arch in DISABLE_PREFIX_CACHE_MM_MODEL
class CacheStatus(Enum): class CacheStatus(Enum):
""" """

View File

@@ -55,15 +55,24 @@ class EngineClient:
enable_logprob=False, enable_logprob=False,
workers=1, workers=1,
tool_parser=None, tool_parser=None,
enable_prefix_caching=None,
): ):
import fastdeploy.model_executor.models # noqa: F401 import fastdeploy.model_executor.models # noqa: F401
architectures = ModelConfig({"model": model_name_or_path}).architectures[0] architectures = ModelConfig({"model": model_name_or_path}).architectures[0]
self.enable_prefix_caching = enable_prefix_caching
if MultimodalRegistry.contains_model(architectures): if MultimodalRegistry.contains_model(architectures):
self.enable_mm = True self.enable_mm = True
else: else:
self.enable_mm = False self.enable_mm = False
if self.enable_mm and self.enable_prefix_caching:
from fastdeploy.cache_manager.cache_data import (
is_mm_model_disable_prefix_cache,
)
self.disable_prefix_mm = is_mm_model_disable_prefix_cache(architectures)
input_processor = InputPreprocessor( input_processor = InputPreprocessor(
tokenizer, tokenizer,
reasoning_parser, reasoning_parser,
@@ -127,6 +136,16 @@ class EngineClient:
await self.add_requests(prompts) await self.add_requests(prompts)
return prompts["prompt_token_ids"] return prompts["prompt_token_ids"]
def _check_mm_disable_prefix_cache(self, task):
is_multimodal_data = False
if self.disable_prefix_mm:
multimodal_inputs = task.get("multimodal_inputs", [])
if multimodal_inputs:
token_type_ids = multimodal_inputs.get("token_type_ids", [])
if token_type_ids:
is_multimodal_data = np.sum(token_type_ids) > 0
return is_multimodal_data
async def add_requests(self, task): async def add_requests(self, task):
""" """
Add a new request to the queue. Add a new request to the queue.
@@ -146,6 +165,16 @@ class EngineClient:
else: else:
self.data_processor.process_request_dict(task, self.max_model_len) self.data_processor.process_request_dict(task, self.max_model_len)
if self.enable_mm and self.enable_prefix_caching:
if self._check_mm_disable_prefix_cache(task):
api_server_logger.error(
"The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache"
)
raise EngineError(
"The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache",
error_code=400,
)
task["prompt_token_ids_len"] = len(task["prompt_token_ids"]) task["prompt_token_ids_len"] = len(task["prompt_token_ids"])
input_ids_len = task["prompt_token_ids_len"] input_ids_len = task["prompt_token_ids_len"]
task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens")) task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens"))

View File

@@ -170,6 +170,7 @@ async def lifespan(app: FastAPI):
enable_logprob=args.enable_logprob, enable_logprob=args.enable_logprob,
workers=args.workers, workers=args.workers,
tool_parser=args.tool_call_parser, tool_parser=args.tool_call_parser,
enable_prefix_caching=args.enable_prefix_caching,
) )
await engine_client.connection_manager.initialize() await engine_client.connection_manager.initialize()
app.state.dynamic_load_weight = args.dynamic_load_weight app.state.dynamic_load_weight = args.dynamic_load_weight