mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] Support PaddleOCR-VL model for XPU (#4529)
* [XPU] support PaddleOCR-VL in XPU * [XPU] fix PaddleOCR-VL pos_emb_type
This commit is contained in:
@@ -83,7 +83,9 @@ class XPUAttentionBackend(AttentionBackend):
|
||||
self.rope_theta: float = (
|
||||
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
|
||||
)
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
|
||||
fd_config.model_config, "use_3d_rope", False
|
||||
)
|
||||
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
@@ -420,6 +420,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
req_len = len(req_dicts)
|
||||
has_prefill_task = False
|
||||
has_decode_task = False
|
||||
multi_vision_inputs = {"images_lst": [], "grid_thw_lst": [], "vit_position_ids_lst": [], "cu_seqlens": [0]}
|
||||
rope_3d_position_ids = {
|
||||
"position_ids_idx": [],
|
||||
"position_ids_lst": [],
|
||||
@@ -436,24 +437,39 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
if self.enable_mm:
|
||||
inputs = request.multimodal_inputs
|
||||
if request.with_image:
|
||||
vision_inputs = {}
|
||||
vision_inputs["input_ids"] = paddle.to_tensor(
|
||||
inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
|
||||
)
|
||||
vision_inputs["token_type_ids"] = paddle.to_tensor(
|
||||
inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
|
||||
)
|
||||
vision_inputs["image_type_ids"] = paddle.to_tensor(
|
||||
inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end],
|
||||
dtype=paddle.int64,
|
||||
)
|
||||
vision_inputs["images"] = paddle.to_tensor(
|
||||
inputs["images"][request.image_start : request.image_end], dtype="uint8"
|
||||
)
|
||||
vision_inputs["grid_thw"] = paddle.to_tensor(
|
||||
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
|
||||
)
|
||||
self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs)
|
||||
if envs.FD_ENABLE_MAX_PREFILL:
|
||||
multi_vision_inputs["images_lst"].append(
|
||||
paddle.to_tensor(inputs["images"][request.image_start : request.image_end])
|
||||
)
|
||||
multi_vision_inputs["grid_thw_lst"].extend(
|
||||
inputs["grid_thw"][request.num_image_start : request.num_image_end]
|
||||
)
|
||||
multi_vision_inputs["cu_seqlens"].extend(
|
||||
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
|
||||
)
|
||||
multi_vision_inputs["vit_position_ids_lst"].extend(
|
||||
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
|
||||
)
|
||||
else:
|
||||
vision_inputs = {}
|
||||
vision_inputs["input_ids"] = paddle.to_tensor(
|
||||
inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
|
||||
)
|
||||
vision_inputs["token_type_ids"] = paddle.to_tensor(
|
||||
inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
|
||||
)
|
||||
vision_inputs["image_type_ids"] = paddle.to_tensor(
|
||||
inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end],
|
||||
dtype=paddle.int64,
|
||||
)
|
||||
vision_inputs["images"] = paddle.to_tensor(
|
||||
inputs["images"][request.image_start : request.image_end],
|
||||
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
|
||||
)
|
||||
vision_inputs["grid_thw"] = paddle.to_tensor(
|
||||
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
|
||||
)
|
||||
self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs)
|
||||
else:
|
||||
self.share_inputs["image_features"] = None
|
||||
|
||||
@@ -570,6 +586,9 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
else:
|
||||
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
|
||||
|
||||
if len(multi_vision_inputs["images_lst"]) > 0:
|
||||
self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs)
|
||||
|
||||
if len(rope_3d_position_ids["position_ids_idx"]) > 0:
|
||||
packed_position_ids = paddle.to_tensor(
|
||||
np.concatenate(rope_3d_position_ids["position_ids_lst"]), dtype="int64"
|
||||
@@ -826,6 +845,16 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
|
||||
if self.enable_mm:
|
||||
head_dim = self.model_config.head_dim
|
||||
if "paddleocr" in self.model_config.model_type: # neox style = True
|
||||
rope_head_dim = head_dim
|
||||
else: # neox style = False
|
||||
rope_head_dim = head_dim // 2
|
||||
|
||||
if head_dim == self.model_config.head_dim:
|
||||
self.share_inputs["pos_emb_type"] = "NORMAL"
|
||||
else:
|
||||
self.share_inputs["pos_emb_type"] = "HALF_HEAD_DIM"
|
||||
|
||||
self.share_inputs["rope_emb"] = paddle.full(
|
||||
shape=[
|
||||
max_num_seqs,
|
||||
@@ -833,7 +862,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
1,
|
||||
self.model_config.max_model_len,
|
||||
1,
|
||||
head_dim // 2,
|
||||
rope_head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype="float32",
|
||||
@@ -866,8 +895,8 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
# Update bad tokens len
|
||||
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
|
||||
|
||||
if self.enable_mm: # pos_emb_type is different in EB and VL
|
||||
self.forward_meta.pos_emb_type = "HALF_HEAD_DIM"
|
||||
if self.enable_mm:
|
||||
self.forward_meta.pos_emb_type = self.share_inputs["pos_emb_type"]
|
||||
self.forward_meta.attn_backend = self.attn_backends[0]
|
||||
self.initialize_attention_backend()
|
||||
|
||||
@@ -1338,12 +1367,10 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
return result
|
||||
|
||||
@paddle.no_grad()
|
||||
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
|
||||
"""extract_vision_features"""
|
||||
def extract_vision_features_ernie(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
|
||||
assert inputs["images"] is not None
|
||||
grid_thw = inputs["grid_thw"]
|
||||
|
||||
# ernie-vl has images norm
|
||||
images = inputs["images"].cast("float32")
|
||||
images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor
|
||||
images = images / self.image_preprocess.image_std_tensor
|
||||
@@ -1353,7 +1380,6 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
token_type_ids_w_video = token_type_ids
|
||||
input_ids = inputs["input_ids"]
|
||||
# convert to img patch id
|
||||
# TODO(lulinjun): may need to check model_config and model_cfg
|
||||
image_mask = input_ids == self.model_config.im_patch_id
|
||||
image_type_ids = inputs["image_type_ids"]
|
||||
with paddle.amp.auto_cast(
|
||||
@@ -1369,6 +1395,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
image_features = image_features.reshape([-1, C * self.model_config.spatial_conv_size**2])
|
||||
image_features = ScatterOp.apply(image_features, axis=-1) # mp 切 Fea
|
||||
image_features = image_features.reshape([S, -1])
|
||||
# ernie-vl has resampler_model
|
||||
image_features = self.model.resampler_model(
|
||||
image_features,
|
||||
image_mask,
|
||||
@@ -1378,6 +1405,59 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
return image_features
|
||||
|
||||
def extract_vision_features_paddleocr(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
|
||||
if envs.FD_ENABLE_MAX_PREFILL:
|
||||
inputs["vit_position_ids_lst"] = np.concatenate(inputs["vit_position_ids_lst"])
|
||||
images = paddle.concat(inputs["images_lst"]).cast("bfloat16")
|
||||
grid_thw = paddle.to_tensor(inputs["grid_thw_lst"], dtype="int64")
|
||||
position_ids = paddle.to_tensor(inputs["vit_position_ids_lst"], dtype="int64")
|
||||
cu_seqlens = paddle.cumsum(paddle.to_tensor(inputs["cu_seqlens"])).cast("int32")
|
||||
else:
|
||||
assert inputs["images"] is not None
|
||||
grid_thw = inputs["grid_thw"]
|
||||
images = inputs["images"]
|
||||
|
||||
position_ids = []
|
||||
cu_seqlens = [0]
|
||||
for idx, thw in enumerate(grid_thw):
|
||||
numel = np.prod(np.array(thw))
|
||||
position_ids.append(paddle.arange(numel) % np.prod(thw[1:]))
|
||||
cu_seqlens.append(cu_seqlens[-1] + numel)
|
||||
|
||||
position_ids = paddle.concat(position_ids, axis=0).to(images.place)
|
||||
cu_seqlens = paddle.to_tensor(cu_seqlens, dtype=paddle.int32).to(images.place)
|
||||
|
||||
with paddle.amp.auto_cast(
|
||||
True,
|
||||
custom_black_list=self.amp_black,
|
||||
custom_white_list=self.amp_white,
|
||||
level="O2",
|
||||
dtype=self.model_config.dtype,
|
||||
):
|
||||
image_features = self.model.visual(
|
||||
pixel_values=images,
|
||||
image_grid_thw=grid_thw,
|
||||
position_ids=position_ids,
|
||||
interpolate_pos_encoding=True,
|
||||
cu_seqlens=cu_seqlens,
|
||||
use_rope=True,
|
||||
window_size=-1,
|
||||
)
|
||||
image_features = self.model.projector(image_features, grid_thw)
|
||||
image_features = paddle.concat(image_features, axis=0)
|
||||
|
||||
return image_features
|
||||
|
||||
@paddle.no_grad()
|
||||
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
|
||||
"""extract_vision_features"""
|
||||
if "ernie" in self.model_config.model_type:
|
||||
return self.extract_vision_features_ernie(inputs)
|
||||
elif "paddleocr" in self.model_config.model_type:
|
||||
return self.extract_vision_features_paddleocr(inputs)
|
||||
else:
|
||||
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
|
||||
|
||||
@paddle.no_grad()
|
||||
def prepare_rope3d(
|
||||
self, position_ids: paddle.Tensor, max_len_lst: list[int], cumsum_seqlens: list[int]
|
||||
|
||||
Reference in New Issue
Block a user