mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] xpu support mm prefix cache (#5356)
Co-authored-by: ddchenhao66 <dhaochen163.com>
This commit is contained in:
@@ -1694,9 +1694,9 @@ class FDConfig:
|
||||
logger.info(
|
||||
"Static Graph does not support to be started together with RL Training, and automatically switch to dynamic graph!"
|
||||
)
|
||||
if self.device_config is not None and self.device_config.device_type != "cuda":
|
||||
if not current_platform.is_cuda():
|
||||
self.graph_opt_config.use_cudagraph = False
|
||||
logger.info(f"CUDAGraph only support on GPU, current device type is {self.device_config.device_type}!")
|
||||
logger.info("CUDAGraph currently only support on GPU!")
|
||||
if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph:
|
||||
if self.scheduler_config.max_num_seqs < self.parallel_config.tensor_parallel_size:
|
||||
self.parallel_config.use_sequence_parallel_moe = False
|
||||
|
||||
@@ -1233,10 +1233,6 @@ class EngineArgs:
|
||||
all_dict = asdict(self)
|
||||
model_cfg = ModelConfig(all_dict)
|
||||
|
||||
# XPU currently disable prefix cache for VL model
|
||||
if current_platform.is_xpu() and (self.enable_mm or model_cfg.enable_mm):
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
if not model_cfg.is_unified_ckpt and hasattr(model_cfg, "tensor_parallel_size"):
|
||||
self.tensor_parallel_size = model_cfg.tensor_parallel_size
|
||||
|
||||
|
||||
@@ -108,6 +108,10 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
"matmul_v2",
|
||||
"fused_gemm_epilogue",
|
||||
]
|
||||
if self.cache_config.max_encoder_cache > 0:
|
||||
self.encoder_cache: dict[str, paddle.Tensor] = {}
|
||||
else:
|
||||
self.encoder_cache = None
|
||||
|
||||
self.device_id = device_id
|
||||
self.speculative_method = self.fd_config.speculative_config.method
|
||||
@@ -161,6 +165,183 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
else:
|
||||
return 0
|
||||
|
||||
def get_chunked_inputs(self, req: Request):
|
||||
"""
|
||||
Get inputs in current chunk
|
||||
"""
|
||||
prefill_start_index = req.prefill_start_index
|
||||
prefill_end_index = req.prefill_end_index
|
||||
inputs = req.multimodal_inputs
|
||||
input_ids = inputs["input_ids"][prefill_start_index:prefill_end_index]
|
||||
token_type_ids = inputs["token_type_ids"][prefill_start_index:prefill_end_index]
|
||||
image_type_ids = inputs["image_type_ids"][req.image_type_ids_start : req.image_type_ids_end]
|
||||
images = inputs["images"][req.image_start : req.image_end]
|
||||
grid_thw = inputs["grid_thw"][req.num_image_start : req.num_image_end]
|
||||
mm_hashes = inputs["mm_hashes"][req.num_image_start : req.num_image_end]
|
||||
|
||||
return (
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
image_type_ids,
|
||||
images,
|
||||
grid_thw,
|
||||
mm_hashes,
|
||||
)
|
||||
|
||||
def batch_uncached_inputs(self, req: Request):
|
||||
"""
|
||||
Batch uncached multimodal inputs
|
||||
"""
|
||||
(input_ids, token_type_ids, image_type_ids, images, grid_thw, mm_hashes) = self.get_chunked_inputs(req)
|
||||
|
||||
image_type_ids_size = grid_thw[:, 0]
|
||||
image_type_ids_split = np.cumsum(image_type_ids_size)[:-1]
|
||||
image_type_ids_lst = np.array_split(image_type_ids, image_type_ids_split, axis=0)
|
||||
|
||||
images_size = np.prod(grid_thw, axis=1)
|
||||
images_split = np.cumsum(images_size)[:-1]
|
||||
images_lst = np.array_split(images, images_split, axis=0)
|
||||
|
||||
assert len(image_type_ids_lst) == len(
|
||||
mm_hashes
|
||||
), f"image_type_ids_lst length {len(image_type_ids_lst)} != mm_hashes length {len(mm_hashes)}"
|
||||
assert len(images_lst) == len(
|
||||
mm_hashes
|
||||
), f"images_lst length {len(images_lst)} != mm_hashes length {len(mm_hashes)}"
|
||||
|
||||
uncached_image_type_ids = []
|
||||
uncached_images = []
|
||||
uncached_grid_thw = []
|
||||
uncached_mm_hashes = []
|
||||
for i, mm_hash in enumerate(mm_hashes):
|
||||
if mm_hash in self.encoder_cache:
|
||||
continue
|
||||
uncached_image_type_ids.append(image_type_ids_lst[i])
|
||||
uncached_images.append(images_lst[i])
|
||||
uncached_grid_thw.append(grid_thw[i])
|
||||
uncached_mm_hashes.append(mm_hash)
|
||||
|
||||
uncached_input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64)
|
||||
uncached_token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
|
||||
if len(uncached_mm_hashes) > 0:
|
||||
uncached_image_type_ids = paddle.to_tensor(np.hstack(uncached_image_type_ids), dtype=paddle.int64)
|
||||
uncached_images = paddle.to_tensor(
|
||||
np.vstack(uncached_images), dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16"
|
||||
)
|
||||
uncached_grid_thw = paddle.to_tensor(uncached_grid_thw, dtype=paddle.int64)
|
||||
|
||||
return (
|
||||
uncached_input_ids,
|
||||
uncached_token_type_ids,
|
||||
uncached_image_type_ids,
|
||||
uncached_images,
|
||||
uncached_grid_thw,
|
||||
uncached_mm_hashes,
|
||||
)
|
||||
|
||||
def scatter_and_cache_features(self, image_features, inputs):
|
||||
"""
|
||||
Split batched image features and cache them
|
||||
"""
|
||||
merge_size = 2
|
||||
grid_thw = inputs["grid_thw"]
|
||||
mm_hashes = inputs["mm_hashes"]
|
||||
image_features_size = (paddle.prod(grid_thw[:, 1:], axis=1) // (merge_size**2)).tolist()
|
||||
image_features_lst = paddle.split(image_features, image_features_size, axis=0)
|
||||
|
||||
assert len(image_features_lst) == len(
|
||||
mm_hashes
|
||||
), f"image_features_lst length {len(image_features_lst)} != mm_hashes length {len(mm_hashes)}"
|
||||
for i, mm_hash in enumerate(mm_hashes):
|
||||
self.encoder_cache[mm_hash] = image_features_lst[i].cpu()
|
||||
|
||||
def _apply_mm_inputs(self, request: Request, multi_vision_inputs: dict, rope_3d_position_ids: dict):
|
||||
"""
|
||||
Apply multimodal inputs to share_inputs
|
||||
- add image_features, extract and cache vision features from model
|
||||
- add rope_emb, rotate position embeddings
|
||||
"""
|
||||
if self.encoder_cache:
|
||||
evict_mm_hashes = request.get("evict_mm_hashes", None)
|
||||
if evict_mm_hashes:
|
||||
for mm_hash in evict_mm_hashes:
|
||||
self.encoder_cache.pop(mm_hash, None)
|
||||
|
||||
inputs = request.multimodal_inputs
|
||||
if request.with_image:
|
||||
if envs.FD_ENABLE_MAX_PREFILL:
|
||||
multi_vision_inputs["images_lst"].append(
|
||||
inputs["images"][request.image_start : request.image_end].cuda()
|
||||
)
|
||||
multi_vision_inputs["grid_thw_lst"].extend(
|
||||
inputs["grid_thw"][request.num_image_start : request.num_image_end]
|
||||
)
|
||||
multi_vision_inputs["cu_seqlens"].extend(
|
||||
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
|
||||
)
|
||||
multi_vision_inputs["vit_position_ids_lst"].extend(
|
||||
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
|
||||
)
|
||||
else:
|
||||
vision_inputs = inputs
|
||||
if self.encoder_cache:
|
||||
(
|
||||
vision_inputs["input_ids"],
|
||||
vision_inputs["token_type_ids"],
|
||||
vision_inputs["image_type_ids"],
|
||||
vision_inputs["images"],
|
||||
vision_inputs["grid_thw"],
|
||||
vision_inputs["mm_hashes"],
|
||||
) = self.batch_uncached_inputs(request)
|
||||
if len(vision_inputs["mm_hashes"]) > 0:
|
||||
# uncached multimodal inputs exist
|
||||
image_features = self.extract_vision_features(vision_inputs)
|
||||
self.scatter_and_cache_features(image_features, vision_inputs)
|
||||
|
||||
full_image_features_lst = []
|
||||
for mm_hash in inputs["mm_hashes"][request.num_image_start : request.num_image_end]:
|
||||
feature = self.encoder_cache[mm_hash].cuda()
|
||||
full_image_features_lst.append(feature)
|
||||
image_features = paddle.concat(full_image_features_lst, axis=0)
|
||||
else:
|
||||
(
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
image_type_ids,
|
||||
images,
|
||||
grid_thw,
|
||||
mm_hashes,
|
||||
) = self.get_chunked_inputs(request)
|
||||
vision_inputs["input_ids"] = paddle.to_tensor(input_ids, dtype=paddle.int64)
|
||||
vision_inputs["token_type_ids"] = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
|
||||
vision_inputs["image_type_ids"] = paddle.to_tensor(image_type_ids, dtype=paddle.int64)
|
||||
vision_inputs["images"] = paddle.to_tensor(
|
||||
images, dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16"
|
||||
)
|
||||
vision_inputs["grid_thw"] = paddle.to_tensor(grid_thw, dtype=paddle.int64)
|
||||
vision_inputs["mm_hashes"] = mm_hashes
|
||||
|
||||
image_features = self.extract_vision_features(vision_inputs)
|
||||
|
||||
# part of the first image may be already cached
|
||||
if "ernie" in self.model_config.model_type:
|
||||
actual_image_token_num = paddle.sum(vision_inputs["input_ids"] == self.model_config.im_patch_id)
|
||||
elif "qwen" in self.model_config.model_type:
|
||||
actual_image_token_num = paddle.sum(
|
||||
vision_inputs["input_ids"] == vision_inputs["image_patch_id"]
|
||||
) + paddle.sum(vision_inputs["input_ids"] == vision_inputs["video_patch_id"])
|
||||
else:
|
||||
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
|
||||
self.share_inputs["image_features"] = image_features[-actual_image_token_num:]
|
||||
|
||||
position_ids = request.multimodal_inputs["position_ids"]
|
||||
rope_3d_position_ids["position_ids_idx"].append(request.idx)
|
||||
rope_3d_position_ids["position_ids_lst"].append(position_ids)
|
||||
rope_3d_position_ids["position_ids_offset"].append(
|
||||
position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1]
|
||||
)
|
||||
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
|
||||
|
||||
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
|
||||
"""
|
||||
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
@@ -189,51 +370,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
prefill_end_index = request.prefill_end_index
|
||||
length = prefill_end_index - prefill_start_index
|
||||
if self.enable_mm:
|
||||
inputs = request.multimodal_inputs
|
||||
if request.with_image:
|
||||
if envs.FD_ENABLE_MAX_PREFILL:
|
||||
multi_vision_inputs["images_lst"].append(
|
||||
paddle.to_tensor(inputs["images"][request.image_start : request.image_end])
|
||||
)
|
||||
multi_vision_inputs["grid_thw_lst"].extend(
|
||||
inputs["grid_thw"][request.num_image_start : request.num_image_end]
|
||||
)
|
||||
multi_vision_inputs["cu_seqlens"].extend(
|
||||
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
|
||||
)
|
||||
multi_vision_inputs["vit_position_ids_lst"].extend(
|
||||
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
|
||||
)
|
||||
else:
|
||||
vision_inputs = {}
|
||||
vision_inputs["input_ids"] = paddle.to_tensor(
|
||||
inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
|
||||
)
|
||||
vision_inputs["token_type_ids"] = paddle.to_tensor(
|
||||
inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
|
||||
)
|
||||
vision_inputs["image_type_ids"] = paddle.to_tensor(
|
||||
inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end],
|
||||
dtype=paddle.int64,
|
||||
)
|
||||
vision_inputs["images"] = paddle.to_tensor(
|
||||
inputs["images"][request.image_start : request.image_end],
|
||||
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
|
||||
)
|
||||
vision_inputs["grid_thw"] = paddle.to_tensor(
|
||||
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
|
||||
)
|
||||
self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs)
|
||||
else:
|
||||
self.share_inputs["image_features"] = None
|
||||
|
||||
position_ids = request.multimodal_inputs["position_ids"]
|
||||
rope_3d_position_ids["position_ids_idx"].append(idx)
|
||||
rope_3d_position_ids["position_ids_lst"].append(position_ids)
|
||||
rope_3d_position_ids["position_ids_offset"].append(
|
||||
position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1]
|
||||
)
|
||||
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
|
||||
self._apply_mm_inputs(request, multi_vision_inputs, rope_3d_position_ids)
|
||||
|
||||
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None:
|
||||
# Enable thinking
|
||||
|
||||
Reference in New Issue
Block a user