mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-26 01:50:33 +08:00 
			
		
		
		
	[Fearture] Support mm model close prefix cache (#4502)
* support mm prefix cache close * add * fix * fix * fix --------- Co-authored-by: ltd0924 <luotingdan@baidu.com>
This commit is contained in:
		| @@ -20,6 +20,17 @@ from fastdeploy.utils import get_logger | |||||||
|  |  | ||||||
| logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log") | logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log") | ||||||
|  |  | ||||||
|  | DISABLE_PREFIX_CACHE_MM_MODEL: set[str] = { | ||||||
|  |     "Ernie5ForCausalLM", | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def is_mm_model_disable_prefix_cache(model_arch): | ||||||
|  |     """ | ||||||
|  |     check if the model architecture is in DISABLE_PREFIX_CACHE_MM_MODEL | ||||||
|  |     """ | ||||||
|  |     return model_arch in DISABLE_PREFIX_CACHE_MM_MODEL | ||||||
|  |  | ||||||
|  |  | ||||||
| class CacheStatus(Enum): | class CacheStatus(Enum): | ||||||
|     """ |     """ | ||||||
|   | |||||||
| @@ -55,15 +55,24 @@ class EngineClient: | |||||||
|         enable_logprob=False, |         enable_logprob=False, | ||||||
|         workers=1, |         workers=1, | ||||||
|         tool_parser=None, |         tool_parser=None, | ||||||
|  |         enable_prefix_caching=None, | ||||||
|     ): |     ): | ||||||
|         import fastdeploy.model_executor.models  # noqa: F401 |         import fastdeploy.model_executor.models  # noqa: F401 | ||||||
|  |  | ||||||
|         architectures = ModelConfig({"model": model_name_or_path}).architectures[0] |         architectures = ModelConfig({"model": model_name_or_path}).architectures[0] | ||||||
|  |         self.enable_prefix_caching = enable_prefix_caching | ||||||
|         if MultimodalRegistry.contains_model(architectures): |         if MultimodalRegistry.contains_model(architectures): | ||||||
|             self.enable_mm = True |             self.enable_mm = True | ||||||
|         else: |         else: | ||||||
|             self.enable_mm = False |             self.enable_mm = False | ||||||
|  |  | ||||||
|  |         if self.enable_mm and self.enable_prefix_caching: | ||||||
|  |             from fastdeploy.cache_manager.cache_data import ( | ||||||
|  |                 is_mm_model_disable_prefix_cache, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             self.disable_prefix_mm = is_mm_model_disable_prefix_cache(architectures) | ||||||
|  |  | ||||||
|         input_processor = InputPreprocessor( |         input_processor = InputPreprocessor( | ||||||
|             tokenizer, |             tokenizer, | ||||||
|             reasoning_parser, |             reasoning_parser, | ||||||
| @@ -127,6 +136,16 @@ class EngineClient: | |||||||
|         await self.add_requests(prompts) |         await self.add_requests(prompts) | ||||||
|         return prompts["prompt_token_ids"] |         return prompts["prompt_token_ids"] | ||||||
|  |  | ||||||
|  |     def _check_mm_disable_prefix_cache(self, task): | ||||||
|  |         is_multimodal_data = False | ||||||
|  |         if self.disable_prefix_mm: | ||||||
|  |             multimodal_inputs = task.get("multimodal_inputs", []) | ||||||
|  |             if multimodal_inputs: | ||||||
|  |                 token_type_ids = multimodal_inputs.get("token_type_ids", []) | ||||||
|  |                 if token_type_ids: | ||||||
|  |                     is_multimodal_data = np.sum(token_type_ids) > 0 | ||||||
|  |         return is_multimodal_data | ||||||
|  |  | ||||||
|     async def add_requests(self, task): |     async def add_requests(self, task): | ||||||
|         """ |         """ | ||||||
|         Add a new request to the queue. |         Add a new request to the queue. | ||||||
| @@ -146,6 +165,16 @@ class EngineClient: | |||||||
|             else: |             else: | ||||||
|                 self.data_processor.process_request_dict(task, self.max_model_len) |                 self.data_processor.process_request_dict(task, self.max_model_len) | ||||||
|  |  | ||||||
|  |             if self.enable_mm and self.enable_prefix_caching: | ||||||
|  |                 if self._check_mm_disable_prefix_cache(task): | ||||||
|  |                     api_server_logger.error( | ||||||
|  |                         "The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache" | ||||||
|  |                     ) | ||||||
|  |                     raise EngineError( | ||||||
|  |                         "The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache", | ||||||
|  |                         error_code=400, | ||||||
|  |                     ) | ||||||
|  |  | ||||||
|             task["prompt_token_ids_len"] = len(task["prompt_token_ids"]) |             task["prompt_token_ids_len"] = len(task["prompt_token_ids"]) | ||||||
|             input_ids_len = task["prompt_token_ids_len"] |             input_ids_len = task["prompt_token_ids_len"] | ||||||
|             task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens")) |             task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens")) | ||||||
|   | |||||||
| @@ -170,6 +170,7 @@ async def lifespan(app: FastAPI): | |||||||
|         enable_logprob=args.enable_logprob, |         enable_logprob=args.enable_logprob, | ||||||
|         workers=args.workers, |         workers=args.workers, | ||||||
|         tool_parser=args.tool_call_parser, |         tool_parser=args.tool_call_parser, | ||||||
|  |         enable_prefix_caching=args.enable_prefix_caching, | ||||||
|     ) |     ) | ||||||
|     await engine_client.connection_manager.initialize() |     await engine_client.connection_manager.initialize() | ||||||
|     app.state.dynamic_load_weight = args.dynamic_load_weight |     app.state.dynamic_load_weight = args.dynamic_load_weight | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 ltd0924
					ltd0924