diff --git a/fastdeploy/cache_manager/cache_data.py b/fastdeploy/cache_manager/cache_data.py index 638da70bc..9a39b5211 100644 --- a/fastdeploy/cache_manager/cache_data.py +++ b/fastdeploy/cache_manager/cache_data.py @@ -20,6 +20,17 @@ from fastdeploy.utils import get_logger logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log") +DISABLE_PREFIX_CACHE_MM_MODEL: set[str] = { + "Ernie5ForCausalLM", +} + + +def is_mm_model_disable_prefix_cache(model_arch): + """ + check if the model architecture is in DISABLE_PREFIX_CACHE_MM_MODEL + """ + return model_arch in DISABLE_PREFIX_CACHE_MM_MODEL + class CacheStatus(Enum): """ diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 4a2efc3f6..13138dd8b 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -55,15 +55,24 @@ class EngineClient: enable_logprob=False, workers=1, tool_parser=None, + enable_prefix_caching=None, ): import fastdeploy.model_executor.models # noqa: F401 architectures = ModelConfig({"model": model_name_or_path}).architectures[0] + self.enable_prefix_caching = enable_prefix_caching if MultimodalRegistry.contains_model(architectures): self.enable_mm = True else: self.enable_mm = False + if self.enable_mm and self.enable_prefix_caching: + from fastdeploy.cache_manager.cache_data import ( + is_mm_model_disable_prefix_cache, + ) + + self.disable_prefix_mm = is_mm_model_disable_prefix_cache(architectures) + input_processor = InputPreprocessor( tokenizer, reasoning_parser, @@ -127,6 +136,16 @@ class EngineClient: await self.add_requests(prompts) return prompts["prompt_token_ids"] + def _check_mm_disable_prefix_cache(self, task): + is_multimodal_data = False + if self.disable_prefix_mm: + multimodal_inputs = task.get("multimodal_inputs", []) + if multimodal_inputs: + token_type_ids = multimodal_inputs.get("token_type_ids", []) + if token_type_ids: + is_multimodal_data = np.sum(token_type_ids) > 0 + return is_multimodal_data + async def add_requests(self, task): """ Add a new request to the queue. @@ -146,6 +165,16 @@ class EngineClient: else: self.data_processor.process_request_dict(task, self.max_model_len) + if self.enable_mm and self.enable_prefix_caching: + if self._check_mm_disable_prefix_cache(task): + api_server_logger.error( + "The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache" + ) + raise EngineError( + "The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache", + error_code=400, + ) + task["prompt_token_ids_len"] = len(task["prompt_token_ids"]) input_ids_len = task["prompt_token_ids_len"] task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens")) diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 48e02ab46..1d6dc65af 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -170,6 +170,7 @@ async def lifespan(app: FastAPI): enable_logprob=args.enable_logprob, workers=args.workers, tool_parser=args.tool_call_parser, + enable_prefix_caching=args.enable_prefix_caching, ) await engine_client.connection_manager.initialize() app.state.dynamic_load_weight = args.dynamic_load_weight