mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-27 02:20:31 +08:00 
			
		
		
		
	[Fearture] Support mm model close prefix cache (#4502)
* support mm prefix cache close * add * fix * fix * fix --------- Co-authored-by: ltd0924 <luotingdan@baidu.com>
This commit is contained in:
		| @@ -20,6 +20,17 @@ from fastdeploy.utils import get_logger | ||||
|  | ||||
| logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log") | ||||
|  | ||||
| DISABLE_PREFIX_CACHE_MM_MODEL: set[str] = { | ||||
|     "Ernie5ForCausalLM", | ||||
| } | ||||
|  | ||||
|  | ||||
| def is_mm_model_disable_prefix_cache(model_arch): | ||||
|     """ | ||||
|     check if the model architecture is in DISABLE_PREFIX_CACHE_MM_MODEL | ||||
|     """ | ||||
|     return model_arch in DISABLE_PREFIX_CACHE_MM_MODEL | ||||
|  | ||||
|  | ||||
| class CacheStatus(Enum): | ||||
|     """ | ||||
|   | ||||
| @@ -55,15 +55,24 @@ class EngineClient: | ||||
|         enable_logprob=False, | ||||
|         workers=1, | ||||
|         tool_parser=None, | ||||
|         enable_prefix_caching=None, | ||||
|     ): | ||||
|         import fastdeploy.model_executor.models  # noqa: F401 | ||||
|  | ||||
|         architectures = ModelConfig({"model": model_name_or_path}).architectures[0] | ||||
|         self.enable_prefix_caching = enable_prefix_caching | ||||
|         if MultimodalRegistry.contains_model(architectures): | ||||
|             self.enable_mm = True | ||||
|         else: | ||||
|             self.enable_mm = False | ||||
|  | ||||
|         if self.enable_mm and self.enable_prefix_caching: | ||||
|             from fastdeploy.cache_manager.cache_data import ( | ||||
|                 is_mm_model_disable_prefix_cache, | ||||
|             ) | ||||
|  | ||||
|             self.disable_prefix_mm = is_mm_model_disable_prefix_cache(architectures) | ||||
|  | ||||
|         input_processor = InputPreprocessor( | ||||
|             tokenizer, | ||||
|             reasoning_parser, | ||||
| @@ -127,6 +136,16 @@ class EngineClient: | ||||
|         await self.add_requests(prompts) | ||||
|         return prompts["prompt_token_ids"] | ||||
|  | ||||
|     def _check_mm_disable_prefix_cache(self, task): | ||||
|         is_multimodal_data = False | ||||
|         if self.disable_prefix_mm: | ||||
|             multimodal_inputs = task.get("multimodal_inputs", []) | ||||
|             if multimodal_inputs: | ||||
|                 token_type_ids = multimodal_inputs.get("token_type_ids", []) | ||||
|                 if token_type_ids: | ||||
|                     is_multimodal_data = np.sum(token_type_ids) > 0 | ||||
|         return is_multimodal_data | ||||
|  | ||||
|     async def add_requests(self, task): | ||||
|         """ | ||||
|         Add a new request to the queue. | ||||
| @@ -146,6 +165,16 @@ class EngineClient: | ||||
|             else: | ||||
|                 self.data_processor.process_request_dict(task, self.max_model_len) | ||||
|  | ||||
|             if self.enable_mm and self.enable_prefix_caching: | ||||
|                 if self._check_mm_disable_prefix_cache(task): | ||||
|                     api_server_logger.error( | ||||
|                         "The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache" | ||||
|                     ) | ||||
|                     raise EngineError( | ||||
|                         "The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache", | ||||
|                         error_code=400, | ||||
|                     ) | ||||
|  | ||||
|             task["prompt_token_ids_len"] = len(task["prompt_token_ids"]) | ||||
|             input_ids_len = task["prompt_token_ids_len"] | ||||
|             task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens")) | ||||
|   | ||||
| @@ -170,6 +170,7 @@ async def lifespan(app: FastAPI): | ||||
|         enable_logprob=args.enable_logprob, | ||||
|         workers=args.workers, | ||||
|         tool_parser=args.tool_call_parser, | ||||
|         enable_prefix_caching=args.enable_prefix_caching, | ||||
|     ) | ||||
|     await engine_client.connection_manager.initialize() | ||||
|     app.state.dynamic_load_weight = args.dynamic_load_weight | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 ltd0924
					ltd0924