[Feature] Optim PaddleOCR-VL (#4873)

* [Feature] Optim PaddleOCR-VL

* fix bug
This commit is contained in:
ming1753
2025-11-07 14:56:44 +08:00
committed by GitHub
parent bbe0820555
commit cba185f1fe
12 changed files with 535 additions and 112 deletions

View File

@@ -22,7 +22,6 @@ import paddle
import paddle.nn as nn
from paddleformers.transformers import PretrainedModel
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import (
@@ -136,12 +135,8 @@ class PaddleOCRVLForConditionalGeneration(ModelForCasualLM):
)
# Persistent buffers for CUDA graphs.
if envs.FD_ENABLE_MAX_PREFILL:
max_length = fd_config.scheduler_config.max_num_seqs * fd_config.model_config.max_model_len
else:
max_length = fd_config.model_config.max_model_len
self._input_embeddings = paddle.zeros(
[max_length, fd_config.model_config.hidden_size],
self._decoder_input_embeddings = paddle.zeros(
[fd_config.scheduler_config.max_num_seqs, fd_config.model_config.hidden_size],
dtype=fd_config.model_config.dtype,
)
@@ -247,12 +242,19 @@ class PaddleOCRVLForConditionalGeneration(ModelForCasualLM):
input_embeddings = self.get_input_embeddings(
ids_remove_padding=ids_remove_padding, image_features=image_features
)
self._input_embeddings.copy_(input_embeddings, False)
hidden_states = self.model(
input_embeddings=self._input_embeddings,
forward_meta=forward_meta,
)
if forward_meta.step_use_cudagraph:
self._decoder_input_embeddings.copy_(input_embeddings, False)
hidden_states = self.model(
input_embeddings=self._decoder_input_embeddings,
forward_meta=forward_meta,
)
else:
hidden_states = self.model(
input_embeddings=input_embeddings,
forward_meta=forward_meta,
)
return hidden_states

View File

@@ -21,39 +21,13 @@ import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddleformers.transformers.activations import ACT2FN
from paddleformers.transformers.model_utils import PretrainedModel
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import slice_fn
from .config import PaddleOCRVisionConfig
def rotate_half(x):
Dh = x.shape[-1]
x1 = x[..., : Dh // 2]
x2 = x[..., Dh // 2 :]
return paddle.concat([-x2, x1], axis=-1)
def _ensure_cos_sin_dim(cos, sin, dim_needed):
last = cos.shape[-1]
if last == dim_needed:
return cos, sin
elif last * 2 == dim_needed:
cos = paddle.concat([cos, cos], axis=-1)
sin = paddle.concat([sin, sin], axis=-1)
return cos, sin
else:
raise ValueError(f"Unexpected cos/sin last-dim: {last}, expected {dim_needed} or {dim_needed//2}")
def apply_rotary_pos_emb_vision(x, cos, sin):
orig_dtype = x.dtype
x = x.astype("float32")
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed.astype(orig_dtype)
from .siglip_ops import get_activation_fn, neox_rope_embedding
class SiglipAttention(nn.Layer):
@@ -147,29 +121,12 @@ class SiglipAttention(nn.Layer):
output_attentions: Optional[bool] = False,
cu_seqlens: Optional[List[paddle.Tensor]] = None,
max_seqlen: Optional[paddle.Tensor] = None,
rope_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, # (cos, sin)
cos_emb: Optional[paddle.Tensor] = None, # (cos, sin)
sin_emb: Optional[paddle.Tensor] = None, # (cos, sin)
):
B, seq_length, D = hidden_states.shape
qkv = (
self.qkv_proj(hidden_states)
.reshape(
[
seq_length,
3,
self.num_heads,
-1,
]
)
.transpose(perm=[1, 0, 2, 3])
)
q, k, v = qkv.unbind(axis=0)
cos, sin = rope_emb
# --------
q = apply_rotary_pos_emb_vision(q, cos, sin)
k = apply_rotary_pos_emb_vision(k, cos, sin)
qkv = self.qkv_proj(hidden_states)
q, k, v = neox_rope_embedding(qkv, cos_emb, sin_emb, self.num_heads, self.head_dim)
attn_output = self.flash_attn_func(
q,
k,
@@ -181,11 +138,9 @@ class SiglipAttention(nn.Layer):
causal=False,
**self.flash_attn_kwargs,
)[0]
# --------
attn_output = attn_output.reshape((seq_length, -1))
attn_output = self.out_proj(attn_output)
return attn_output
@@ -327,11 +282,7 @@ class SiglipMLP(nn.Layer):
def __init__(self, config):
super().__init__()
self.config = config
if config.hidden_act == "gelu_pytorch_tanh":
config.hidden_act = "gelu_new"
self.activation_fn = ACT2FN[config.hidden_act]
self.activation_fn = get_activation_fn(config.hidden_act)
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc1.weight.weight_loader = self.weight_loader
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
@@ -353,7 +304,7 @@ class SiglipMLP(nn.Layer):
def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.activation_fn(hidden_states[0])
hidden_states = self.fc2(hidden_states)
return hidden_states
@@ -375,7 +326,8 @@ class SiglipEncoderLayer(paddle.nn.Layer):
output_attentions=False,
cu_seqlens=None,
max_seqlen=None,
rope_emb=None,
cos_emb=None,
sin_emb=None,
):
residual = hidden_states
@@ -388,7 +340,8 @@ class SiglipEncoderLayer(paddle.nn.Layer):
output_attentions=output_attentions,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
rope_emb=rope_emb,
cos_emb=cos_emb,
sin_emb=sin_emb,
)
hs_post_attn = residual + x
@@ -545,13 +498,13 @@ class SiglipEncoder(nn.Layer):
rope_emb = rope_emb_max_grid[pids].flatten(1)
rope_emb = rope_emb.tile((1, 2))
cos = rope_emb.cos().astype("float32")
sin = rope_emb.sin().astype("float32")
cos = cos.unsqueeze(-2)
sin = sin.unsqueeze(-2)
rope_emb = (cos, sin)
cos_emb = rope_emb.cos().astype("float32")
sin_emb = rope_emb.sin().astype("float32")
cos_emb = cos_emb.unsqueeze(-2)
sin_emb = sin_emb.unsqueeze(-2)
else:
rope_emb = None
cos_emb = None
sin_emb = None
window_indices, cu_seqlens_within_windows = None, None
@@ -588,7 +541,8 @@ class SiglipEncoder(nn.Layer):
output_attentions=output_attentions,
cu_seqlens=attn_cu_seqlens,
max_seqlen=max_seqlen,
rope_emb=rope_emb,
cos_emb=cos_emb,
sin_emb=sin_emb,
)
hidden_states = layer_outputs[0]

