[XPU] Support PaddleOCR-VL model for XPU (#4529)

* [XPU] support PaddleOCR-VL in XPU

* [XPU] fix PaddleOCR-VL pos_emb_type
This commit is contained in:
Lucas
2025-10-28 20:35:04 +08:00
committed by GitHub
parent 2a9ed72533
commit 0a0c74e717
2 changed files with 109 additions and 27 deletions

View File

@@ -83,7 +83,9 @@ class XPUAttentionBackend(AttentionBackend):
self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
fd_config.model_config, "use_3d_rope", False
)
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.rank: int = fd_config.parallel_config.tensor_parallel_rank

View File

@@ -420,6 +420,7 @@ class XPUModelRunner(ModelRunnerBase):
req_len = len(req_dicts)
has_prefill_task = False
has_decode_task = False
multi_vision_inputs = {"images_lst": [], "grid_thw_lst": [], "vit_position_ids_lst": [], "cu_seqlens": [0]}
rope_3d_position_ids = {
"position_ids_idx": [],
"position_ids_lst": [],
@@ -436,24 +437,39 @@ class XPUModelRunner(ModelRunnerBase):
if self.enable_mm:
inputs = request.multimodal_inputs
if request.with_image:
vision_inputs = {}
vision_inputs["input_ids"] = paddle.to_tensor(
inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["token_type_ids"] = paddle.to_tensor(
inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["image_type_ids"] = paddle.to_tensor(
inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end],
dtype=paddle.int64,
)
vision_inputs["images"] = paddle.to_tensor(
inputs["images"][request.image_start : request.image_end], dtype="uint8"
)
vision_inputs["grid_thw"] = paddle.to_tensor(
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
)
self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs)
if envs.FD_ENABLE_MAX_PREFILL:
multi_vision_inputs["images_lst"].append(
paddle.to_tensor(inputs["images"][request.image_start : request.image_end])
)
multi_vision_inputs["grid_thw_lst"].extend(
inputs["grid_thw"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["cu_seqlens"].extend(
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["vit_position_ids_lst"].extend(
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
)
else:
vision_inputs = {}
vision_inputs["input_ids"] = paddle.to_tensor(
inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["token_type_ids"] = paddle.to_tensor(
inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["image_type_ids"] = paddle.to_tensor(
inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end],
dtype=paddle.int64,
)
vision_inputs["images"] = paddle.to_tensor(
inputs["images"][request.image_start : request.image_end],
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
)
vision_inputs["grid_thw"] = paddle.to_tensor(
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
)
self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs)
else:
self.share_inputs["image_features"] = None
@@ -570,6 +586,9 @@ class XPUModelRunner(ModelRunnerBase):
else:
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
if len(multi_vision_inputs["images_lst"]) > 0:
self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs)
if len(rope_3d_position_ids["position_ids_idx"]) > 0:
packed_position_ids = paddle.to_tensor(
np.concatenate(rope_3d_position_ids["position_ids_lst"]), dtype="int64"
@@ -826,6 +845,16 @@ class XPUModelRunner(ModelRunnerBase):
if self.enable_mm:
head_dim = self.model_config.head_dim
if "paddleocr" in self.model_config.model_type: # neox style = True
rope_head_dim = head_dim
else: # neox style = False
rope_head_dim = head_dim // 2
if head_dim == self.model_config.head_dim:
self.share_inputs["pos_emb_type"] = "NORMAL"
else:
self.share_inputs["pos_emb_type"] = "HALF_HEAD_DIM"
self.share_inputs["rope_emb"] = paddle.full(
shape=[
max_num_seqs,
@@ -833,7 +862,7 @@ class XPUModelRunner(ModelRunnerBase):
1,
self.model_config.max_model_len,
1,
head_dim // 2,
rope_head_dim,
],
fill_value=0,
dtype="float32",
@@ -866,8 +895,8 @@ class XPUModelRunner(ModelRunnerBase):
# Update bad tokens len
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
if self.enable_mm: # pos_emb_type is different in EB and VL
self.forward_meta.pos_emb_type = "HALF_HEAD_DIM"
if self.enable_mm:
self.forward_meta.pos_emb_type = self.share_inputs["pos_emb_type"]
self.forward_meta.attn_backend = self.attn_backends[0]
self.initialize_attention_backend()
@@ -1338,12 +1367,10 @@ class XPUModelRunner(ModelRunnerBase):
)
return result
@paddle.no_grad()
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
"""extract_vision_features"""
def extract_vision_features_ernie(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
assert inputs["images"] is not None
grid_thw = inputs["grid_thw"]
# ernie-vl has images norm
images = inputs["images"].cast("float32")
images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor
images = images / self.image_preprocess.image_std_tensor
@@ -1353,7 +1380,6 @@ class XPUModelRunner(ModelRunnerBase):
token_type_ids_w_video = token_type_ids
input_ids = inputs["input_ids"]
# convert to img patch id
# TODO(lulinjun): may need to check model_config and model_cfg
image_mask = input_ids == self.model_config.im_patch_id
image_type_ids = inputs["image_type_ids"]
with paddle.amp.auto_cast(
@@ -1369,6 +1395,7 @@ class XPUModelRunner(ModelRunnerBase):
image_features = image_features.reshape([-1, C * self.model_config.spatial_conv_size**2])
image_features = ScatterOp.apply(image_features, axis=-1) # mp 切 Fea
image_features = image_features.reshape([S, -1])
# ernie-vl has resampler_model
image_features = self.model.resampler_model(
image_features,
image_mask,
@@ -1378,6 +1405,59 @@ class XPUModelRunner(ModelRunnerBase):
)
return image_features
def extract_vision_features_paddleocr(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
if envs.FD_ENABLE_MAX_PREFILL:
inputs["vit_position_ids_lst"] = np.concatenate(inputs["vit_position_ids_lst"])
images = paddle.concat(inputs["images_lst"]).cast("bfloat16")
grid_thw = paddle.to_tensor(inputs["grid_thw_lst"], dtype="int64")
position_ids = paddle.to_tensor(inputs["vit_position_ids_lst"], dtype="int64")
cu_seqlens = paddle.cumsum(paddle.to_tensor(inputs["cu_seqlens"])).cast("int32")
else:
assert inputs["images"] is not None
grid_thw = inputs["grid_thw"]
images = inputs["images"]
position_ids = []
cu_seqlens = [0]
for idx, thw in enumerate(grid_thw):
numel = np.prod(np.array(thw))
position_ids.append(paddle.arange(numel) % np.prod(thw[1:]))
cu_seqlens.append(cu_seqlens[-1] + numel)
position_ids = paddle.concat(position_ids, axis=0).to(images.place)
cu_seqlens = paddle.to_tensor(cu_seqlens, dtype=paddle.int32).to(images.place)
with paddle.amp.auto_cast(
True,
custom_black_list=self.amp_black,
custom_white_list=self.amp_white,
level="O2",
dtype=self.model_config.dtype,
):
image_features = self.model.visual(
pixel_values=images,
image_grid_thw=grid_thw,
position_ids=position_ids,
interpolate_pos_encoding=True,
cu_seqlens=cu_seqlens,
use_rope=True,
window_size=-1,
)
image_features = self.model.projector(image_features, grid_thw)
image_features = paddle.concat(image_features, axis=0)
return image_features
@paddle.no_grad()
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
"""extract_vision_features"""
if "ernie" in self.model_config.model_type:
return self.extract_vision_features_ernie(inputs)
elif "paddleocr" in self.model_config.model_type:
return self.extract_vision_features_paddleocr(inputs)
else:
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
@paddle.no_grad()
def prepare_rope3d(
self, position_ids: paddle.Tensor, max_len_lst: list[int], cumsum_seqlens: list[int]