[Model]support qwen2_5_vl (#3557)

* adapt qwen_2_5_vl model

* adapt qwen_2_5_vl VIT model

* adapt qwen2_5_vl images_embeds

* adapt qwen2_5_vl 3D rope

* adapt qwen2_5_vl 3D rope v2

* adapt qwen2_5_vl processor

* adapt qwen2_5_vl bypass resampler_model

* adapt qwen2_5_vl 绕过部分ernie逻辑

* adapt qwen2_5_vl 绕过部分ernie逻辑 v2

* adapt qwen2_5_vl 权重加载与命名修改

* adapt qwen2_5_vl 非必须think_end_id

* adapt qwen2_5_vl 区分多种模型的extract_vision_features

* fix:adapt qwen2_5_vl model

* adapt qwen2_5_vl norm

* adapt qwen2_5_vl  processor 更新

* adapt qwen2_5_vl image and video success

* adapt qwen2_5_vl 部分整理代码

* adapt qwen2_5_vl 支持多卡

* adapt qwen2_5_vl on latest develop

* adapt qwen2_5_vl RL

* adapt qwen2_5_vl 整理代码

* support noex rope3d

* adapt qwen2_5_vl add init.py

* adapt qwen2_5_vl add init.py v2

* adapt qwen2_5_vl remove space

* adapt qwen2_5_vl remove space v2

* adapt qwen2_5_vl pre-commit

* adapt qwen2_5_vl update

* adapt qwen2_5_vl pre-commit v2

* adapt qwen2_5_vl modify comments

* adapt qwen2_5_vl fix indentation

* adapt qwen2_5_vl fix indentation v2

---------

Co-authored-by: wangyafeng <wangyafeng@baidu.com>
Co-authored-by: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com>
Co-authored-by: CSWYF3634076 <58356743+CSWYF3634076@users.noreply.github.com>
This commit is contained in:
zhouchong
2025-08-29 18:28:39 +08:00
committed by GitHub
parent 65425bf858
commit ccd52b5596
10 changed files with 1718 additions and 17 deletions

View File

@@ -325,8 +325,6 @@ class ErnieVlRotaryEmbedding3D:
position_ids_3d[:, : position_ids.shape[1], :] = position_ids
# import pdb;pdb.set_trace()
# position_ids: [bsz, seq_len]
position_ids = paddle.arange(0, self.max_position, 1, dtype="float32").reshape((1, -1))
@@ -383,6 +381,100 @@ class ErnieVlRotaryEmbedding3D:
return rot_emb
class QwenVlRotaryEmbedding3D:
def __init__(
self,
rotary_dim,
base,
partial_rotary_factor,
max_position,
freq_allocation,
):
self.rotary_dim = rotary_dim
self.base = base
self.paritial_rotary_factor = partial_rotary_factor
self.max_position = max_position
self.freq_allocation = freq_allocation
def __call__(self, position_ids):
rot_emb = paddle.zeros((2, 1, self.max_position, 1, self.rotary_dim // 2), dtype="float32")
# position_ids_3d: [bsz, seq_len, 3]
position_ids_3d = paddle.tile(
paddle.arange(self.max_position, dtype="int64").unsqueeze(0).unsqueeze(-1),
[1, 1, 3],
)
position_ids_3d[:, : position_ids.shape[1], :] = position_ids
# position_ids: [bsz, seq_len]
position_ids = paddle.arange(0, self.max_position, 1, dtype="float32").reshape((1, -1))
position_ids = position_ids / self.paritial_rotary_factor
indices = paddle.arange(0, self.rotary_dim, 2, dtype="float32")
indices = 1 / self.base ** (indices / self.rotary_dim)
# sinusoid_inp: [bsz, seq_len, 1, head_dim // 2]
sinusoid_inp = position_ids.unsqueeze(-1) * indices.unsqueeze(0)
# pos_emb: [bsz, seq_len, 1, head_dim]
pos_emb = paddle.concat([paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1)
# pos_emb: [bsz, 1, seq_len, head_dim]
pos_emb = paddle.reshape(pos_emb, (-1, 1, self.max_position, self.rotary_dim))
# pos_emb: [bsz, seq_len, 1, head_dim]
pos_emb = pos_emb.transpose([0, 2, 1, 3])
# sin: [bsz, seq_len, 1, head_dim // 2]
sin, cos = paddle.chunk(pos_emb, 2, axis=-1)
batch_indices = paddle.arange(end=position_ids.shape[0]).cast("int64")
# batch_indices: [[0]]
batch_indices = batch_indices[..., None]
# sin, cos: [3, seq_len, 1, head_dim // 2]
sin = sin.tile([position_ids.shape[0], 1, 1, 1])
cos = cos.tile([position_ids.shape[0], 1, 1, 1])
tmp_pos_id_0 = position_ids_3d[..., 0].squeeze().astype("int64")
tmp_pos_id_1 = position_ids_3d[..., 1].squeeze().astype("int64")
tmp_pos_id_2 = position_ids_3d[..., 2].squeeze().astype("int64")
# sin_bsz = paddle.index_select(sin, index=batch_indices, axis=0)
# sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, -self.freq_allocation :]
# sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1, axis=1)[
# :, :, :, : self.rotary_dim // 2 - self.freq_allocation : 2
# ]
# sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2, axis=1)[
# :, :, :, 1 : self.rotary_dim // 2 - self.freq_allocation : 2
# ]
# sin_hw = paddle.stack([sin_h, sin_w], axis=-1).reshape(sin_h.shape[:-1] + [sin_h.shape[-1] * 2])
# sin_thw = paddle.concat([sin_hw, sin_t], axis=-1)
section_t = self.freq_allocation # 16
section_h = (self.rotary_dim // 2 - self.freq_allocation) // 2 # 24
section_w = (self.rotary_dim // 2 - self.freq_allocation) // 2 # 24
sin_bsz = paddle.index_select(sin, index=batch_indices, axis=0)
sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, :section_t]
sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1, axis=1)[:, :, :, section_t : section_t + section_h]
sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2, axis=1)[
:, :, :, section_t + section_h : section_t + section_h + section_w
]
sin_thw = paddle.concat([sin_t, sin_h, sin_w], axis=-1)
cos_bsz = paddle.index_select(cos, index=batch_indices, axis=0)
cos_t = paddle.index_select(cos_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, :section_t]
cos_h = paddle.index_select(cos_bsz, index=tmp_pos_id_1, axis=1)[:, :, :, section_t : section_t + section_h]
cos_w = paddle.index_select(cos_bsz, index=tmp_pos_id_2, axis=1)[
:, :, :, section_t + section_h : section_t + section_h + section_w
]
cos_thw = paddle.concat([cos_t, cos_h, cos_w], axis=-1)
rot_emb[0] = cos_thw
rot_emb[1] = sin_thw
# neox style need
rot_emb_neox = paddle.concat([rot_emb, rot_emb], axis=-1)
return rot_emb_neox
def get_rope_3d(
rotary_dim: int,
base: float,
@@ -390,6 +482,7 @@ def get_rope_3d(
partial_rotary_factor: float,
max_position: int,
freq_allocation: int,
model_type: str,
) -> paddle.Tensor:
"""
Pre-calculate rotary position embedding for position_ids.
@@ -407,9 +500,20 @@ def get_rope_3d(
Default: 1 (apply to all dimensions).
max_position: Maximum position index to precompute.
freq_allocation: Number of rotary dimensions allocated to temporal axis
model_type: Model type, such as 'ernie4_5_moe_vl' or 'qwen2_5_vl'.
"""
if "ernie" in model_type:
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(
rotary_dim, base, partial_rotary_factor, max_position, freq_allocation
)
elif "qwen" in model_type:
rotary_emb3d_layer = QwenVlRotaryEmbedding3D(
rotary_dim, base, partial_rotary_factor, max_position, freq_allocation
)
else: # default ernie
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(
rotary_dim, base, partial_rotary_factor, max_position, freq_allocation
)
rotary_emb_3d = rotary_emb3d_layer(position_ids)
return rotary_emb_3d

View File

@@ -0,0 +1,15 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

View File

@@ -0,0 +1,23 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from .configuration import DFNRopeVisionTransformerConfig
from .modeling import DFNRopeVisionTransformerPretrainedModel
__all__ = [
"DFNRopeVisionTransformerConfig",
"DFNRopeVisionTransformerPretrainedModel",
]

View File

@@ -0,0 +1,277 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import math
from collections import OrderedDict
import paddle
import paddle.nn.functional as F
from paddle import Tensor, nn
class NewGELUActivation(nn.Layer):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def forward(self, input: Tensor) -> Tensor:
"""_summary_
Args:
input (Tensor): _description_
Returns:
Tensor: _description_
"""
return (
0.5 * input * (1.0 + paddle.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * paddle.pow(input, 3.0))))
)
class GELUActivation(nn.Layer):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def __init__(self, use_gelu_python: bool = False):
"""_summary_
Args:
use_gelu_python (bool, optional): _description_. Defaults to False.
"""
super().__init__()
if use_gelu_python:
self.act = self._gelu_python
else:
self.act = nn.functional.gelu
def _gelu_python(self, input: Tensor) -> Tensor:
"""_summary_
Args:
input (Tensor): _description_
Returns:
Tensor: _description_
"""
return input * 0.5 * (1.0 + paddle.erf(input / math.sqrt(2.0)))
def forward(self, input: Tensor) -> Tensor:
"""_summary_
Args:
input (Tensor): _description_
Returns:
Tensor: _description_
"""
return self.act(input)
class FastGELUActivation(nn.Layer):
"""
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
"""
def forward(self, input: Tensor) -> Tensor:
"""_summary_
Args:
input (Tensor): _description_
Returns:
Tensor: _description_
"""
return 0.5 * input * (1.0 + paddle.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
class QuickGELUActivation(nn.Layer):
"""
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
"""
def forward(self, input: Tensor) -> Tensor:
"""_summary_
Args:
input (Tensor): _description_
Returns:
Tensor: _description_
"""
return input * F.sigmoid(1.702 * input)
class ClippedGELUActivation(nn.Layer):
"""
Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
https://arxiv.org/abs/2004.09602.
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
"""
def __init__(self, min: float, max: float):
if min > max:
raise ValueError(f"min should be < max (got min: {min}, max: {max})")
super().__init__()
self.min = min
self.max = max
def forward(self, x: Tensor) -> Tensor:
"""_summary_
Args:
x (Tensor): _description_
Returns:
Tensor: _description_
"""
return paddle.clip(gelu(x), self.min, self.max)
class SiLUActivation(nn.Layer):
"""
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
later.
"""
def forward(self, input: Tensor) -> Tensor:
"""_summary_
Args:
input (Tensor): _description_
Returns:
Tensor: _description_
"""
return F.silu(input)
class MishActivation(nn.Layer):
"""
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
visit the official repository for the paper: https://github.com/digantamisra98/Mish
"""
def forward(self, input: Tensor) -> Tensor:
"""_summary_
Args:
input (Tensor): _description_
Returns:
Tensor: _description_
"""
return F.mish(input)
class LinearActivation(nn.Layer):
"""
Applies the linear activation function, i.e. forwarding input directly to output.
"""
def forward(self, input: Tensor) -> Tensor:
"""_summary_
Args:
input (Tensor): _description_
Returns:
Tensor: _description_
"""
return input
class ClassInstantier(OrderedDict):
"""_summary_
Args:
OrderedDict (_type_): _description_
"""
def __getitem__(self, key):
"""_summary_
Args:
key (_type_): _description_
Returns:
_type_: _description_
"""
content = super().__getitem__(key)
cls, kwargs = content if isinstance(content, tuple) else (content, {})
return cls(**kwargs)
ACT2CLS = {
"gelu": GELUActivation,
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
"gelu_fast": FastGELUActivation,
"gelu_new": NewGELUActivation,
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
"linear": LinearActivation,
"mish": MishActivation,
"quick_gelu": QuickGELUActivation,
"relu": nn.ReLU,
"relu6": nn.ReLU6,
"sigmoid": nn.Sigmoid,
"silu": SiLUActivation,
"swish": SiLUActivation,
"tanh": nn.Tanh,
}
ACT2FN = ClassInstantier(ACT2CLS)
def get_activation(activation_string):
"""_summary_
Args:
activation_string (_type_): _description_
Raises:
KeyError: _description_
Returns:
_type_: _description_
"""
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
# For backwards compatibility with: from activations import gelu_python
gelu_python = get_activation("gelu_python")
gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
gelu_fast = get_activation("gelu_fast")
quick_gelu = get_activation("quick_gelu")
silu = get_activation("silu")
mish = get_activation("mish")
linear_act = get_activation("linear")

View File

@@ -0,0 +1,96 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from paddleformers.transformers.configuration_utils import PretrainedConfig
__all__ = [
"DFNRopeVisionTransformerConfig",
]
# qwen2_5 视觉参数
"""
"vision_config": {
"depth": 32,
"hidden_act": "silu",
"hidden_size": 1280,
"intermediate_size": 3420,
"num_heads": 16,
"in_chans": 3,
"out_hidden_size": 3584,
"patch_size": 14,
"spatial_merge_size": 2,
"spatial_patch_size": 14,
"window_size": 112,
"fullatt_block_indexes": [
7,
15,
23,
31
],
"tokens_per_second": 2,
"temporal_patch_size": 2
},
"""
# qwen:
# hidden_size -> embed_dim
# out_hidden_size -> hidden_size
# intermediate_size -> qwen_vision_block 中 mlp/mlp_hidden_dim
# fullatt_block_indexes 区分vit部分不同attention的layer_index
# spatial_patch_size 和 tokens_per_second 在vllm中没用到
class DFNRopeVisionTransformerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`~ErnieModel`]. It is used to instantiate an Ernie
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Ernie-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
"""
model_type = "DFNRope_vision_transformer"
def __init__(
self,
depth=32,
hidden_size=1280,
out_hidden_size=3584,
intermediate_size=3420,
hidden_act="silu",
num_heads=16,
in_channels=3,
patch_size=14,
spatial_merge_size=2,
window_size=112,
fullatt_block_indexes=[7, 15, 23, 31],
temporal_patch_size=2,
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.hidden_size = hidden_size
self.out_hidden_size = out_hidden_size
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.window_size = window_size
self.fullatt_block_indexes = fullatt_block_indexes
self.temporal_patch_size = temporal_patch_size

View File

@@ -0,0 +1,706 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from functools import partial
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import (
ColumnParallelLinear,
RowParallelLinear,
)
from paddle.nn.functional.flash_attention import (
flash_attn_unpadded as flash_attn_varlen_func,
)
from paddleformers.transformers.model_utils import PretrainedModel
from fastdeploy.model_executor.layers.utils import get_tensor
from .activation import ACT2FN
from .configuration import DFNRopeVisionTransformerConfig
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x
def apply_rotary_pos_emb_vision(tensor: paddle.Tensor, freqs: paddle.Tensor) -> paddle.Tensor:
"""_summary_
Args:
tensor (paddle.Tensor): _description_
freqs (paddle.Tensor): _description_
Returns:
paddle.Tensor: _description_
"""
orig_dtype = tensor.dtype
with paddle.amp.auto_cast(False):
tensor = tensor.astype(dtype="float32")
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).tile(repeat_times=[1, 1, 2]).unsqueeze(0).astype(dtype="float32")
sin = sin.unsqueeze(1).tile(repeat_times=[1, 1, 2]).unsqueeze(0).astype(dtype="float32")
output = tensor * cos + rotate_half(tensor) * sin
output = paddle.cast(output, orig_dtype)
return output
class VisionFlashAttention2(nn.Layer):
"""_summary_
Args:
nn (_type_): _description_
"""
def __init__(self, dim: int, num_heads: int = 16, tensor_parallel_degree: int = 1) -> None:
super().__init__()
self.num_heads = num_heads
self.tensor_parallel_degree = tensor_parallel_degree
if tensor_parallel_degree > 1:
self.qkv = ColumnParallelLinear(
dim,
dim * 3,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
weight_attr=None,
has_bias=True,
fuse_matmul_bias=True,
gather_output=False,
)
self.proj = RowParallelLinear(
dim,
dim,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
input_is_parallel=True,
has_bias=True,
)
else:
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
self.proj = nn.Linear(dim, dim, bias_attr=True)
self.head_dim = dim // num_heads # must added
def forward(
self,
hidden_states: paddle.Tensor,
cu_seqlens: paddle.Tensor,
rotary_pos_emb: paddle.Tensor = None,
) -> paddle.Tensor:
"""_summary_
Args:
hidden_states (paddle.Tensor): _description_
cu_seqlens (paddle.Tensor): _description_
rotary_pos_emb (paddle.Tensor, optional): _description_. Defaults to None.
Returns:
paddle.Tensor: _description_
"""
seq_length = hidden_states.shape[0]
qkv = (
self.qkv(hidden_states)
.reshape(
[
seq_length,
3,
self.num_heads // self.tensor_parallel_degree,
-1,
]
)
.transpose(perm=[1, 0, 2, 3])
)
q, k, v = qkv.unbind(axis=0)
q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
softmax_scale = self.head_dim**-0.5
attn_output = (
flash_attn_varlen_func( # flash_attn_unpadded
q, # 不支持float32
k,
v,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
scale=softmax_scale,
)[0]
.squeeze(0)
.reshape([seq_length, -1])
)
attn_output = attn_output.astype(paddle.float32)
attn_output = self.proj(attn_output)
return attn_output
class PatchEmbed(nn.Layer):
"""_summary_
Args:
nn (_type_): _description_
"""
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_channels: int = 3,
hidden_size: int = 1152,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.in_channels = in_channels
self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.layer.Conv3D(
in_channels, hidden_size, kernel_size=kernel_size, stride=kernel_size, bias_attr=False
)
def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
"""_summary_
Args:
hidden_states (paddle.Tensor): _description_
Returns:
paddle.Tensor: _description_
"""
target_dtype = self.proj.weight.dtype
hidden_states = hidden_states.reshape(
[-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size]
)
hidden_states = self.proj(paddle.cast(hidden_states, dtype=target_dtype)).reshape([-1, self.hidden_size])
return hidden_states
class VisionMlp(nn.Layer):
"""_summary_
Args:
nn (_type_): _description_
"""
def __init__(
self,
dim: int,
hidden_dim: int,
bias: bool = False,
hidden_act: str = "gelu",
tensor_parallel_degree: int = 1,
) -> None:
super().__init__()
self.tensor_parallel_degree = tensor_parallel_degree
if self.tensor_parallel_degree > 1:
self.gate_proj = ColumnParallelLinear(
dim,
hidden_dim,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
gather_output=False,
has_bias=bias,
)
self.up_proj = ColumnParallelLinear(
dim,
hidden_dim,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
gather_output=False,
has_bias=bias,
)
self.down_proj = RowParallelLinear(
hidden_dim,
dim,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
input_is_parallel=True,
has_bias=bias,
)
else:
self.gate_proj = nn.Linear(dim, hidden_dim, bias_attr=bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias_attr=bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias_attr=bias)
self.act = ACT2FN[hidden_act]
def forward(self, x) -> paddle.Tensor:
"""_summary_
Args:
x (_type_): _description_
Returns:
paddle.Tensor: _description_
"""
x_gate = self.gate_proj(x)
x_gate = self.act(x_gate)
x_up = self.up_proj(x)
x_down = self.down_proj(x_gate * x_up)
return x_down
class VisionRotaryEmbedding(nn.Layer):
"""_summary_
Args:
nn (_type_): _description_
"""
def __init__(self, dim: int, theta: float = 10000.0) -> None:
"""_summary_
Args:
dim (int): _description_
theta (float, optional): _description_. Defaults to 10000.0.
"""
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / theta ** (paddle.arange(start=0, end=dim, step=2, dtype="float32") / dim)
self.register_buffer("inv_freq", inv_freq, persistable=False)
self._seq_len_cached = 0
self._freqs_cached = None
def update_freqs_cache(self, seqlen: int) -> None:
if seqlen > self._seq_len_cached:
seqlen *= 2
self._seq_len_cached = seqlen
self.inv_freq = 1.0 / (self.theta ** (paddle.arange(0, self.dim, 2, dtype="float32") / self.dim))
seq = paddle.arange(seqlen, dtype=self.inv_freq.dtype)
freqs = paddle.outer(seq, self.inv_freq)
self._freqs_cached = freqs
def forward(self, seqlen: int) -> paddle.Tensor:
"""_summary_
Args:
seqlen (int): _description_
Returns:
paddle.Tensor: _description_
"""
self.update_freqs_cache(seqlen)
return self._freqs_cached[:seqlen]
class Qwen2RMSNorm(nn.Layer):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = paddle.create_parameter(
shape=[hidden_size],
dtype=paddle.get_default_dtype(),
default_initializer=nn.initializer.Constant(1.0),
)
self.variance_epsilon = eps
def forward(self, hidden_states):
if paddle.in_dynamic_mode():
with paddle.amp.auto_cast(False):
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
else:
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
hidden_states = paddle.cast(hidden_states, self.weight.dtype)
return hidden_states * self.weight
class DFNRopeVisionBlock(nn.Layer):
"""_summary_
Args:
nn (_type_): _description_
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_hidden_dim: int,
hidden_act: str = "gelu",
tensor_parallel_degree: int = 1,
attn_implementation: str = "sdpa",
) -> None:
"""_summary_
Args:
config (_type_): _description_
attn_implementation (str, optional): _description_. Defaults to "sdpa".
"""
super().__init__()
self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
self.attn = VisionFlashAttention2(
dim=dim,
num_heads=num_heads,
tensor_parallel_degree=tensor_parallel_degree,
)
self.mlp = VisionMlp(
dim=dim,
hidden_dim=mlp_hidden_dim,
bias=True,
hidden_act=hidden_act,
tensor_parallel_degree=tensor_parallel_degree,
)
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> paddle.Tensor:
"""_summary_
Args:
hidden_states (_type_): _description_
cu_seqlens (_type_): _description_
rotary_pos_emb (_type_): _description_
Returns:
paddle.Tensor: _description_
"""
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
class PatchMerger(nn.Layer):
"""_summary_
Args:
nn (_type_): _description_
"""
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
"""_summary_
Args:
dim (int): _description_
context_dim (int): _description_
spatial_merge_size (int, optional): _description_. Defaults to 2.
"""
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size, bias_attr=True),
nn.GELU(),
nn.Linear(self.hidden_size, dim, bias_attr=True),
)
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
"""_summary_
Args:
x (paddle.Tensor): _description_
Returns:
paddle.Tensor: _description_
"""
x = self.mlp(self.ln_q(x).reshape([-1, self.hidden_size]))
return x
class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
"""_summary_
Args:
PretrainedModel (_type_): _description_
Returns:
_type_: _description_
"""
config_class = DFNRopeVisionTransformerConfig
def __init__(self, config, prefix_name: str = "") -> None:
super().__init__(config.vision_config)
self.spatial_merge_size = config.vision_config.spatial_merge_size
self.prefix_name = prefix_name
# args for get_window_index_thw
self.window_size = config.vision_config.window_size
self.patch_size = config.vision_config.patch_size
self.spatial_merge_size = config.vision_config.spatial_merge_size
self.fullatt_block_indexes = config.vision_config.fullatt_block_indexes
self.spatial_merge_unit = self.spatial_merge_size**2
self.patch_embed = PatchEmbed(
patch_size=config.vision_config.patch_size,
temporal_patch_size=config.vision_config.temporal_patch_size,
in_channels=config.vision_config.in_chans,
hidden_size=config.vision_config.hidden_size,
)
head_dim = config.vision_config.hidden_size // config.vision_config.num_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.LayerList(
[
DFNRopeVisionBlock(
dim=config.vision_config.hidden_size,
num_heads=config.vision_config.num_heads,
mlp_hidden_dim=config.vision_config.intermediate_size,
hidden_act=config.vision_config.hidden_act,
tensor_parallel_degree=config.pretrained_config.tensor_parallel_degree,
)
for _ in range(config.vision_config.depth)
]
)
self.merger = PatchMerger(
dim=config.vision_config.out_hidden_size, context_dim=config.vision_config.hidden_size
)
@property
def device(self) -> paddle.device:
return self.patch_embed.proj.weight.device
def get_dtype(self) -> paddle.dtype:
"""_summary_
Returns:
paddle.dtype: _description_
"""
return self.blocks[0].mlp.fc2.weight.dtype
def get_window_index(self, grid_thw):
window_index: list = []
cu_window_seqlens: list = [0]
window_index_id = 0
vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = (grid_h // self.spatial_merge_size, grid_w // self.spatial_merge_size)
index = paddle.arange(end=grid_t * llm_grid_h * llm_grid_w).reshape([grid_t, llm_grid_h, llm_grid_w])
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
index_padded = paddle.nn.functional.pad(
x=index, pad=(0, pad_w, 0, pad_h), mode="constant", value=-100, pad_from_left_axis=False
)
index_padded = index_padded.reshape(
[grid_t, num_windows_h, vit_merger_window_size, num_windows_w, vit_merger_window_size]
)
index_padded = index_padded.transpose(perm=[0, 1, 3, 2, 4]).reshape(
[grid_t, num_windows_h * num_windows_w, vit_merger_window_size, vit_merger_window_size]
)
seqlens = (index_padded != -100).sum(axis=[2, 3]).reshape([-1])
index_padded = index_padded.reshape([-1])
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = seqlens.cumsum(axis=0) * self.spatial_merge_unit + cu_window_seqlens[-1]
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = paddle.concat(x=window_index, axis=0)
return window_index, cu_window_seqlens
def rot_pos_emb(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = paddle.arange(h).unsqueeze(1).expand([-1, w])
hpos_ids = hpos_ids.reshape(
[
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
]
)
hpos_ids = hpos_ids.transpose(perm=[0, 2, 1, 3])
hpos_ids = hpos_ids.flatten()
wpos_ids = paddle.arange(w).unsqueeze(0).expand([h, -1])
wpos_ids = wpos_ids.reshape(
[
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
]
)
wpos_ids = wpos_ids.transpose([0, 2, 1, 3])
wpos_ids = wpos_ids.flatten()
pos_ids.append(paddle.stack(x=[hpos_ids, wpos_ids], axis=-1).tile(repeat_times=[t, 1]))
pos_ids = paddle.concat(x=pos_ids, axis=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(start_axis=1)
return rotary_pos_emb
def get_rope_by_thw(self, t, h, w):
window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw(t, h, w)
rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w)
rotary_pos_emb_thw = rotary_pos_emb_thw[window_index_thw, :, :]
rotary_pos_emb_thw = rotary_pos_emb_thw.flatten(start_dim=0, end_dim=1)
cu_seqlens_thw = paddle.repeat_interleave(paddle.tensor([h * w], dtype=paddle.int32), t)
return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw, cu_seqlens_thw)
def forward(self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor, num_pad=0) -> paddle.Tensor:
"""_summary_
Args:
hidden_states (paddle.Tensor): _description_
grid_thw (paddle.Tensor): _description_
Returns:
paddle.Tensor: _description_
"""
hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
cu_window_seqlens = paddle.to_tensor(data=cu_window_seqlens, dtype="int32", place=hidden_states.place)
cu_window_seqlens = paddle.unique_consecutive(x=cu_window_seqlens)
seq_len, _ = tuple(hidden_states.shape)
hidden_states = hidden_states.reshape([seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1])
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape([seq_len, -1])
rotary_pos_emb = rotary_pos_emb.reshape([seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1])
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape([seq_len, -1])
cu_seqlens = paddle.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
axis=0, dtype="int32"
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for layer_num, blk in enumerate(self.blocks):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
rotary_pos_emb=rotary_pos_emb,
)
# adapter
hidden_states = self.merger(hidden_states)
reverse_indices = paddle.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]
return hidden_states
def extract_feature(self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor) -> paddle.Tensor:
"""_summary_
Args:
hidden_states (paddle.Tensor): _description_
grid_thw (paddle.Tensor): _description_
Returns:
paddle.Tensor: _description_
"""
return self.forward(hidden_states, grid_thw)
@classmethod
def _get_tensor_parallel_mappings(cls, config, is_split=True):
"""
dummy
"""
from paddleformers.transformers.conversion_utils import split_or_merge_func
fn = split_or_merge_func(
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
)
vision_config = config.vision_config
def split_qkv_weight(x):
head_dim = vision_config.hidden_size // vision_config.num_heads
x = x.reshape(
[
vision_config.hidden_size,
3,
vision_config.num_heads,
head_dim,
]
)
x = np.split(x, vision_config.tensor_parallel_degree, axis=-2)[vision_config.tensor_parallel_rank]
x = x.reshape([vision_config.hidden_size, -1])
return x
def split_qkv_bias(x):
head_dim = vision_config.hidden_size // vision_config.num_heads
x = x.reshape([3, vision_config.num_heads, head_dim])
x = np.split(x, vision_config.tensor_parallel_degree, axis=-2)[vision_config.tensor_parallel_rank]
x = x.reshape([-1])
return x
def get_tensor_parallel_split_mappings(depth):
final_actions = {}
base_actions = {
"visual.blocks.0.attn.proj.weight": partial(fn, is_column=False),
"visual.blocks.0.mlp.gate_proj.weight": partial(fn, is_column=True),
"visual.blocks.0.mlp.gate_proj.bias": partial(fn, is_column=True),
"visual.blocks.0.mlp.up_proj.weight": partial(fn, is_column=True),
"visual.blocks.0.mlp.up_proj.bias": partial(fn, is_column=True),
"visual.blocks.0.mlp.down_proj.weight": partial(fn, is_column=False),
"visual.blocks.0.qkv.weight": split_qkv_weight,
"visual.blocks.0.qkv.bias": split_qkv_bias,
}
for key, action in base_actions.items():
if "blocks.0." in key:
for i in range(depth):
newkey = key.replace("blocks.0.", f"blocks.{i}.")
final_actions[newkey] = action
return final_actions
mappings = get_tensor_parallel_split_mappings(vision_config.depth)
return mappings
def load_state_dict(self, state_dict):
params_dict = dict(self.named_parameters())
for param_name, param in params_dict.items():
state_dict_key = f"{self.prefix_name}.{param_name}"
if state_dict_key not in state_dict:
raise ValueError(f"The key {state_dict_key} does not exist in state_dict. ")
tensor = get_tensor(state_dict.pop(state_dict_key))
if param.shape != tensor.shape:
raise ValueError(f"{state_dict_key} param.shape={param.shape} tensor.shape={tensor.shape}")
else:
param.copy_(tensor, False)

