mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Support Paddle-OCR (#4396)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
* init * update code * fix code style & disable thinking * adapt for common_engine.update_mm_requests_chunk_size * use 3d rope * use flash_attn_unpadded * opt siglip * update to be compatible with the latest codebase * fix typo * optim OCR performance * fix bug * fix bug * fix bug * fix bug * normlize name * modify xpu rope * revert logger * fix bug * fix bug * fix bug * support default_v1 * optim performance * fix bug --------- Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com> Co-authored-by: zhangyue66 <zhangyue66@baidu.com>
This commit is contained in:
@@ -419,33 +419,42 @@ class ErnieVlRotaryEmbedding3D:
|
||||
self.max_position = max_position
|
||||
self.freq_allocation = freq_allocation
|
||||
|
||||
def __call__(self, position_ids):
|
||||
def __call__(self, position_ids, max_len_lst, cumsum_seqlens):
|
||||
rot_emb = paddle.zeros((2, 1, self.max_position, 1, self.rotary_dim // 2), dtype="float32")
|
||||
|
||||
bsz = len(cumsum_seqlens) - 1
|
||||
# position_ids_3d: [bsz, seq_len, 3]
|
||||
position_ids_3d = paddle.tile(
|
||||
paddle.arange(self.max_position, dtype="int64").unsqueeze(0).unsqueeze(-1),
|
||||
[1, 1, 3],
|
||||
[bsz, 1, 3],
|
||||
)
|
||||
for i in range(bsz):
|
||||
position_ids_cur = position_ids[cumsum_seqlens[i] : cumsum_seqlens[i + 1]]
|
||||
prefix_max_position_ids = paddle.max(position_ids_cur) + 1
|
||||
dec_pos_ids = paddle.tile(
|
||||
paddle.arange(max_len_lst[i], dtype="int64").unsqueeze(-1),
|
||||
[1, 3],
|
||||
)
|
||||
dec_pos_ids = dec_pos_ids + prefix_max_position_ids
|
||||
position_ids_3d_real = paddle.concat([position_ids_cur, dec_pos_ids], axis=0)
|
||||
position_ids_3d[i, : position_ids_3d_real.shape[0], :] = position_ids_3d_real
|
||||
|
||||
position_ids_3d[:, : position_ids.shape[1], :] = position_ids
|
||||
|
||||
# position_ids: [bsz, seq_len]
|
||||
# position_ids: [bsz(1), seq_len]
|
||||
position_ids = paddle.arange(0, self.max_position, 1, dtype="float32").reshape((1, -1))
|
||||
|
||||
position_ids = position_ids / self.paritial_rotary_factor
|
||||
|
||||
indices = paddle.arange(0, self.rotary_dim, 2, dtype="float32")
|
||||
indices = 1 / self.base ** (indices / self.rotary_dim)
|
||||
# sinusoid_inp: [bsz, seq_len, 1, head_dim // 2]
|
||||
# sinusoid_inp: [bsz(1), seq_len, 1, head_dim // 2]
|
||||
sinusoid_inp = position_ids.unsqueeze(-1) * indices.unsqueeze(0)
|
||||
# pos_emb: [bsz, seq_len, 1, head_dim]
|
||||
# pos_emb: [bsz(1), seq_len, 1, head_dim]
|
||||
pos_emb = paddle.concat([paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1)
|
||||
# pos_emb: [bsz, 1, seq_len, head_dim]
|
||||
# pos_emb: [bsz(1), 1, seq_len, head_dim]
|
||||
pos_emb = paddle.reshape(pos_emb, (-1, 1, self.max_position, self.rotary_dim))
|
||||
# pos_emb: [bsz, seq_len, 1, head_dim]
|
||||
# pos_emb: [bsz(1), seq_len, 1, head_dim]
|
||||
pos_emb = pos_emb.transpose([0, 2, 1, 3])
|
||||
# sin: [bsz, seq_len, 1, head_dim // 2]
|
||||
# sin: [bsz(1), seq_len, 1, head_dim // 2]
|
||||
sin, cos = paddle.chunk(pos_emb, 2, axis=-1)
|
||||
batch_indices = paddle.arange(end=position_ids.shape[0]).cast("int64")
|
||||
# batch_indices: [[0]]
|
||||
@@ -454,39 +463,46 @@ class ErnieVlRotaryEmbedding3D:
|
||||
sin = sin.tile([position_ids.shape[0], 1, 1, 1])
|
||||
cos = cos.tile([position_ids.shape[0], 1, 1, 1])
|
||||
|
||||
tmp_pos_id_0 = position_ids_3d[..., 0].squeeze().astype("int64")
|
||||
tmp_pos_id_1 = position_ids_3d[..., 1].squeeze().astype("int64")
|
||||
tmp_pos_id_2 = position_ids_3d[..., 2].squeeze().astype("int64")
|
||||
tmp_pos_id_0 = position_ids_3d[..., 0].astype("int64")
|
||||
tmp_pos_id_1 = position_ids_3d[..., 1].astype("int64")
|
||||
tmp_pos_id_2 = position_ids_3d[..., 2].astype("int64")
|
||||
|
||||
sin_bsz = paddle.index_select(sin, index=batch_indices, axis=0)
|
||||
sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, -self.freq_allocation :]
|
||||
sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1, axis=1)[
|
||||
:, :, :, : self.rotary_dim // 2 - self.freq_allocation : 2
|
||||
]
|
||||
sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2, axis=1)[
|
||||
:, :, :, 1 : self.rotary_dim // 2 - self.freq_allocation : 2
|
||||
]
|
||||
sin_hw = paddle.stack([sin_h, sin_w], axis=-1).reshape(sin_h.shape[:-1] + [sin_h.shape[-1] * 2])
|
||||
sin_thw = paddle.concat([sin_hw, sin_t], axis=-1)
|
||||
|
||||
cos_bsz = paddle.index_select(cos, index=batch_indices, axis=0)
|
||||
cos_t = paddle.index_select(cos_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, -self.freq_allocation :]
|
||||
cos_h = paddle.index_select(cos_bsz, index=tmp_pos_id_1, axis=1)[
|
||||
:, :, :, : self.rotary_dim // 2 - self.freq_allocation : 2
|
||||
]
|
||||
cos_w = paddle.index_select(cos_bsz, index=tmp_pos_id_2, axis=1)[
|
||||
:, :, :, 1 : self.rotary_dim // 2 - self.freq_allocation : 2
|
||||
]
|
||||
cos_hw = paddle.stack([cos_h, cos_w], axis=-1).reshape(cos_h.shape[:-1] + [cos_h.shape[-1] * 2])
|
||||
cos_thw = paddle.concat([cos_hw, cos_t], axis=-1)
|
||||
rot_emb_list = []
|
||||
for i in range(bsz):
|
||||
sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0[i], axis=1)[:, :, :, -self.freq_allocation :]
|
||||
sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1[i], axis=1)[
|
||||
:, :, :, : self.rotary_dim // 2 - self.freq_allocation : 2
|
||||
]
|
||||
sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2[i], axis=1)[
|
||||
:, :, :, 1 : self.rotary_dim // 2 - self.freq_allocation : 2
|
||||
]
|
||||
sin_hw = paddle.stack([sin_h, sin_w], axis=-1).reshape(sin_h.shape[:-1] + [sin_h.shape[-1] * 2])
|
||||
sin_thw = paddle.concat([sin_hw, sin_t], axis=-1)
|
||||
|
||||
rot_emb[0] = cos_thw
|
||||
rot_emb[1] = sin_thw
|
||||
cos_bsz = paddle.index_select(cos, index=batch_indices, axis=0)
|
||||
cos_t = paddle.index_select(cos_bsz, index=tmp_pos_id_0[i], axis=1)[:, :, :, -self.freq_allocation :]
|
||||
cos_h = paddle.index_select(cos_bsz, index=tmp_pos_id_1[i], axis=1)[
|
||||
:, :, :, : self.rotary_dim // 2 - self.freq_allocation : 2
|
||||
]
|
||||
cos_w = paddle.index_select(cos_bsz, index=tmp_pos_id_2[i], axis=1)[
|
||||
:, :, :, 1 : self.rotary_dim // 2 - self.freq_allocation : 2
|
||||
]
|
||||
cos_hw = paddle.stack([cos_h, cos_w], axis=-1).reshape(cos_h.shape[:-1] + [cos_h.shape[-1] * 2])
|
||||
cos_thw = paddle.concat([cos_hw, cos_t], axis=-1)
|
||||
|
||||
if current_platform.is_iluvatar():
|
||||
rot_emb = paddle.stack([rot_emb, rot_emb], axis=-1).reshape([2, 1, self.max_position, 1, self.rotary_dim])
|
||||
rot_emb[0] = cos_thw
|
||||
rot_emb[1] = sin_thw
|
||||
|
||||
return rot_emb
|
||||
if current_platform.is_iluvatar():
|
||||
rot_emb = paddle.stack([rot_emb, rot_emb], axis=-1).reshape(
|
||||
[2, 1, self.max_position, 1, self.rotary_dim]
|
||||
)
|
||||
|
||||
rot_emb_list.append(rot_emb)
|
||||
|
||||
return rot_emb_list
|
||||
|
||||
|
||||
class QwenVlRotaryEmbedding3D:
|
||||
@@ -504,33 +520,42 @@ class QwenVlRotaryEmbedding3D:
|
||||
self.max_position = max_position
|
||||
self.freq_allocation = freq_allocation
|
||||
|
||||
def __call__(self, position_ids):
|
||||
def __call__(self, position_ids, max_len_lst, cumsum_seqlens):
|
||||
rot_emb = paddle.zeros((2, 1, self.max_position, 1, self.rotary_dim // 2), dtype="float32")
|
||||
|
||||
bsz = len(cumsum_seqlens) - 1
|
||||
# position_ids_3d: [bsz, seq_len, 3]
|
||||
position_ids_3d = paddle.tile(
|
||||
paddle.arange(self.max_position, dtype="int64").unsqueeze(0).unsqueeze(-1),
|
||||
[1, 1, 3],
|
||||
[bsz, 1, 3],
|
||||
)
|
||||
for i in range(bsz):
|
||||
position_ids_cur = position_ids[cumsum_seqlens[i] : cumsum_seqlens[i + 1]]
|
||||
prefix_max_position_ids = paddle.max(position_ids_cur) + 1
|
||||
dec_pos_ids = paddle.tile(
|
||||
paddle.arange(max_len_lst[i], dtype="int64").unsqueeze(-1),
|
||||
[1, 3],
|
||||
)
|
||||
dec_pos_ids = dec_pos_ids + prefix_max_position_ids
|
||||
position_ids_3d_real = paddle.concat([position_ids_cur, dec_pos_ids], axis=0)
|
||||
position_ids_3d[i, : position_ids_3d_real.shape[0], :] = position_ids_3d_real
|
||||
|
||||
position_ids_3d[:, : position_ids.shape[1], :] = position_ids
|
||||
|
||||
# position_ids: [bsz, seq_len]
|
||||
# position_ids: [bsz(1), seq_len]
|
||||
position_ids = paddle.arange(0, self.max_position, 1, dtype="float32").reshape((1, -1))
|
||||
|
||||
position_ids = position_ids / self.paritial_rotary_factor
|
||||
|
||||
indices = paddle.arange(0, self.rotary_dim, 2, dtype="float32")
|
||||
indices = 1 / self.base ** (indices / self.rotary_dim)
|
||||
# sinusoid_inp: [bsz, seq_len, 1, head_dim // 2]
|
||||
# sinusoid_inp: [bsz(1), seq_len, 1, head_dim // 2]
|
||||
sinusoid_inp = position_ids.unsqueeze(-1) * indices.unsqueeze(0)
|
||||
# pos_emb: [bsz, seq_len, 1, head_dim]
|
||||
# pos_emb: [bsz(1), seq_len, 1, head_dim]
|
||||
pos_emb = paddle.concat([paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1)
|
||||
# pos_emb: [bsz, 1, seq_len, head_dim]
|
||||
# pos_emb: [bsz(1), 1, seq_len, head_dim]
|
||||
pos_emb = paddle.reshape(pos_emb, (-1, 1, self.max_position, self.rotary_dim))
|
||||
# pos_emb: [bsz, seq_len, 1, head_dim]
|
||||
# pos_emb: [bsz(1), seq_len, 1, head_dim]
|
||||
pos_emb = pos_emb.transpose([0, 2, 1, 3])
|
||||
# sin: [bsz, seq_len, 1, head_dim // 2]
|
||||
# sin: [bsz(1), seq_len, 1, head_dim // 2]
|
||||
sin, cos = paddle.chunk(pos_emb, 2, axis=-1)
|
||||
batch_indices = paddle.arange(end=position_ids.shape[0]).cast("int64")
|
||||
# batch_indices: [[0]]
|
||||
@@ -539,9 +564,9 @@ class QwenVlRotaryEmbedding3D:
|
||||
sin = sin.tile([position_ids.shape[0], 1, 1, 1])
|
||||
cos = cos.tile([position_ids.shape[0], 1, 1, 1])
|
||||
|
||||
tmp_pos_id_0 = position_ids_3d[..., 0].squeeze().astype("int64")
|
||||
tmp_pos_id_1 = position_ids_3d[..., 1].squeeze().astype("int64")
|
||||
tmp_pos_id_2 = position_ids_3d[..., 2].squeeze().astype("int64")
|
||||
tmp_pos_id_0 = position_ids_3d[..., 0].astype("int64")
|
||||
tmp_pos_id_1 = position_ids_3d[..., 1].astype("int64")
|
||||
tmp_pos_id_2 = position_ids_3d[..., 2].astype("int64")
|
||||
|
||||
# sin_bsz = paddle.index_select(sin, index=batch_indices, axis=0)
|
||||
# sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, -self.freq_allocation :]
|
||||
@@ -559,28 +584,37 @@ class QwenVlRotaryEmbedding3D:
|
||||
section_w = (self.rotary_dim // 2 - self.freq_allocation) // 2 # 24
|
||||
|
||||
sin_bsz = paddle.index_select(sin, index=batch_indices, axis=0)
|
||||
sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, :section_t]
|
||||
sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1, axis=1)[:, :, :, section_t : section_t + section_h]
|
||||
sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2, axis=1)[
|
||||
:, :, :, section_t + section_h : section_t + section_h + section_w
|
||||
]
|
||||
sin_thw = paddle.concat([sin_t, sin_h, sin_w], axis=-1)
|
||||
|
||||
cos_bsz = paddle.index_select(cos, index=batch_indices, axis=0)
|
||||
rot_emb_list = []
|
||||
for i in range(bsz):
|
||||
sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0[i], axis=1)[:, :, :, :section_t]
|
||||
sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1[i], axis=1)[
|
||||
:, :, :, section_t : section_t + section_h
|
||||
]
|
||||
sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2[i], axis=1)[
|
||||
:, :, :, section_t + section_h : section_t + section_h + section_w
|
||||
]
|
||||
sin_thw = paddle.concat([sin_t, sin_h, sin_w], axis=-1)
|
||||
|
||||
cos_t = paddle.index_select(cos_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, :section_t]
|
||||
cos_h = paddle.index_select(cos_bsz, index=tmp_pos_id_1, axis=1)[:, :, :, section_t : section_t + section_h]
|
||||
cos_w = paddle.index_select(cos_bsz, index=tmp_pos_id_2, axis=1)[
|
||||
:, :, :, section_t + section_h : section_t + section_h + section_w
|
||||
]
|
||||
cos_thw = paddle.concat([cos_t, cos_h, cos_w], axis=-1)
|
||||
cos_bsz = paddle.index_select(cos, index=batch_indices, axis=0)
|
||||
|
||||
rot_emb[0] = cos_thw
|
||||
rot_emb[1] = sin_thw
|
||||
cos_t = paddle.index_select(cos_bsz, index=tmp_pos_id_0[i], axis=1)[:, :, :, :section_t]
|
||||
cos_h = paddle.index_select(cos_bsz, index=tmp_pos_id_1[i], axis=1)[
|
||||
:, :, :, section_t : section_t + section_h
|
||||
]
|
||||
cos_w = paddle.index_select(cos_bsz, index=tmp_pos_id_2[i], axis=1)[
|
||||
:, :, :, section_t + section_h : section_t + section_h + section_w
|
||||
]
|
||||
cos_thw = paddle.concat([cos_t, cos_h, cos_w], axis=-1)
|
||||
|
||||
# neox style need
|
||||
rot_emb_neox = paddle.concat([rot_emb, rot_emb], axis=-1)
|
||||
return rot_emb_neox
|
||||
rot_emb[0] = cos_thw
|
||||
rot_emb[1] = sin_thw
|
||||
|
||||
# neox style need
|
||||
rot_emb_neox = paddle.concat([rot_emb, rot_emb], axis=-1)
|
||||
rot_emb_list.append(rot_emb_neox)
|
||||
|
||||
return rot_emb_list
|
||||
|
||||
|
||||
def get_rope_3d(
|
||||
@@ -591,6 +625,8 @@ def get_rope_3d(
|
||||
max_position: int,
|
||||
freq_allocation: int,
|
||||
model_type: str,
|
||||
max_len_lst: list[int],
|
||||
cumsum_seqlens: list[int],
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Pre-calculate rotary position embedding for position_ids.
|
||||
@@ -618,10 +654,14 @@ def get_rope_3d(
|
||||
rotary_emb3d_layer = QwenVlRotaryEmbedding3D(
|
||||
rotary_dim, base, partial_rotary_factor, max_position, freq_allocation
|
||||
)
|
||||
elif "paddleocr" in model_type:
|
||||
rotary_emb3d_layer = QwenVlRotaryEmbedding3D(
|
||||
rotary_dim, base, partial_rotary_factor, max_position, freq_allocation
|
||||
)
|
||||
else: # default ernie
|
||||
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(
|
||||
rotary_dim, base, partial_rotary_factor, max_position, freq_allocation
|
||||
)
|
||||
|
||||
rotary_emb_3d = rotary_emb3d_layer(position_ids)
|
||||
rotary_emb_3d = rotary_emb3d_layer(position_ids, max_len_lst, cumsum_seqlens)
|
||||
return rotary_emb_3d
|
||||
|
||||
Reference in New Issue
Block a user