[Feature] Support Paddle-OCR (#4396)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

* init

* update code

* fix code style & disable thinking

* adapt for common_engine.update_mm_requests_chunk_size

* use 3d rope

* use flash_attn_unpadded

* opt siglip

* update to be compatible with the latest codebase

* fix typo

* optim OCR performance

* fix bug

* fix bug

* fix bug

* fix bug

* normlize name

* modify xpu rope

* revert logger

* fix bug

* fix bug

* fix bug

* support default_v1

* optim performance

* fix bug

---------

Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com>
Co-authored-by: zhangyue66 <zhangyue66@baidu.com>
This commit is contained in:
ming1753
2025-10-24 23:34:30 +08:00
committed by GitHub
parent 822dea8d5f
commit e4e3cede7f
21 changed files with 2869 additions and 175 deletions

View File

@@ -310,6 +310,14 @@ class GPUModelRunner(ModelRunnerBase):
req_len = len(req_dicts)
has_prefill_task = False
has_decode_task = False
multi_vision_inputs = {"images_lst": [], "grid_thw_lst": [], "vit_position_ids_lst": [], "cu_seqlens": [0]}
rope_3d_position_ids = {
"position_ids_idx": [],
"position_ids_lst": [],
"position_ids_offset": [0],
"max_tokens_lst": [],
}
for i in range(req_len):
request = req_dicts[i]
idx = request.idx
@@ -320,39 +328,49 @@ class GPUModelRunner(ModelRunnerBase):
if self.enable_mm:
inputs = request.multimodal_inputs
if request.with_image:
vision_inputs = {}
vision_inputs["input_ids"] = paddle.to_tensor(
inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["token_type_ids"] = paddle.to_tensor(
inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["image_type_ids"] = paddle.to_tensor(
inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end],
dtype=paddle.int64,
)
vision_inputs["images"] = paddle.to_tensor(
inputs["images"][request.image_start : request.image_end],
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
)
vision_inputs["grid_thw"] = paddle.to_tensor(
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
)
self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs)
if envs.FD_ENABLE_MAX_PREFILL:
multi_vision_inputs["images_lst"].append(
inputs["images"][request.image_start : request.image_end].cuda()
)
multi_vision_inputs["grid_thw_lst"].extend(
inputs["grid_thw"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["cu_seqlens"].extend(
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["vit_position_ids_lst"].extend(
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
)
else:
vision_inputs = {}
vision_inputs["input_ids"] = paddle.to_tensor(
inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["token_type_ids"] = paddle.to_tensor(
inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["image_type_ids"] = paddle.to_tensor(
inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end],
dtype=paddle.int64,
)
vision_inputs["images"] = paddle.to_tensor(
inputs["images"][request.image_start : request.image_end],
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
)
vision_inputs["grid_thw"] = paddle.to_tensor(
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
)
self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs)
else:
self.share_inputs["image_features"] = None
if inputs["position_ids"] is not None:
position_ids = paddle.to_tensor(
request.multimodal_inputs["position_ids"],
dtype="int64",
).unsqueeze([0])
else:
position_ids = None
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
position_ids, request.get("max_tokens", 2048)
position_ids = request.multimodal_inputs["position_ids"]
rope_3d_position_ids["position_ids_idx"].append(idx)
rope_3d_position_ids["position_ids_lst"].append(position_ids)
rope_3d_position_ids["position_ids_offset"].append(
position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1]
)
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None:
# Enable thinking
@@ -466,6 +484,21 @@ class GPUModelRunner(ModelRunnerBase):
else:
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
if len(multi_vision_inputs["images_lst"]) > 0:
self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs)
if len(rope_3d_position_ids["position_ids_idx"]) > 0:
packed_position_ids = paddle.to_tensor(
np.concatenate(rope_3d_position_ids["position_ids_lst"]), dtype="int64"
)
rope_3d_lst = self.prepare_rope3d(
packed_position_ids,
rope_3d_position_ids["max_tokens_lst"],
rope_3d_position_ids["position_ids_offset"],
)
for i, idx in enumerate(rope_3d_position_ids["position_ids_idx"]):
self.share_inputs["rope_emb"][idx : idx + 1, :] = rope_3d_lst[i]
if has_prefill_task or has_decode_task:
self.share_inputs["not_need_stop"][0] = True
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
@@ -545,7 +578,7 @@ class GPUModelRunner(ModelRunnerBase):
position_ids = paddle.to_tensor(
request.multimodal_inputs["position_ids"],
dtype="int64",
).unsqueeze([0])
)
else:
position_ids = None
token_chunk_size = inputs["input_ids"].shape[1]
@@ -582,8 +615,8 @@ class GPUModelRunner(ModelRunnerBase):
if self.enable_mm:
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
position_ids, request.get("max_tokens", 2048)
)
position_ids, [request.get("max_tokens", 2048)], [0, position_ids.shape[0]]
)[0]
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None:
@@ -994,7 +1027,9 @@ class GPUModelRunner(ModelRunnerBase):
if self.enable_mm:
head_dim = self.model_config.head_dim
if "qwen" in self.model_config.model_type: # neox style = True
if (
"qwen" in self.model_config.model_type or "paddleocr" in self.model_config.model_type
): # neox style = True
rope_head_dim = head_dim
else: # neox style = False
rope_head_dim = head_dim // 2
@@ -2221,7 +2256,7 @@ class GPUModelRunner(ModelRunnerBase):
grid_thw = None
if one["position_ids"] is not None:
position_ids = paddle.to_tensor(one["position_ids"], dtype="int64").unsqueeze([0])
position_ids = paddle.to_tensor(one["position_ids"], dtype="int64")
else:
position_ids = None
@@ -2288,6 +2323,49 @@ class GPUModelRunner(ModelRunnerBase):
return image_features
def extract_vision_features_paddleocr(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
if envs.FD_ENABLE_MAX_PREFILL:
inputs["vit_position_ids_lst"] = np.concatenate(inputs["vit_position_ids_lst"])
images = paddle.concat(inputs["images_lst"]).cast("bfloat16")
grid_thw = paddle.to_tensor(inputs["grid_thw_lst"], dtype="int64")
position_ids = paddle.to_tensor(inputs["vit_position_ids_lst"], dtype="int64")
cu_seqlens = paddle.cumsum(paddle.to_tensor(inputs["cu_seqlens"])).cast("int32")
else:
assert inputs["images"] is not None
grid_thw = inputs["grid_thw"]
images = inputs["images"]
position_ids = []
cu_seqlens = [0]
for idx, thw in enumerate(grid_thw):
numel = np.prod(np.array(thw))
position_ids.append(paddle.arange(numel) % np.prod(thw[1:]))
cu_seqlens.append(cu_seqlens[-1] + numel)
position_ids = paddle.concat(position_ids, axis=0).to(images.place)
cu_seqlens = paddle.to_tensor(cu_seqlens, dtype=paddle.int32).to(images.place)
with paddle.amp.auto_cast(
True,
custom_black_list=self.amp_black,
custom_white_list=self.amp_white,
level="O2",
dtype=self.model_config.dtype,
):
image_features = self.model.visual(
pixel_values=images,
image_grid_thw=grid_thw,
position_ids=position_ids,
interpolate_pos_encoding=True,
cu_seqlens=cu_seqlens,
use_rope=True,
window_size=-1,
)
image_features = self.model.projector(image_features, grid_thw)
image_features = paddle.concat(image_features, axis=0)
return image_features
@paddle.no_grad()
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
"""extract_vision_features"""
@@ -2295,28 +2373,26 @@ class GPUModelRunner(ModelRunnerBase):
return self.extract_vision_features_ernie(inputs)
elif "qwen" in self.model_config.model_type:
return self.extract_vision_features_qwen(inputs)
elif "paddleocr" in self.model_config.model_type:
return self.extract_vision_features_paddleocr(inputs)
else:
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
@paddle.no_grad()
def prepare_rope3d(self, position_ids: paddle.Tensor, max_len: int) -> paddle.Tensor:
def prepare_rope3d(
self, position_ids: paddle.Tensor, max_len_lst: list[int], cumsum_seqlens: list[int]
) -> list[paddle.Tensor]:
"""prepare_rope3d"""
prefix_max_position_ids = paddle.max(position_ids) + 1
dec_pos_ids = paddle.tile(
paddle.arange(max_len, dtype="int64").unsqueeze(0).unsqueeze(-1),
[1, 1, 3],
)
dec_pos_ids = dec_pos_ids + prefix_max_position_ids
position_ids_3d_real = paddle.concat([position_ids, dec_pos_ids], axis=1)
rope_emb = get_rope_3d(
position_ids=position_ids_3d_real,
rope_emb_lst = get_rope_3d(
position_ids=position_ids,
rotary_dim=self.model_config.head_dim,
partial_rotary_factor=1.0,
base=self.model_config.rope_theta,
max_position=self.model_config.max_model_len,
freq_allocation=getattr(self.model_config, "freq_allocation", 20),
model_type=self.model_config.model_type,
max_len_lst=max_len_lst,
cumsum_seqlens=cumsum_seqlens,
)
return rope_emb
return rope_emb_lst