View File

@@ -0,0 +1,74 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import List
import paddle
from paddleformers.transformers.activations import ACT2FN
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import fused_neox_rope_embedding, gelu_tanh
def rotate_half(x):
Dh = x.shape[-1]
x1 = x[..., : Dh // 2]
x2 = x[..., Dh // 2 :]
return paddle.concat([-x2, x1], axis=-1)
def apply_rotary_pos_emb_vision(x, cos, sin):
orig_dtype = x.dtype
x = x.astype("float32")
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed.astype(orig_dtype)
def native_neox_rope_embedding(qkv, cos, sin, num_heads):
B, seq_length, D = qkv.shape
qkv = qkv.reshape(
[
seq_length,
3,
num_heads,
-1,
]
).transpose(perm=[1, 0, 2, 3])
q, k, v = qkv.unbind(axis=0)
q = apply_rotary_pos_emb_vision(q, cos, sin)
k = apply_rotary_pos_emb_vision(k, cos, sin)
return q, k, v
def neox_rope_embedding(
qkv: paddle.Tensor, cos_emb: paddle.Tensor, sin_emb: paddle.Tensor, num_heads: int, head_dim: int
) -> List[paddle.Tensor]:
if current_platform.is_cuda():
return fused_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads, head_dim)
else:
return native_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads)
def get_activation_fn(hidden_act: str):
if hidden_act == "gelu_pytorch_tanh":
if current_platform.is_cuda():
return gelu_tanh
else:
return ACT2FN["gelu_new"]
else:
return ACT2FN[hidden_act]