[XPU] xpu support mm prefix cache (#5356)

Co-authored-by: ddchenhao66 <dhaochen163.com>
This commit is contained in:
ddchenhao66
2025-12-03 19:07:34 +08:00
committed by GitHub
parent a4bb3e9960
commit 4e8096bd0d
3 changed files with 184 additions and 51 deletions

View File

@@ -1694,9 +1694,9 @@ class FDConfig:
logger.info( logger.info(
"Static Graph does not support to be started together with RL Training, and automatically switch to dynamic graph!" "Static Graph does not support to be started together with RL Training, and automatically switch to dynamic graph!"
) )
if self.device_config is not None and self.device_config.device_type != "cuda": if not current_platform.is_cuda():
self.graph_opt_config.use_cudagraph = False self.graph_opt_config.use_cudagraph = False
logger.info(f"CUDAGraph only support on GPU, current device type is {self.device_config.device_type}!") logger.info("CUDAGraph currently only support on GPU!")
if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph: if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph:
if self.scheduler_config.max_num_seqs < self.parallel_config.tensor_parallel_size: if self.scheduler_config.max_num_seqs < self.parallel_config.tensor_parallel_size:
self.parallel_config.use_sequence_parallel_moe = False self.parallel_config.use_sequence_parallel_moe = False

View File

@@ -1233,10 +1233,6 @@ class EngineArgs:
all_dict = asdict(self) all_dict = asdict(self)
model_cfg = ModelConfig(all_dict) model_cfg = ModelConfig(all_dict)
# XPU currently disable prefix cache for VL model
if current_platform.is_xpu() and (self.enable_mm or model_cfg.enable_mm):
self.enable_prefix_caching = False
if not model_cfg.is_unified_ckpt and hasattr(model_cfg, "tensor_parallel_size"): if not model_cfg.is_unified_ckpt and hasattr(model_cfg, "tensor_parallel_size"):
self.tensor_parallel_size = model_cfg.tensor_parallel_size self.tensor_parallel_size = model_cfg.tensor_parallel_size

View File

@@ -108,6 +108,10 @@ class XPUModelRunner(ModelRunnerBase):
"matmul_v2", "matmul_v2",
"fused_gemm_epilogue", "fused_gemm_epilogue",
] ]
if self.cache_config.max_encoder_cache > 0:
self.encoder_cache: dict[str, paddle.Tensor] = {}
else:
self.encoder_cache = None
self.device_id = device_id self.device_id = device_id
self.speculative_method = self.fd_config.speculative_config.method self.speculative_method = self.fd_config.speculative_config.method
@@ -161,6 +165,183 @@ class XPUModelRunner(ModelRunnerBase):
else: else:
return 0 return 0
def get_chunked_inputs(self, req: Request):
"""
Get inputs in current chunk
"""
prefill_start_index = req.prefill_start_index
prefill_end_index = req.prefill_end_index
inputs = req.multimodal_inputs
input_ids = inputs["input_ids"][prefill_start_index:prefill_end_index]
token_type_ids = inputs["token_type_ids"][prefill_start_index:prefill_end_index]
image_type_ids = inputs["image_type_ids"][req.image_type_ids_start : req.image_type_ids_end]
images = inputs["images"][req.image_start : req.image_end]
grid_thw = inputs["grid_thw"][req.num_image_start : req.num_image_end]
mm_hashes = inputs["mm_hashes"][req.num_image_start : req.num_image_end]
return (
input_ids,
token_type_ids,
image_type_ids,
images,
grid_thw,
mm_hashes,
)
def batch_uncached_inputs(self, req: Request):
"""
Batch uncached multimodal inputs
"""
(input_ids, token_type_ids, image_type_ids, images, grid_thw, mm_hashes) = self.get_chunked_inputs(req)
image_type_ids_size = grid_thw[:, 0]
image_type_ids_split = np.cumsum(image_type_ids_size)[:-1]
image_type_ids_lst = np.array_split(image_type_ids, image_type_ids_split, axis=0)
images_size = np.prod(grid_thw, axis=1)
images_split = np.cumsum(images_size)[:-1]
images_lst = np.array_split(images, images_split, axis=0)
assert len(image_type_ids_lst) == len(
mm_hashes
), f"image_type_ids_lst length {len(image_type_ids_lst)} != mm_hashes length {len(mm_hashes)}"
assert len(images_lst) == len(
mm_hashes
), f"images_lst length {len(images_lst)} != mm_hashes length {len(mm_hashes)}"
uncached_image_type_ids = []
uncached_images = []
uncached_grid_thw = []
uncached_mm_hashes = []
for i, mm_hash in enumerate(mm_hashes):
if mm_hash in self.encoder_cache:
continue
uncached_image_type_ids.append(image_type_ids_lst[i])
uncached_images.append(images_lst[i])
uncached_grid_thw.append(grid_thw[i])
uncached_mm_hashes.append(mm_hash)
uncached_input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64)
uncached_token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
if len(uncached_mm_hashes) > 0:
uncached_image_type_ids = paddle.to_tensor(np.hstack(uncached_image_type_ids), dtype=paddle.int64)
uncached_images = paddle.to_tensor(
np.vstack(uncached_images), dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16"
)
uncached_grid_thw = paddle.to_tensor(uncached_grid_thw, dtype=paddle.int64)
return (
uncached_input_ids,
uncached_token_type_ids,
uncached_image_type_ids,
uncached_images,
uncached_grid_thw,
uncached_mm_hashes,
)
def scatter_and_cache_features(self, image_features, inputs):
"""
Split batched image features and cache them
"""
merge_size = 2
grid_thw = inputs["grid_thw"]
mm_hashes = inputs["mm_hashes"]
image_features_size = (paddle.prod(grid_thw[:, 1:], axis=1) // (merge_size**2)).tolist()
image_features_lst = paddle.split(image_features, image_features_size, axis=0)
assert len(image_features_lst) == len(
mm_hashes
), f"image_features_lst length {len(image_features_lst)} != mm_hashes length {len(mm_hashes)}"
for i, mm_hash in enumerate(mm_hashes):
self.encoder_cache[mm_hash] = image_features_lst[i].cpu()
def _apply_mm_inputs(self, request: Request, multi_vision_inputs: dict, rope_3d_position_ids: dict):
"""
Apply multimodal inputs to share_inputs
- add image_features, extract and cache vision features from model
- add rope_emb, rotate position embeddings
"""
if self.encoder_cache:
evict_mm_hashes = request.get("evict_mm_hashes", None)
if evict_mm_hashes:
for mm_hash in evict_mm_hashes:
self.encoder_cache.pop(mm_hash, None)
inputs = request.multimodal_inputs
if request.with_image:
if envs.FD_ENABLE_MAX_PREFILL:
multi_vision_inputs["images_lst"].append(
inputs["images"][request.image_start : request.image_end].cuda()
)
multi_vision_inputs["grid_thw_lst"].extend(
inputs["grid_thw"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["cu_seqlens"].extend(
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["vit_position_ids_lst"].extend(
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
)
else:
vision_inputs = inputs
if self.encoder_cache:
(
vision_inputs["input_ids"],
vision_inputs["token_type_ids"],
vision_inputs["image_type_ids"],
vision_inputs["images"],
vision_inputs["grid_thw"],
vision_inputs["mm_hashes"],
) = self.batch_uncached_inputs(request)
if len(vision_inputs["mm_hashes"]) > 0:
# uncached multimodal inputs exist
image_features = self.extract_vision_features(vision_inputs)
self.scatter_and_cache_features(image_features, vision_inputs)
full_image_features_lst = []
for mm_hash in inputs["mm_hashes"][request.num_image_start : request.num_image_end]:
feature = self.encoder_cache[mm_hash].cuda()
full_image_features_lst.append(feature)
image_features = paddle.concat(full_image_features_lst, axis=0)
else:
(
input_ids,
token_type_ids,
image_type_ids,
images,
grid_thw,
mm_hashes,
) = self.get_chunked_inputs(request)
vision_inputs["input_ids"] = paddle.to_tensor(input_ids, dtype=paddle.int64)
vision_inputs["token_type_ids"] = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
vision_inputs["image_type_ids"] = paddle.to_tensor(image_type_ids, dtype=paddle.int64)
vision_inputs["images"] = paddle.to_tensor(
images, dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16"
)
vision_inputs["grid_thw"] = paddle.to_tensor(grid_thw, dtype=paddle.int64)
vision_inputs["mm_hashes"] = mm_hashes
image_features = self.extract_vision_features(vision_inputs)
# part of the first image may be already cached
if "ernie" in self.model_config.model_type:
actual_image_token_num = paddle.sum(vision_inputs["input_ids"] == self.model_config.im_patch_id)
elif "qwen" in self.model_config.model_type:
actual_image_token_num = paddle.sum(
vision_inputs["input_ids"] == vision_inputs["image_patch_id"]
) + paddle.sum(vision_inputs["input_ids"] == vision_inputs["video_patch_id"])
else:
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
self.share_inputs["image_features"] = image_features[-actual_image_token_num:]
position_ids = request.multimodal_inputs["position_ids"]
rope_3d_position_ids["position_ids_idx"].append(request.idx)
rope_3d_position_ids["position_ids_lst"].append(position_ids)
rope_3d_position_ids["position_ids_offset"].append(
position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1]
)
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int): def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
""" """
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1 Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
@@ -189,51 +370,7 @@ class XPUModelRunner(ModelRunnerBase):
prefill_end_index = request.prefill_end_index prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index length = prefill_end_index - prefill_start_index
if self.enable_mm: if self.enable_mm:
inputs = request.multimodal_inputs self._apply_mm_inputs(request, multi_vision_inputs, rope_3d_position_ids)
if request.with_image:
if envs.FD_ENABLE_MAX_PREFILL:
multi_vision_inputs["images_lst"].append(
paddle.to_tensor(inputs["images"][request.image_start : request.image_end])
)
multi_vision_inputs["grid_thw_lst"].extend(
inputs["grid_thw"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["cu_seqlens"].extend(
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["vit_position_ids_lst"].extend(
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
)
else:
vision_inputs = {}
vision_inputs["input_ids"] = paddle.to_tensor(
inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["token_type_ids"] = paddle.to_tensor(
inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["image_type_ids"] = paddle.to_tensor(
inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end],
dtype=paddle.int64,
)
vision_inputs["images"] = paddle.to_tensor(
inputs["images"][request.image_start : request.image_end],
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
)
vision_inputs["grid_thw"] = paddle.to_tensor(
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
)
self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs)
else:
self.share_inputs["image_features"] = None
position_ids = request.multimodal_inputs["position_ids"]
rope_3d_position_ids["position_ids_idx"].append(idx)
rope_3d_position_ids["position_ids_lst"].append(position_ids)
rope_3d_position_ids["position_ids_offset"].append(
position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1]
)
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None: if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None:
# Enable thinking # Enable thinking