View File

@@ -0,0 +1,390 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
from functools import partial
from typing import Dict, Optional, Union
import numpy as np
import paddle
from paddle import nn
from paddleformers.transformers import PretrainedModel
from paddleformers.transformers.configuration_utils import PretrainedConfig
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.graph_optimization.decorator import (
support_graph_optimization,
)
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.model_executor.models.qwen2 import Qwen2DecoderLayer
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import extract_text_token_output
from fastdeploy.model_executor.forward_meta import ForwardMeta
@support_graph_optimization
class Qwen2_5_VLModel(nn.Layer):
def __init__(
self,
fd_config: FDConfig = None,
):
"""
Initializer for the Ernie4_5_VLModel class.
Args:
"""
super().__init__()
self.num_layers = fd_config.model_config.num_hidden_layers
self.image_token_id = fd_config.model_config.image_token_id
self.video_token_id = fd_config.model_config.video_token_id
self._dtype = fd_config.model_config.dtype
fd_config.model_config.pretrained_config.prefix_name = "model"
self.fd_config = fd_config
self.embed_tokens = VocabParallelEmbedding(
fd_config=fd_config,
num_embeddings=fd_config.model_config.vocab_size,
embedding_dim=fd_config.model_config.hidden_size,
params_dtype=paddle.get_default_dtype,
prefix=(f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens"),
)
self.layers = nn.LayerList(
[
Qwen2DecoderLayer(
fd_config=fd_config,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}",
)
for i in range(self.num_layers)
]
)
self.norm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm",
)
def load_state_dict(self, state_dict):
"""
Load model parameters from a given state dictionary.
Args:
state_dict (dict[str, np.ndarray | paddle.Tensor]):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
self.embed_tokens.load_state_dict(state_dict)
self.norm.load_state_dict(state_dict)
for i in range(self.num_layers):
logger.info(f"Start load layer {i}")
self.layers[i].load_state_dict(state_dict)
def forward(
self,
ids_remove_padding: paddle.Tensor,
image_features: Optional[paddle.Tensor],
forward_meta: ForwardMeta,
):
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
# -----------------------
# 将 image_embeds 替换 input_embeds 里的 image video 占位符
image_mask = ids_remove_padding == self.image_token_id
image_token_num = image_mask.sum()
video_mask = ids_remove_padding == self.video_token_id
video_token_num = video_mask.sum()
# 由于框架只有 image_features所以目前不支持图片和视频混合
# TODO(wangyafeng) 后续考虑支持传入 video_features
if image_token_num > 0:
hidden_states[image_mask] = image_features.cast(self._dtype)
if video_token_num > 0:
hidden_states[video_mask] = image_features.cast(self._dtype)
# -----------------------
residual = None
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](
forward_meta,
hidden_states,
residual,
)
hidden_states = hidden_states + residual
# -----------------------
max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time, k=1)
hidden_states = extract_text_token_output(
max_seq_len,
max_seq_len_index.cast("int32"),
image_token_num.cast("int32"),
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
hidden_states.cast("float32"),
).cast(self._dtype)
# -----------------------
out = self.norm(hidden_states)
return out
class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM):
"""
Qwen2_5_VLForConditionalGeneration
"""
def __init__(self, fd_config: FDConfig):
"""
Args:
fd_config (FDConfig): Configurations for the LLM model.
"""
super(Qwen2_5_VLForConditionalGeneration, self).__init__(fd_config)
# ----------- vision model ------------
self.visual = self._init_vision_model(fd_config.model_config)
# ----------- language model -------------
self.model = Qwen2_5_VLModel(fd_config=fd_config)
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
self.lm_head = ParallelLMHead(
fd_config=fd_config,
embedding_dim=fd_config.model_config.hidden_size,
num_embeddings=fd_config.model_config.vocab_size,
prefix="lm_head",
)
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
def _init_vision_model(self, model_config) -> nn.Layer:
from fastdeploy.model_executor.models.qwen2_5_vl.dfnrope.modeling import (
DFNRopeVisionTransformerPretrainedModel,
)
visual = DFNRopeVisionTransformerPretrainedModel(model_config, prefix_name="visual")
visual = paddle.amp.decorate(models=visual, level="O2", dtype="bfloat16")
visual.eval()
return visual
@classmethod
def name(self):
return "Qwen2_5_VLForConditionalGeneration"
@paddle.no_grad()
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
"""
Load model parameters from a given state dictionary.
Args:
state_dict (dict[str, np.ndarray | paddle.Tensor]):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
self.model.load_state_dict(state_dict)
self.visual.load_state_dict(state_dict)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0]))
else:
self.lm_head.load_state_dict(state_dict)
def compute_logits(self, hidden_states: paddle.Tensor):
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits
def empty_input_forward(self):
"""
empty_input_forward
"""
fake_hidden_states = paddle.empty(
shape=[0, self.fd_config.model_config.hidden_size],
dtype=paddle.get_default_dtype(),
)
for i in range(
self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers,
):
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states)
self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states)
def forward(
self,
ids_remove_padding: paddle.Tensor,
image_features: Optional[paddle.Tensor],
forward_meta: ForwardMeta,
):
hidden_states = self.model(
ids_remove_padding=ids_remove_padding,
image_features=image_features,
forward_meta=forward_meta,
)
return hidden_states
class Qwen2_5_VLPretrainedModel(PretrainedModel):
"""
Qwen2_PretrainedModel
"""
config_class = FDConfig
def _init_weight(self, layer):
"""
_init_weight
"""
return None
@classmethod
def arch_name(self):
return "Qwen2_5_VLForConditionalGeneration"
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid
from fastdeploy.model_executor.models.utils import WeightMeta
weight_infos = [
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.q_proj.weight", True),
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.q_proj.bias", True),
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.k_proj.weight", True),
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.k_proj.bias", True),
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.v_proj.weight", True),
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.v_proj.bias", True),
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.weight", False),
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.mlp.gate_proj.weight", True),
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.mlp.up_proj.weight", True),
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.mlp.down_proj.weight", False),
WeightMeta(".embed_tokens.weight", False),
WeightMeta("lm_head.weight", True),
]
weight_vison = [
# vision
WeightMeta(
f"visual.blocks.{{{layerid.LAYER_ID}}}.attn.proj.weight",
False,
),
WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.mlp.up_proj.weight", True),
WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.mlp.up_proj.bias", True),
WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.mlp.gate_proj.weight", True),
WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.mlp.gate_proj.bias", True),
WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.mlp.down_proj.weight", False),
WeightMeta(
f"visual.blocks.{{{layerid.LAYER_ID}}}.attn.qkv.weight",
True,
tsm.GQA,
),
WeightMeta(
f"visual.blocks.{{{layerid.LAYER_ID}}}.attn.qkv.bias",
True,
tsm.GQA,
),
]
@classmethod
def _get_tensor_parallel_mappings(cls, config: PretrainedConfig, is_split=True):
"""
get_tensor_parallel_mappings
"""
logger.info("qwen2_5_vl inference model _get_tensor_parallel_mappings")
from fastdeploy.model_executor.models.tp_utils import (
build_expanded_keys,
has_prefix,
split_or_merge_func_v1,
)
fn = split_or_merge_func_v1(
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
head_dim=config.head_dim,
)
vision_fn = split_or_merge_func_v1(
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.vision_config.get("num_heads"),
num_key_value_heads=config.vision_config.get("num_heads"),
head_dim=config.vision_config.get("hidden_size") // config.vision_config.get("num_heads"),
)
def get_tensor_parallel_split_mappings(
num_layers: int,
prefix_name: str,
):
base_actions = {}
for weight_name, is_column, extra in cls.weight_infos:
params = {
"is_column": is_column,
**({extra.value: True} if extra else {}),
}
if "lm_head.weight" or "" in weight_name:
key = weight_name
elif not has_prefix(prefix_name, weight_name):
key = f"{prefix_name}{weight_name}"
else:
key = weight_name
base_actions[key] = partial(fn, **params)
final_actions = {}
final_actions = build_expanded_keys(
base_actions,
num_layers,
)
return final_actions
def get_vison_parallel_split_mappings(num_layers: int):
base_actions = {}
for weight_name, is_column, extra in cls.weight_vison:
params = {
"is_column": is_column,
**({extra.value: True} if extra else {}),
}
base_actions[weight_name] = partial(vision_fn, **params)
final_actions = {}
final_actions = build_expanded_keys(
base_actions,
num_layers,
)
return final_actions
mappings = get_tensor_parallel_split_mappings(
config.num_hidden_layers,
config.prefix_name,
)
vision_mappings = get_vison_parallel_split_mappings(config.vision_config.get("depth"))
return {**mappings, **vision_mappings}

