[XPU] xpu support mm prefix cache (#5356)

Co-authored-by: ddchenhao66 <dhaochen163.com>
This commit is contained in:
ddchenhao66
2025-12-03 19:07:34 +08:00
committed by GitHub
parent a4bb3e9960
commit 4e8096bd0d
3 changed files with 184 additions and 51 deletions

View File

@@ -1694,9 +1694,9 @@ class FDConfig:
logger.info(
"Static Graph does not support to be started together with RL Training, and automatically switch to dynamic graph!"
)
if self.device_config is not None and self.device_config.device_type != "cuda":
if not current_platform.is_cuda():
self.graph_opt_config.use_cudagraph = False
logger.info(f"CUDAGraph only support on GPU, current device type is {self.device_config.device_type}!")
logger.info("CUDAGraph currently only support on GPU!")
if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph:
if self.scheduler_config.max_num_seqs < self.parallel_config.tensor_parallel_size:
self.parallel_config.use_sequence_parallel_moe = False

View File

@@ -1233,10 +1233,6 @@ class EngineArgs:
all_dict = asdict(self)
model_cfg = ModelConfig(all_dict)
# XPU currently disable prefix cache for VL model
if current_platform.is_xpu() and (self.enable_mm or model_cfg.enable_mm):
self.enable_prefix_caching = False
if not model_cfg.is_unified_ckpt and hasattr(model_cfg, "tensor_parallel_size"):
self.tensor_parallel_size = model_cfg.tensor_parallel_size

View File

@@ -108,6 +108,10 @@ class XPUModelRunner(ModelRunnerBase):
"matmul_v2",
"fused_gemm_epilogue",
]
if self.cache_config.max_encoder_cache > 0:
self.encoder_cache: dict[str, paddle.Tensor] = {}
else:
self.encoder_cache = None
self.device_id = device_id
self.speculative_method = self.fd_config.speculative_config.method
@@ -161,6 +165,183 @@ class XPUModelRunner(ModelRunnerBase):
else:
return 0
def get_chunked_inputs(self, req: Request):
"""
Get inputs in current chunk
"""
prefill_start_index = req.prefill_start_index
prefill_end_index = req.prefill_end_index
inputs = req.multimodal_inputs
input_ids = inputs["input_ids"][prefill_start_index:prefill_end_index]
token_type_ids = inputs["token_type_ids"][prefill_start_index:prefill_end_index]
image_type_ids = inputs["image_type_ids"][req.image_type_ids_start : req.image_type_ids_end]
images = inputs["images"][req.image_start : req.image_end]
grid_thw = inputs["grid_thw"][req.num_image_start : req.num_image_end]
mm_hashes = inputs["mm_hashes"][req.num_image_start : req.num_image_end]
return (
input_ids,
token_type_ids,
image_type_ids,
images,
grid_thw,
mm_hashes,
)
def batch_uncached_inputs(self, req: Request):
"""
Batch uncached multimodal inputs
"""
(input_ids, token_type_ids, image_type_ids, images, grid_thw, mm_hashes) = self.get_chunked_inputs(req)
image_type_ids_size = grid_thw[:, 0]
image_type_ids_split = np.cumsum(image_type_ids_size)[:-1]
image_type_ids_lst = np.array_split(image_type_ids, image_type_ids_split, axis=0)
images_size = np.prod(grid_thw, axis=1)
images_split = np.cumsum(images_size)[:-1]
images_lst = np.array_split(images, images_split, axis=0)
assert len(image_type_ids_lst) == len(
mm_hashes
), f"image_type_ids_lst length {len(image_type_ids_lst)} != mm_hashes length {len(mm_hashes)}"
assert len(images_lst) == len(
mm_hashes
), f"images_lst length {len(images_lst)} != mm_hashes length {len(mm_hashes)}"
uncached_image_type_ids = []
uncached_images = []
uncached_grid_thw = []
uncached_mm_hashes = []
for i, mm_hash in enumerate(mm_hashes):
if mm_hash in self.encoder_cache:
continue
uncached_image_type_ids.append(image_type_ids_lst[i])
uncached_images.append(images_lst[i])
uncached_grid_thw.append(grid_thw[i])
uncached_mm_hashes.append(mm_hash)
uncached_input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64)
uncached_token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
if len(uncached_mm_hashes) > 0:
uncached_image_type_ids = paddle.to_tensor(np.hstack(uncached_image_type_ids), dtype=paddle.int64)
uncached_images = paddle.to_tensor(
np.vstack(uncached_images), dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16"
)
uncached_grid_thw = paddle.to_tensor(uncached_grid_thw, dtype=paddle.int64)
return (
uncached_input_ids,
uncached_token_type_ids,
uncached_image_type_ids,
uncached_images,
uncached_grid_thw,
uncached_mm_hashes,
)
def scatter_and_cache_features(self, image_features, inputs):
"""
Split batched image features and cache them
"""
merge_size = 2
grid_thw = inputs["grid_thw"]
mm_hashes = inputs["mm_hashes"]
image_features_size = (paddle.prod(grid_thw[:, 1:], axis=1) // (merge_size**2)).tolist()
image_features_lst = paddle.split(image_features, image_features_size, axis=0)
assert len(image_features_lst) == len(
mm_hashes
), f"image_features_lst length {len(image_features_lst)} != mm_hashes length {len(mm_hashes)}"
for i, mm_hash in enumerate(mm_hashes):
self.encoder_cache[mm_hash] = image_features_lst[i].cpu()
def _apply_mm_inputs(self, request: Request, multi_vision_inputs: dict, rope_3d_position_ids: dict):
"""
Apply multimodal inputs to share_inputs
- add image_features, extract and cache vision features from model
- add rope_emb, rotate position embeddings
"""
if self.encoder_cache:
evict_mm_hashes = request.get("evict_mm_hashes", None)
if evict_mm_hashes:
for mm_hash in evict_mm_hashes:
self.encoder_cache.pop(mm_hash, None)
inputs = request.multimodal_inputs
if request.with_image:
if envs.FD_ENABLE_MAX_PREFILL:
multi_vision_inputs["images_lst"].append(
inputs["images"][request.image_start : request.image_end].cuda()
)
multi_vision_inputs["grid_thw_lst"].extend(
inputs["grid_thw"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["cu_seqlens"].extend(
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["vit_position_ids_lst"].extend(
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
)
else:
vision_inputs = inputs
if self.encoder_cache:
(
vision_inputs["input_ids"],
vision_inputs["token_type_ids"],
vision_inputs["image_type_ids"],
vision_inputs["images"],
vision_inputs["grid_thw"],
vision_inputs["mm_hashes"],
) = self.batch_uncached_inputs(request)
if len(vision_inputs["mm_hashes"]) > 0:
# uncached multimodal inputs exist
image_features = self.extract_vision_features(vision_inputs)
self.scatter_and_cache_features(image_features, vision_inputs)
full_image_features_lst = []
for mm_hash in inputs["mm_hashes"][request.num_image_start : request.num_image_end]:
feature = self.encoder_cache[mm_hash].cuda()
full_image_features_lst.append(feature)
image_features = paddle.concat(full_image_features_lst, axis=0)
else:
(
input_ids,
token_type_ids,
image_type_ids,
images,
grid_thw,
mm_hashes,
) = self.get_chunked_inputs(request)
vision_inputs["input_ids"] = paddle.to_tensor(input_ids, dtype=paddle.int64)
vision_inputs["token_type_ids"] = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
vision_inputs["image_type_ids"] = paddle.to_tensor(image_type_ids, dtype=paddle.int64)
vision_inputs["images"] = paddle.to_tensor(
images, dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16"
)
vision_inputs["grid_thw"] = paddle.to_tensor(grid_thw, dtype=paddle.int64)
vision_inputs["mm_hashes"] = mm_hashes
image_features = self.extract_vision_features(vision_inputs)
# part of the first image may be already cached
if "ernie" in self.model_config.model_type:
actual_image_token_num = paddle.sum(vision_inputs["input_ids"] == self.model_config.im_patch_id)
elif "qwen" in self.model_config.model_type:
actual_image_token_num = paddle.sum(
vision_inputs["input_ids"] == vision_inputs["image_patch_id"]
) + paddle.sum(vision_inputs["input_ids"] == vision_inputs["video_patch_id"])
else:
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
self.share_inputs["image_features"] = image_features[-actual_image_token_num:]
position_ids = request.multimodal_inputs["position_ids"]
rope_3d_position_ids["position_ids_idx"].append(request.idx)
rope_3d_position_ids["position_ids_lst"].append(position_ids)
rope_3d_position_ids["position_ids_offset"].append(
position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1]
)
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
"""
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
@@ -189,51 +370,7 @@ class XPUModelRunner(ModelRunnerBase):
prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index
if self.enable_mm:
inputs = request.multimodal_inputs
if request.with_image:
if envs.FD_ENABLE_MAX_PREFILL:
multi_vision_inputs["images_lst"].append(
paddle.to_tensor(inputs["images"][request.image_start : request.image_end])
)
multi_vision_inputs["grid_thw_lst"].extend(
inputs["grid_thw"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["cu_seqlens"].extend(
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["vit_position_ids_lst"].extend(
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
)
else:
vision_inputs = {}
vision_inputs["input_ids"] = paddle.to_tensor(
inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["token_type_ids"] = paddle.to_tensor(
inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["image_type_ids"] = paddle.to_tensor(
inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end],
dtype=paddle.int64,
)
vision_inputs["images"] = paddle.to_tensor(
inputs["images"][request.image_start : request.image_end],
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
)
vision_inputs["grid_thw"] = paddle.to_tensor(
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
)
self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs)
else:
self.share_inputs["image_features"] = None
position_ids = request.multimodal_inputs["position_ids"]
rope_3d_position_ids["position_ids_idx"].append(idx)
rope_3d_position_ids["position_ids_lst"].append(position_ids)
rope_3d_position_ids["position_ids_offset"].append(
position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1]
)
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
self._apply_mm_inputs(request, multi_vision_inputs, rope_3d_position_ids)
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None:
# Enable thinking