diff --git a/fastdeploy/config.py b/fastdeploy/config.py index fd6a241ae..a98e70a97 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1694,9 +1694,9 @@ class FDConfig: logger.info( "Static Graph does not support to be started together with RL Training, and automatically switch to dynamic graph!" ) - if self.device_config is not None and self.device_config.device_type != "cuda": + if not current_platform.is_cuda(): self.graph_opt_config.use_cudagraph = False - logger.info(f"CUDAGraph only support on GPU, current device type is {self.device_config.device_type}!") + logger.info("CUDAGraph currently only support on GPU!") if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph: if self.scheduler_config.max_num_seqs < self.parallel_config.tensor_parallel_size: self.parallel_config.use_sequence_parallel_moe = False diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index b0a50f94a..cf9c464f5 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -1233,10 +1233,6 @@ class EngineArgs: all_dict = asdict(self) model_cfg = ModelConfig(all_dict) - # XPU currently disable prefix cache for VL model - if current_platform.is_xpu() and (self.enable_mm or model_cfg.enable_mm): - self.enable_prefix_caching = False - if not model_cfg.is_unified_ckpt and hasattr(model_cfg, "tensor_parallel_size"): self.tensor_parallel_size = model_cfg.tensor_parallel_size diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index c5ade2af7..182ad701f 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -108,6 +108,10 @@ class XPUModelRunner(ModelRunnerBase): "matmul_v2", "fused_gemm_epilogue", ] + if self.cache_config.max_encoder_cache > 0: + self.encoder_cache: dict[str, paddle.Tensor] = {} + else: + self.encoder_cache = None self.device_id = device_id self.speculative_method = self.fd_config.speculative_config.method @@ -161,6 +165,183 @@ class XPUModelRunner(ModelRunnerBase): else: return 0 + def get_chunked_inputs(self, req: Request): + """ + Get inputs in current chunk + """ + prefill_start_index = req.prefill_start_index + prefill_end_index = req.prefill_end_index + inputs = req.multimodal_inputs + input_ids = inputs["input_ids"][prefill_start_index:prefill_end_index] + token_type_ids = inputs["token_type_ids"][prefill_start_index:prefill_end_index] + image_type_ids = inputs["image_type_ids"][req.image_type_ids_start : req.image_type_ids_end] + images = inputs["images"][req.image_start : req.image_end] + grid_thw = inputs["grid_thw"][req.num_image_start : req.num_image_end] + mm_hashes = inputs["mm_hashes"][req.num_image_start : req.num_image_end] + + return ( + input_ids, + token_type_ids, + image_type_ids, + images, + grid_thw, + mm_hashes, + ) + + def batch_uncached_inputs(self, req: Request): + """ + Batch uncached multimodal inputs + """ + (input_ids, token_type_ids, image_type_ids, images, grid_thw, mm_hashes) = self.get_chunked_inputs(req) + + image_type_ids_size = grid_thw[:, 0] + image_type_ids_split = np.cumsum(image_type_ids_size)[:-1] + image_type_ids_lst = np.array_split(image_type_ids, image_type_ids_split, axis=0) + + images_size = np.prod(grid_thw, axis=1) + images_split = np.cumsum(images_size)[:-1] + images_lst = np.array_split(images, images_split, axis=0) + + assert len(image_type_ids_lst) == len( + mm_hashes + ), f"image_type_ids_lst length {len(image_type_ids_lst)} != mm_hashes length {len(mm_hashes)}" + assert len(images_lst) == len( + mm_hashes + ), f"images_lst length {len(images_lst)} != mm_hashes length {len(mm_hashes)}" + + uncached_image_type_ids = [] + uncached_images = [] + uncached_grid_thw = [] + uncached_mm_hashes = [] + for i, mm_hash in enumerate(mm_hashes): + if mm_hash in self.encoder_cache: + continue + uncached_image_type_ids.append(image_type_ids_lst[i]) + uncached_images.append(images_lst[i]) + uncached_grid_thw.append(grid_thw[i]) + uncached_mm_hashes.append(mm_hash) + + uncached_input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64) + uncached_token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64) + if len(uncached_mm_hashes) > 0: + uncached_image_type_ids = paddle.to_tensor(np.hstack(uncached_image_type_ids), dtype=paddle.int64) + uncached_images = paddle.to_tensor( + np.vstack(uncached_images), dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16" + ) + uncached_grid_thw = paddle.to_tensor(uncached_grid_thw, dtype=paddle.int64) + + return ( + uncached_input_ids, + uncached_token_type_ids, + uncached_image_type_ids, + uncached_images, + uncached_grid_thw, + uncached_mm_hashes, + ) + + def scatter_and_cache_features(self, image_features, inputs): + """ + Split batched image features and cache them + """ + merge_size = 2 + grid_thw = inputs["grid_thw"] + mm_hashes = inputs["mm_hashes"] + image_features_size = (paddle.prod(grid_thw[:, 1:], axis=1) // (merge_size**2)).tolist() + image_features_lst = paddle.split(image_features, image_features_size, axis=0) + + assert len(image_features_lst) == len( + mm_hashes + ), f"image_features_lst length {len(image_features_lst)} != mm_hashes length {len(mm_hashes)}" + for i, mm_hash in enumerate(mm_hashes): + self.encoder_cache[mm_hash] = image_features_lst[i].cpu() + + def _apply_mm_inputs(self, request: Request, multi_vision_inputs: dict, rope_3d_position_ids: dict): + """ + Apply multimodal inputs to share_inputs + - add image_features, extract and cache vision features from model + - add rope_emb, rotate position embeddings + """ + if self.encoder_cache: + evict_mm_hashes = request.get("evict_mm_hashes", None) + if evict_mm_hashes: + for mm_hash in evict_mm_hashes: + self.encoder_cache.pop(mm_hash, None) + + inputs = request.multimodal_inputs + if request.with_image: + if envs.FD_ENABLE_MAX_PREFILL: + multi_vision_inputs["images_lst"].append( + inputs["images"][request.image_start : request.image_end].cuda() + ) + multi_vision_inputs["grid_thw_lst"].extend( + inputs["grid_thw"][request.num_image_start : request.num_image_end] + ) + multi_vision_inputs["cu_seqlens"].extend( + inputs["vit_seqlen"][request.num_image_start : request.num_image_end] + ) + multi_vision_inputs["vit_position_ids_lst"].extend( + inputs["vit_position_ids"][request.num_image_start : request.num_image_end] + ) + else: + vision_inputs = inputs + if self.encoder_cache: + ( + vision_inputs["input_ids"], + vision_inputs["token_type_ids"], + vision_inputs["image_type_ids"], + vision_inputs["images"], + vision_inputs["grid_thw"], + vision_inputs["mm_hashes"], + ) = self.batch_uncached_inputs(request) + if len(vision_inputs["mm_hashes"]) > 0: + # uncached multimodal inputs exist + image_features = self.extract_vision_features(vision_inputs) + self.scatter_and_cache_features(image_features, vision_inputs) + + full_image_features_lst = [] + for mm_hash in inputs["mm_hashes"][request.num_image_start : request.num_image_end]: + feature = self.encoder_cache[mm_hash].cuda() + full_image_features_lst.append(feature) + image_features = paddle.concat(full_image_features_lst, axis=0) + else: + ( + input_ids, + token_type_ids, + image_type_ids, + images, + grid_thw, + mm_hashes, + ) = self.get_chunked_inputs(request) + vision_inputs["input_ids"] = paddle.to_tensor(input_ids, dtype=paddle.int64) + vision_inputs["token_type_ids"] = paddle.to_tensor(token_type_ids, dtype=paddle.int64) + vision_inputs["image_type_ids"] = paddle.to_tensor(image_type_ids, dtype=paddle.int64) + vision_inputs["images"] = paddle.to_tensor( + images, dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16" + ) + vision_inputs["grid_thw"] = paddle.to_tensor(grid_thw, dtype=paddle.int64) + vision_inputs["mm_hashes"] = mm_hashes + + image_features = self.extract_vision_features(vision_inputs) + + # part of the first image may be already cached + if "ernie" in self.model_config.model_type: + actual_image_token_num = paddle.sum(vision_inputs["input_ids"] == self.model_config.im_patch_id) + elif "qwen" in self.model_config.model_type: + actual_image_token_num = paddle.sum( + vision_inputs["input_ids"] == vision_inputs["image_patch_id"] + ) + paddle.sum(vision_inputs["input_ids"] == vision_inputs["video_patch_id"]) + else: + raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported") + self.share_inputs["image_features"] = image_features[-actual_image_token_num:] + + position_ids = request.multimodal_inputs["position_ids"] + rope_3d_position_ids["position_ids_idx"].append(request.idx) + rope_3d_position_ids["position_ids_lst"].append(position_ids) + rope_3d_position_ids["position_ids_offset"].append( + position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1] + ) + rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048)) + def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int): """ Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1 @@ -189,51 +370,7 @@ class XPUModelRunner(ModelRunnerBase): prefill_end_index = request.prefill_end_index length = prefill_end_index - prefill_start_index if self.enable_mm: - inputs = request.multimodal_inputs - if request.with_image: - if envs.FD_ENABLE_MAX_PREFILL: - multi_vision_inputs["images_lst"].append( - paddle.to_tensor(inputs["images"][request.image_start : request.image_end]) - ) - multi_vision_inputs["grid_thw_lst"].extend( - inputs["grid_thw"][request.num_image_start : request.num_image_end] - ) - multi_vision_inputs["cu_seqlens"].extend( - inputs["vit_seqlen"][request.num_image_start : request.num_image_end] - ) - multi_vision_inputs["vit_position_ids_lst"].extend( - inputs["vit_position_ids"][request.num_image_start : request.num_image_end] - ) - else: - vision_inputs = {} - vision_inputs["input_ids"] = paddle.to_tensor( - inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64 - ) - vision_inputs["token_type_ids"] = paddle.to_tensor( - inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64 - ) - vision_inputs["image_type_ids"] = paddle.to_tensor( - inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end], - dtype=paddle.int64, - ) - vision_inputs["images"] = paddle.to_tensor( - inputs["images"][request.image_start : request.image_end], - dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16", - ) - vision_inputs["grid_thw"] = paddle.to_tensor( - inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64" - ) - self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs) - else: - self.share_inputs["image_features"] = None - - position_ids = request.multimodal_inputs["position_ids"] - rope_3d_position_ids["position_ids_idx"].append(idx) - rope_3d_position_ids["position_ids_lst"].append(position_ids) - rope_3d_position_ids["position_ids_offset"].append( - position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1] - ) - rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048)) + self._apply_mm_inputs(request, multi_vision_inputs, rope_3d_position_ids) if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None: # Enable thinking