View File

@@ -20,7 +20,11 @@ class MultimodalRegistry:
A registry for multimodal models
"""
mm_models: set[str] = {"Ernie4_5_VLMoeForConditionalGeneration", "Ernie5MoeForCausalLM"}
mm_models: set[str] = {
"Ernie4_5_VLMoeForConditionalGeneration",
"Ernie5MoeForCausalLM",
"Qwen2_5_VLForConditionalGeneration",
}
@classmethod
def contains_model(cls, name: str) -> bool:

View File

@@ -33,6 +33,10 @@ from fastdeploy.model_executor.models.qwen2 import (
Qwen2ForCausalLM,
Qwen2PretrainedModel,
)
from fastdeploy.model_executor.models.qwen2_5_vl.qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
Qwen2_5_VLPretrainedModel,
)
from fastdeploy.model_executor.models.qwen3 import (
Qwen3ForCausalLM,
Qwen3PretrainedModel,
@@ -477,3 +481,51 @@ class Qwen3ForCausalLMRL(Qwen3ForCausalLM, BaseRLModel):
self._complete_missing_mappings()
return self.infer_to_train_mapping
class Qwen2_5_VLForConditionalGenerationRL(Qwen2_5_VLForConditionalGeneration, BaseRLModel):
"""
Qwen2_5_VLForConditionalGenerationRL
"""
_get_tensor_parallel_mappings = Qwen2_5_VLPretrainedModel._get_tensor_parallel_mappings
def __init__(self, fd_config: FDConfig):
"""
Args:
fd_config (FDConfig): Configurations for the LLM model.
"""
super(Qwen2_5_VLForConditionalGenerationRL, self).__init__(fd_config)
@classmethod
def name(self) -> str:
"""name"""
return "Qwen2_5_VLForConditionalGenerationRL"
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
if self._mappings_built:
return self.infer_to_train_mapping
self.infer_to_train_mapping = {}
self._mappings_built = True
# Prepare placeholders
place_holders = ["weight"]
# Initialize mapping dictionary
self._update_base_mappings("model")
base_name = "model.layers"
# Helper function to add layer mappings
def _add_layer_mappings(layer_idx):
# FFN mappings
for ph in place_holders:
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.up_gate_proj.{ph}"] = (
f"{base_name}.{layer_idx}.mlp.gate_up_fused_proj.{ph}"
)
for layer_idx in range(self.fd_config.model_config.num_hidden_layers):
_add_layer_mappings(layer_idx)
self._complete_missing_mappings()
return self.infer_to_train_mapping

View File

@@ -103,6 +103,7 @@ class GPUModelRunner(ModelRunnerBase):
# VL model config:
if self.enable_mm:
if "ernie" in self.fd_config.model_config.model_type:
self._init_image_preprocess()
self.amp_black = [
@@ -242,7 +243,8 @@ class GPUModelRunner(ModelRunnerBase):
dtype=paddle.int64,
)
vision_inputs["images"] = paddle.to_tensor(
inputs["images"][request.image_start : request.image_end], dtype="uint8"
inputs["images"][request.image_start : request.image_end],
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
)
vision_inputs["grid_thw"] = paddle.to_tensor(
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
@@ -797,6 +799,11 @@ class GPUModelRunner(ModelRunnerBase):
if self.enable_mm:
head_dim = self.model_config.head_dim
if "qwen" in self.model_config.model_type: # neox style = True
rope_head_dim = head_dim
else: # neox style = False
rope_head_dim = head_dim // 2
self.share_inputs["rope_emb"] = paddle.full(
shape=[
max_num_seqs,
@@ -804,14 +811,16 @@ class GPUModelRunner(ModelRunnerBase):
1,
self.parallel_config.max_model_len,
1,
head_dim // 2,
rope_head_dim,
],
fill_value=0,
dtype="float32",
)
self.share_inputs["image_features"] = None
self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
self.share_inputs["enable_thinking"] = paddle.full(shape=[1], fill_value=True, dtype="bool")
self.share_inputs["enable_thinking"] = paddle.full(
shape=[1], fill_value=("ernie" in self.model_config.model_type), dtype="bool"
)
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
def _prepare_inputs(self) -> None:
@@ -1186,7 +1195,7 @@ class GPUModelRunner(ModelRunnerBase):
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1),
think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1),
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None),
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
stop_token_ids=self.share_inputs["stop_seqs"],
@@ -1476,7 +1485,7 @@ class GPUModelRunner(ModelRunnerBase):
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1),
think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1),
need_think_end=(self.share_inputs["need_think_end"][:num_running_requests] if self.enable_mm else None),
reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None),
stop_token_ids=self.share_inputs["stop_seqs"],
@@ -1720,7 +1729,7 @@ class GPUModelRunner(ModelRunnerBase):
image_type_ids = one["image_type_ids"][np.newaxis, :]
images = one["images"]
image_type_ids = paddle.to_tensor(image_type_ids, dtype=paddle.int64)
images = paddle.to_tensor(images, dtype="uint8")
images = paddle.to_tensor(images, dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16")
grid_thw = paddle.to_tensor(one["grid_thw"], dtype="int64")
else:
image_type_ids = None
@@ -1742,12 +1751,10 @@ class GPUModelRunner(ModelRunnerBase):
)
return result
@paddle.no_grad()
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
"""extract_vision_features"""
def extract_vision_features_ernie(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
assert inputs["images"] is not None
grid_thw = inputs["grid_thw"]
# ernie-vl has images norm
images = inputs["images"].cast("float32")
images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor
images = images / self.image_preprocess.image_std_tensor
@@ -1772,6 +1779,7 @@ class GPUModelRunner(ModelRunnerBase):
image_features = image_features.reshape([-1, C * self.model_config.spatial_conv_size**2])
image_features = ScatterOp.apply(image_features, axis=-1) # mp 切 Fea
image_features = image_features.reshape([S, -1])
# ernie-vl has resampler_model
image_features = self.model.resampler_model(
image_features,
image_mask,
@@ -1781,6 +1789,31 @@ class GPUModelRunner(ModelRunnerBase):
)
return image_features
def extract_vision_features_qwen(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
assert inputs["images"] is not None
grid_thw = inputs["grid_thw"]
images = inputs["images"]
with paddle.amp.auto_cast(
True,
custom_black_list=self.amp_black,
custom_white_list=self.amp_white,
level="O2",
dtype=self.parallel_config.dtype,
):
image_features = self.model.visual.extract_feature(images, grid_thw)
return image_features
@paddle.no_grad()
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
"""extract_vision_features"""
if "ernie" in self.model_config.model_type:
return self.extract_vision_features_ernie(inputs)
elif "qwen" in self.model_config.model_type:
return self.extract_vision_features_qwen(inputs)
else:
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
@paddle.no_grad()
def prepare_rope3d(self, position_ids: paddle.Tensor, max_len: int) -> paddle.Tensor:
"""prepare_rope3d"""
@@ -1800,5 +1833,6 @@ class GPUModelRunner(ModelRunnerBase):
base=self.model_config.rope_theta,
max_position=self.parallel_config.max_model_len,
freq_allocation=getattr(self.model_config, "freq_allocation", 20),
model_type=self.model_config.model_type,
)
return rope_emb