mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-26 01:50:33 +08:00
[Fearture] Support mm model close prefix cache (#4502)
* support mm prefix cache close * add * fix * fix * fix --------- Co-authored-by: ltd0924 <luotingdan@baidu.com>
This commit is contained in:
@@ -20,6 +20,17 @@ from fastdeploy.utils import get_logger
|
|||||||
|
|
||||||
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
|
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
|
||||||
|
|
||||||
|
DISABLE_PREFIX_CACHE_MM_MODEL: set[str] = {
|
||||||
|
"Ernie5ForCausalLM",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_mm_model_disable_prefix_cache(model_arch):
|
||||||
|
"""
|
||||||
|
check if the model architecture is in DISABLE_PREFIX_CACHE_MM_MODEL
|
||||||
|
"""
|
||||||
|
return model_arch in DISABLE_PREFIX_CACHE_MM_MODEL
|
||||||
|
|
||||||
|
|
||||||
class CacheStatus(Enum):
|
class CacheStatus(Enum):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -55,15 +55,24 @@ class EngineClient:
|
|||||||
enable_logprob=False,
|
enable_logprob=False,
|
||||||
workers=1,
|
workers=1,
|
||||||
tool_parser=None,
|
tool_parser=None,
|
||||||
|
enable_prefix_caching=None,
|
||||||
):
|
):
|
||||||
import fastdeploy.model_executor.models # noqa: F401
|
import fastdeploy.model_executor.models # noqa: F401
|
||||||
|
|
||||||
architectures = ModelConfig({"model": model_name_or_path}).architectures[0]
|
architectures = ModelConfig({"model": model_name_or_path}).architectures[0]
|
||||||
|
self.enable_prefix_caching = enable_prefix_caching
|
||||||
if MultimodalRegistry.contains_model(architectures):
|
if MultimodalRegistry.contains_model(architectures):
|
||||||
self.enable_mm = True
|
self.enable_mm = True
|
||||||
else:
|
else:
|
||||||
self.enable_mm = False
|
self.enable_mm = False
|
||||||
|
|
||||||
|
if self.enable_mm and self.enable_prefix_caching:
|
||||||
|
from fastdeploy.cache_manager.cache_data import (
|
||||||
|
is_mm_model_disable_prefix_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.disable_prefix_mm = is_mm_model_disable_prefix_cache(architectures)
|
||||||
|
|
||||||
input_processor = InputPreprocessor(
|
input_processor = InputPreprocessor(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
reasoning_parser,
|
reasoning_parser,
|
||||||
@@ -127,6 +136,16 @@ class EngineClient:
|
|||||||
await self.add_requests(prompts)
|
await self.add_requests(prompts)
|
||||||
return prompts["prompt_token_ids"]
|
return prompts["prompt_token_ids"]
|
||||||
|
|
||||||
|
def _check_mm_disable_prefix_cache(self, task):
|
||||||
|
is_multimodal_data = False
|
||||||
|
if self.disable_prefix_mm:
|
||||||
|
multimodal_inputs = task.get("multimodal_inputs", [])
|
||||||
|
if multimodal_inputs:
|
||||||
|
token_type_ids = multimodal_inputs.get("token_type_ids", [])
|
||||||
|
if token_type_ids:
|
||||||
|
is_multimodal_data = np.sum(token_type_ids) > 0
|
||||||
|
return is_multimodal_data
|
||||||
|
|
||||||
async def add_requests(self, task):
|
async def add_requests(self, task):
|
||||||
"""
|
"""
|
||||||
Add a new request to the queue.
|
Add a new request to the queue.
|
||||||
@@ -146,6 +165,16 @@ class EngineClient:
|
|||||||
else:
|
else:
|
||||||
self.data_processor.process_request_dict(task, self.max_model_len)
|
self.data_processor.process_request_dict(task, self.max_model_len)
|
||||||
|
|
||||||
|
if self.enable_mm and self.enable_prefix_caching:
|
||||||
|
if self._check_mm_disable_prefix_cache(task):
|
||||||
|
api_server_logger.error(
|
||||||
|
"The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache"
|
||||||
|
)
|
||||||
|
raise EngineError(
|
||||||
|
"The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache",
|
||||||
|
error_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
task["prompt_token_ids_len"] = len(task["prompt_token_ids"])
|
task["prompt_token_ids_len"] = len(task["prompt_token_ids"])
|
||||||
input_ids_len = task["prompt_token_ids_len"]
|
input_ids_len = task["prompt_token_ids_len"]
|
||||||
task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens"))
|
task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens"))
|
||||||
|
|||||||
@@ -170,6 +170,7 @@ async def lifespan(app: FastAPI):
|
|||||||
enable_logprob=args.enable_logprob,
|
enable_logprob=args.enable_logprob,
|
||||||
workers=args.workers,
|
workers=args.workers,
|
||||||
tool_parser=args.tool_call_parser,
|
tool_parser=args.tool_call_parser,
|
||||||
|
enable_prefix_caching=args.enable_prefix_caching,
|
||||||
)
|
)
|
||||||
await engine_client.connection_manager.initialize()
|
await engine_client.connection_manager.initialize()
|
||||||
app.state.dynamic_load_weight = args.dynamic_load_weight
|
app.state.dynamic_load_weight = args.dynamic_load_weight
|
||||||
|
|||||||
Reference in New Issue
Block a user