mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Model]support qwen2_5_vl (#3557)
* adapt qwen_2_5_vl model * adapt qwen_2_5_vl VIT model * adapt qwen2_5_vl images_embeds * adapt qwen2_5_vl 3D rope * adapt qwen2_5_vl 3D rope v2 * adapt qwen2_5_vl processor * adapt qwen2_5_vl bypass resampler_model * adapt qwen2_5_vl 绕过部分ernie逻辑 * adapt qwen2_5_vl 绕过部分ernie逻辑 v2 * adapt qwen2_5_vl 权重加载与命名修改 * adapt qwen2_5_vl 非必须think_end_id * adapt qwen2_5_vl 区分多种模型的extract_vision_features * fix:adapt qwen2_5_vl model * adapt qwen2_5_vl norm * adapt qwen2_5_vl processor 更新 * adapt qwen2_5_vl image and video success * adapt qwen2_5_vl 部分整理代码 * adapt qwen2_5_vl 支持多卡 * adapt qwen2_5_vl on latest develop * adapt qwen2_5_vl RL * adapt qwen2_5_vl 整理代码 * support noex rope3d * adapt qwen2_5_vl add init.py * adapt qwen2_5_vl add init.py v2 * adapt qwen2_5_vl remove space * adapt qwen2_5_vl remove space v2 * adapt qwen2_5_vl pre-commit * adapt qwen2_5_vl update * adapt qwen2_5_vl pre-commit v2 * adapt qwen2_5_vl modify comments * adapt qwen2_5_vl fix indentation * adapt qwen2_5_vl fix indentation v2 --------- Co-authored-by: wangyafeng <wangyafeng@baidu.com> Co-authored-by: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> Co-authored-by: CSWYF3634076 <58356743+CSWYF3634076@users.noreply.github.com>
This commit is contained in:
@@ -325,8 +325,6 @@ class ErnieVlRotaryEmbedding3D:
|
|||||||
|
|
||||||
position_ids_3d[:, : position_ids.shape[1], :] = position_ids
|
position_ids_3d[:, : position_ids.shape[1], :] = position_ids
|
||||||
|
|
||||||
# import pdb;pdb.set_trace()
|
|
||||||
|
|
||||||
# position_ids: [bsz, seq_len]
|
# position_ids: [bsz, seq_len]
|
||||||
position_ids = paddle.arange(0, self.max_position, 1, dtype="float32").reshape((1, -1))
|
position_ids = paddle.arange(0, self.max_position, 1, dtype="float32").reshape((1, -1))
|
||||||
|
|
||||||
@@ -383,6 +381,100 @@ class ErnieVlRotaryEmbedding3D:
|
|||||||
return rot_emb
|
return rot_emb
|
||||||
|
|
||||||
|
|
||||||
|
class QwenVlRotaryEmbedding3D:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
rotary_dim,
|
||||||
|
base,
|
||||||
|
partial_rotary_factor,
|
||||||
|
max_position,
|
||||||
|
freq_allocation,
|
||||||
|
):
|
||||||
|
self.rotary_dim = rotary_dim
|
||||||
|
self.base = base
|
||||||
|
self.paritial_rotary_factor = partial_rotary_factor
|
||||||
|
self.max_position = max_position
|
||||||
|
self.freq_allocation = freq_allocation
|
||||||
|
|
||||||
|
def __call__(self, position_ids):
|
||||||
|
rot_emb = paddle.zeros((2, 1, self.max_position, 1, self.rotary_dim // 2), dtype="float32")
|
||||||
|
|
||||||
|
# position_ids_3d: [bsz, seq_len, 3]
|
||||||
|
position_ids_3d = paddle.tile(
|
||||||
|
paddle.arange(self.max_position, dtype="int64").unsqueeze(0).unsqueeze(-1),
|
||||||
|
[1, 1, 3],
|
||||||
|
)
|
||||||
|
|
||||||
|
position_ids_3d[:, : position_ids.shape[1], :] = position_ids
|
||||||
|
|
||||||
|
# position_ids: [bsz, seq_len]
|
||||||
|
position_ids = paddle.arange(0, self.max_position, 1, dtype="float32").reshape((1, -1))
|
||||||
|
|
||||||
|
position_ids = position_ids / self.paritial_rotary_factor
|
||||||
|
|
||||||
|
indices = paddle.arange(0, self.rotary_dim, 2, dtype="float32")
|
||||||
|
indices = 1 / self.base ** (indices / self.rotary_dim)
|
||||||
|
# sinusoid_inp: [bsz, seq_len, 1, head_dim // 2]
|
||||||
|
sinusoid_inp = position_ids.unsqueeze(-1) * indices.unsqueeze(0)
|
||||||
|
# pos_emb: [bsz, seq_len, 1, head_dim]
|
||||||
|
pos_emb = paddle.concat([paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1)
|
||||||
|
# pos_emb: [bsz, 1, seq_len, head_dim]
|
||||||
|
pos_emb = paddle.reshape(pos_emb, (-1, 1, self.max_position, self.rotary_dim))
|
||||||
|
# pos_emb: [bsz, seq_len, 1, head_dim]
|
||||||
|
pos_emb = pos_emb.transpose([0, 2, 1, 3])
|
||||||
|
# sin: [bsz, seq_len, 1, head_dim // 2]
|
||||||
|
sin, cos = paddle.chunk(pos_emb, 2, axis=-1)
|
||||||
|
batch_indices = paddle.arange(end=position_ids.shape[0]).cast("int64")
|
||||||
|
# batch_indices: [[0]]
|
||||||
|
batch_indices = batch_indices[..., None]
|
||||||
|
# sin, cos: [3, seq_len, 1, head_dim // 2]
|
||||||
|
sin = sin.tile([position_ids.shape[0], 1, 1, 1])
|
||||||
|
cos = cos.tile([position_ids.shape[0], 1, 1, 1])
|
||||||
|
|
||||||
|
tmp_pos_id_0 = position_ids_3d[..., 0].squeeze().astype("int64")
|
||||||
|
tmp_pos_id_1 = position_ids_3d[..., 1].squeeze().astype("int64")
|
||||||
|
tmp_pos_id_2 = position_ids_3d[..., 2].squeeze().astype("int64")
|
||||||
|
|
||||||
|
# sin_bsz = paddle.index_select(sin, index=batch_indices, axis=0)
|
||||||
|
# sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, -self.freq_allocation :]
|
||||||
|
# sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1, axis=1)[
|
||||||
|
# :, :, :, : self.rotary_dim // 2 - self.freq_allocation : 2
|
||||||
|
# ]
|
||||||
|
# sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2, axis=1)[
|
||||||
|
# :, :, :, 1 : self.rotary_dim // 2 - self.freq_allocation : 2
|
||||||
|
# ]
|
||||||
|
# sin_hw = paddle.stack([sin_h, sin_w], axis=-1).reshape(sin_h.shape[:-1] + [sin_h.shape[-1] * 2])
|
||||||
|
# sin_thw = paddle.concat([sin_hw, sin_t], axis=-1)
|
||||||
|
|
||||||
|
section_t = self.freq_allocation # 16
|
||||||
|
section_h = (self.rotary_dim // 2 - self.freq_allocation) // 2 # 24
|
||||||
|
section_w = (self.rotary_dim // 2 - self.freq_allocation) // 2 # 24
|
||||||
|
|
||||||
|
sin_bsz = paddle.index_select(sin, index=batch_indices, axis=0)
|
||||||
|
sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, :section_t]
|
||||||
|
sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1, axis=1)[:, :, :, section_t : section_t + section_h]
|
||||||
|
sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2, axis=1)[
|
||||||
|
:, :, :, section_t + section_h : section_t + section_h + section_w
|
||||||
|
]
|
||||||
|
sin_thw = paddle.concat([sin_t, sin_h, sin_w], axis=-1)
|
||||||
|
|
||||||
|
cos_bsz = paddle.index_select(cos, index=batch_indices, axis=0)
|
||||||
|
|
||||||
|
cos_t = paddle.index_select(cos_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, :section_t]
|
||||||
|
cos_h = paddle.index_select(cos_bsz, index=tmp_pos_id_1, axis=1)[:, :, :, section_t : section_t + section_h]
|
||||||
|
cos_w = paddle.index_select(cos_bsz, index=tmp_pos_id_2, axis=1)[
|
||||||
|
:, :, :, section_t + section_h : section_t + section_h + section_w
|
||||||
|
]
|
||||||
|
cos_thw = paddle.concat([cos_t, cos_h, cos_w], axis=-1)
|
||||||
|
|
||||||
|
rot_emb[0] = cos_thw
|
||||||
|
rot_emb[1] = sin_thw
|
||||||
|
|
||||||
|
# neox style need
|
||||||
|
rot_emb_neox = paddle.concat([rot_emb, rot_emb], axis=-1)
|
||||||
|
return rot_emb_neox
|
||||||
|
|
||||||
|
|
||||||
def get_rope_3d(
|
def get_rope_3d(
|
||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
base: float,
|
base: float,
|
||||||
@@ -390,6 +482,7 @@ def get_rope_3d(
|
|||||||
partial_rotary_factor: float,
|
partial_rotary_factor: float,
|
||||||
max_position: int,
|
max_position: int,
|
||||||
freq_allocation: int,
|
freq_allocation: int,
|
||||||
|
model_type: str,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Pre-calculate rotary position embedding for position_ids.
|
Pre-calculate rotary position embedding for position_ids.
|
||||||
@@ -407,9 +500,20 @@ def get_rope_3d(
|
|||||||
Default: 1 (apply to all dimensions).
|
Default: 1 (apply to all dimensions).
|
||||||
max_position: Maximum position index to precompute.
|
max_position: Maximum position index to precompute.
|
||||||
freq_allocation: Number of rotary dimensions allocated to temporal axis
|
freq_allocation: Number of rotary dimensions allocated to temporal axis
|
||||||
|
model_type: Model type, such as 'ernie4_5_moe_vl' or 'qwen2_5_vl'.
|
||||||
"""
|
"""
|
||||||
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(
|
if "ernie" in model_type:
|
||||||
rotary_dim, base, partial_rotary_factor, max_position, freq_allocation
|
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(
|
||||||
)
|
rotary_dim, base, partial_rotary_factor, max_position, freq_allocation
|
||||||
|
)
|
||||||
|
elif "qwen" in model_type:
|
||||||
|
rotary_emb3d_layer = QwenVlRotaryEmbedding3D(
|
||||||
|
rotary_dim, base, partial_rotary_factor, max_position, freq_allocation
|
||||||
|
)
|
||||||
|
else: # default ernie
|
||||||
|
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(
|
||||||
|
rotary_dim, base, partial_rotary_factor, max_position, freq_allocation
|
||||||
|
)
|
||||||
|
|
||||||
rotary_emb_3d = rotary_emb3d_layer(position_ids)
|
rotary_emb_3d = rotary_emb3d_layer(position_ids)
|
||||||
return rotary_emb_3d
|
return rotary_emb_3d
|
||||||
|
15
fastdeploy/model_executor/models/qwen2_5_vl/__init__.py
Normal file
15
fastdeploy/model_executor/models/qwen2_5_vl/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
@@ -0,0 +1,23 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .configuration import DFNRopeVisionTransformerConfig
|
||||||
|
from .modeling import DFNRopeVisionTransformerPretrainedModel
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DFNRopeVisionTransformerConfig",
|
||||||
|
"DFNRopeVisionTransformerPretrainedModel",
|
||||||
|
]
|
@@ -0,0 +1,277 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddle import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
|
class NewGELUActivation(nn.Layer):
|
||||||
|
"""
|
||||||
|
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
||||||
|
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (Tensor): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: _description_
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
0.5 * input * (1.0 + paddle.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * paddle.pow(input, 3.0))))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GELUActivation(nn.Layer):
|
||||||
|
"""
|
||||||
|
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
|
||||||
|
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
|
||||||
|
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
|
||||||
|
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, use_gelu_python: bool = False):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_gelu_python (bool, optional): _description_. Defaults to False.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if use_gelu_python:
|
||||||
|
self.act = self._gelu_python
|
||||||
|
else:
|
||||||
|
self.act = nn.functional.gelu
|
||||||
|
|
||||||
|
def _gelu_python(self, input: Tensor) -> Tensor:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (Tensor): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: _description_
|
||||||
|
"""
|
||||||
|
return input * 0.5 * (1.0 + paddle.erf(input / math.sqrt(2.0)))
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (Tensor): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: _description_
|
||||||
|
"""
|
||||||
|
return self.act(input)
|
||||||
|
|
||||||
|
|
||||||
|
class FastGELUActivation(nn.Layer):
|
||||||
|
"""
|
||||||
|
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (Tensor): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: _description_
|
||||||
|
"""
|
||||||
|
return 0.5 * input * (1.0 + paddle.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
|
||||||
|
|
||||||
|
|
||||||
|
class QuickGELUActivation(nn.Layer):
|
||||||
|
"""
|
||||||
|
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (Tensor): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: _description_
|
||||||
|
"""
|
||||||
|
return input * F.sigmoid(1.702 * input)
|
||||||
|
|
||||||
|
|
||||||
|
class ClippedGELUActivation(nn.Layer):
|
||||||
|
"""
|
||||||
|
Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
|
||||||
|
it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
|
||||||
|
https://arxiv.org/abs/2004.09602.
|
||||||
|
|
||||||
|
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
|
||||||
|
initially created.
|
||||||
|
|
||||||
|
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
|
||||||
|
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, min: float, max: float):
|
||||||
|
if min > max:
|
||||||
|
raise ValueError(f"min should be < max (got min: {min}, max: {max})")
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.min = min
|
||||||
|
self.max = max
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: _description_
|
||||||
|
"""
|
||||||
|
return paddle.clip(gelu(x), self.min, self.max)
|
||||||
|
|
||||||
|
|
||||||
|
class SiLUActivation(nn.Layer):
|
||||||
|
"""
|
||||||
|
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
|
||||||
|
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
|
||||||
|
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
|
||||||
|
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
|
||||||
|
later.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (Tensor): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: _description_
|
||||||
|
"""
|
||||||
|
return F.silu(input)
|
||||||
|
|
||||||
|
|
||||||
|
class MishActivation(nn.Layer):
|
||||||
|
"""
|
||||||
|
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
|
||||||
|
visit the official repository for the paper: https://github.com/digantamisra98/Mish
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (Tensor): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: _description_
|
||||||
|
"""
|
||||||
|
return F.mish(input)
|
||||||
|
|
||||||
|
|
||||||
|
class LinearActivation(nn.Layer):
|
||||||
|
"""
|
||||||
|
Applies the linear activation function, i.e. forwarding input directly to output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (Tensor): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: _description_
|
||||||
|
"""
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
class ClassInstantier(OrderedDict):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
OrderedDict (_type_): _description_
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (_type_): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
content = super().__getitem__(key)
|
||||||
|
cls, kwargs = content if isinstance(content, tuple) else (content, {})
|
||||||
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
ACT2CLS = {
|
||||||
|
"gelu": GELUActivation,
|
||||||
|
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
|
||||||
|
"gelu_fast": FastGELUActivation,
|
||||||
|
"gelu_new": NewGELUActivation,
|
||||||
|
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
|
||||||
|
"linear": LinearActivation,
|
||||||
|
"mish": MishActivation,
|
||||||
|
"quick_gelu": QuickGELUActivation,
|
||||||
|
"relu": nn.ReLU,
|
||||||
|
"relu6": nn.ReLU6,
|
||||||
|
"sigmoid": nn.Sigmoid,
|
||||||
|
"silu": SiLUActivation,
|
||||||
|
"swish": SiLUActivation,
|
||||||
|
"tanh": nn.Tanh,
|
||||||
|
}
|
||||||
|
ACT2FN = ClassInstantier(ACT2CLS)
|
||||||
|
|
||||||
|
|
||||||
|
def get_activation(activation_string):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
activation_string (_type_): _description_
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
if activation_string in ACT2FN:
|
||||||
|
return ACT2FN[activation_string]
|
||||||
|
else:
|
||||||
|
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
|
||||||
|
|
||||||
|
|
||||||
|
# For backwards compatibility with: from activations import gelu_python
|
||||||
|
gelu_python = get_activation("gelu_python")
|
||||||
|
gelu_new = get_activation("gelu_new")
|
||||||
|
gelu = get_activation("gelu")
|
||||||
|
gelu_fast = get_activation("gelu_fast")
|
||||||
|
quick_gelu = get_activation("quick_gelu")
|
||||||
|
silu = get_activation("silu")
|
||||||
|
mish = get_activation("mish")
|
||||||
|
linear_act = get_activation("linear")
|
@@ -0,0 +1,96 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DFNRopeVisionTransformerConfig",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# qwen2_5 视觉参数
|
||||||
|
"""
|
||||||
|
"vision_config": {
|
||||||
|
"depth": 32,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 1280,
|
||||||
|
"intermediate_size": 3420,
|
||||||
|
"num_heads": 16,
|
||||||
|
"in_chans": 3,
|
||||||
|
"out_hidden_size": 3584,
|
||||||
|
"patch_size": 14,
|
||||||
|
"spatial_merge_size": 2,
|
||||||
|
"spatial_patch_size": 14,
|
||||||
|
"window_size": 112,
|
||||||
|
"fullatt_block_indexes": [
|
||||||
|
7,
|
||||||
|
15,
|
||||||
|
23,
|
||||||
|
31
|
||||||
|
],
|
||||||
|
"tokens_per_second": 2,
|
||||||
|
"temporal_patch_size": 2
|
||||||
|
},
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# qwen:
|
||||||
|
# hidden_size -> embed_dim
|
||||||
|
# out_hidden_size -> hidden_size
|
||||||
|
# intermediate_size -> qwen_vision_block 中 mlp/mlp_hidden_dim
|
||||||
|
# fullatt_block_indexes 区分vit部分不同attention的layer_index
|
||||||
|
# spatial_patch_size 和 tokens_per_second 在vllm中没用到
|
||||||
|
class DFNRopeVisionTransformerConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`~ErnieModel`]. It is used to instantiate an Ernie
|
||||||
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the Ernie-7B.
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "DFNRope_vision_transformer"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
depth=32,
|
||||||
|
hidden_size=1280,
|
||||||
|
out_hidden_size=3584,
|
||||||
|
intermediate_size=3420,
|
||||||
|
hidden_act="silu",
|
||||||
|
num_heads=16,
|
||||||
|
in_channels=3,
|
||||||
|
patch_size=14,
|
||||||
|
spatial_merge_size=2,
|
||||||
|
window_size=112,
|
||||||
|
fullatt_block_indexes=[7, 15, 23, 31],
|
||||||
|
temporal_patch_size=2,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.depth = depth
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.out_hidden_size = out_hidden_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.spatial_merge_size = spatial_merge_size
|
||||||
|
self.window_size = window_size
|
||||||
|
self.fullatt_block_indexes = fullatt_block_indexes
|
||||||
|
self.temporal_patch_size = temporal_patch_size
|
706
fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/modeling.py
Normal file
706
fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/modeling.py
Normal file
@@ -0,0 +1,706 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.distributed import fleet
|
||||||
|
from paddle.distributed.fleet.meta_parallel import (
|
||||||
|
ColumnParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
|
from paddle.nn.functional.flash_attention import (
|
||||||
|
flash_attn_unpadded as flash_attn_varlen_func,
|
||||||
|
)
|
||||||
|
from paddleformers.transformers.model_utils import PretrainedModel
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||||
|
|
||||||
|
from .activation import ACT2FN
|
||||||
|
from .configuration import DFNRopeVisionTransformerConfig
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||||
|
def rotate_half(x):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb_vision(tensor: paddle.Tensor, freqs: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (paddle.Tensor): _description_
|
||||||
|
freqs (paddle.Tensor): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: _description_
|
||||||
|
"""
|
||||||
|
orig_dtype = tensor.dtype
|
||||||
|
|
||||||
|
with paddle.amp.auto_cast(False):
|
||||||
|
tensor = tensor.astype(dtype="float32")
|
||||||
|
cos = freqs.cos()
|
||||||
|
sin = freqs.sin()
|
||||||
|
cos = cos.unsqueeze(1).tile(repeat_times=[1, 1, 2]).unsqueeze(0).astype(dtype="float32")
|
||||||
|
sin = sin.unsqueeze(1).tile(repeat_times=[1, 1, 2]).unsqueeze(0).astype(dtype="float32")
|
||||||
|
output = tensor * cos + rotate_half(tensor) * sin
|
||||||
|
output = paddle.cast(output, orig_dtype)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class VisionFlashAttention2(nn.Layer):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nn (_type_): _description_
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, num_heads: int = 16, tensor_parallel_degree: int = 1) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.tensor_parallel_degree = tensor_parallel_degree
|
||||||
|
|
||||||
|
if tensor_parallel_degree > 1:
|
||||||
|
self.qkv = ColumnParallelLinear(
|
||||||
|
dim,
|
||||||
|
dim * 3,
|
||||||
|
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||||
|
weight_attr=None,
|
||||||
|
has_bias=True,
|
||||||
|
fuse_matmul_bias=True,
|
||||||
|
gather_output=False,
|
||||||
|
)
|
||||||
|
self.proj = RowParallelLinear(
|
||||||
|
dim,
|
||||||
|
dim,
|
||||||
|
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||||
|
input_is_parallel=True,
|
||||||
|
has_bias=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
|
||||||
|
self.proj = nn.Linear(dim, dim, bias_attr=True)
|
||||||
|
|
||||||
|
self.head_dim = dim // num_heads # must added
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: paddle.Tensor,
|
||||||
|
cu_seqlens: paddle.Tensor,
|
||||||
|
rotary_pos_emb: paddle.Tensor = None,
|
||||||
|
) -> paddle.Tensor:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (paddle.Tensor): _description_
|
||||||
|
cu_seqlens (paddle.Tensor): _description_
|
||||||
|
rotary_pos_emb (paddle.Tensor, optional): _description_. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: _description_
|
||||||
|
"""
|
||||||
|
seq_length = hidden_states.shape[0]
|
||||||
|
qkv = (
|
||||||
|
self.qkv(hidden_states)
|
||||||
|
.reshape(
|
||||||
|
[
|
||||||
|
seq_length,
|
||||||
|
3,
|
||||||
|
self.num_heads // self.tensor_parallel_degree,
|
||||||
|
-1,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
.transpose(perm=[1, 0, 2, 3])
|
||||||
|
)
|
||||||
|
q, k, v = qkv.unbind(axis=0)
|
||||||
|
|
||||||
|
q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0)
|
||||||
|
k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0)
|
||||||
|
|
||||||
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
|
|
||||||
|
softmax_scale = self.head_dim**-0.5
|
||||||
|
|
||||||
|
attn_output = (
|
||||||
|
flash_attn_varlen_func( # flash_attn_unpadded
|
||||||
|
q, # 不支持float32
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
max_seqlen,
|
||||||
|
scale=softmax_scale,
|
||||||
|
)[0]
|
||||||
|
.squeeze(0)
|
||||||
|
.reshape([seq_length, -1])
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.astype(paddle.float32)
|
||||||
|
attn_output = self.proj(attn_output)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Layer):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nn (_type_): _description_
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: int = 14,
|
||||||
|
temporal_patch_size: int = 2,
|
||||||
|
in_channels: int = 3,
|
||||||
|
hidden_size: int = 1152,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.temporal_patch_size = temporal_patch_size
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
kernel_size = (temporal_patch_size, patch_size, patch_size)
|
||||||
|
self.proj = nn.layer.Conv3D(
|
||||||
|
in_channels, hidden_size, kernel_size=kernel_size, stride=kernel_size, bias_attr=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (paddle.Tensor): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: _description_
|
||||||
|
"""
|
||||||
|
target_dtype = self.proj.weight.dtype
|
||||||
|
hidden_states = hidden_states.reshape(
|
||||||
|
[-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size]
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.proj(paddle.cast(hidden_states, dtype=target_dtype)).reshape([-1, self.hidden_size])
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class VisionMlp(nn.Layer):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nn (_type_): _description_
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
bias: bool = False,
|
||||||
|
hidden_act: str = "gelu",
|
||||||
|
tensor_parallel_degree: int = 1,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.tensor_parallel_degree = tensor_parallel_degree
|
||||||
|
|
||||||
|
if self.tensor_parallel_degree > 1:
|
||||||
|
self.gate_proj = ColumnParallelLinear(
|
||||||
|
dim,
|
||||||
|
hidden_dim,
|
||||||
|
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||||
|
gather_output=False,
|
||||||
|
has_bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.up_proj = ColumnParallelLinear(
|
||||||
|
dim,
|
||||||
|
hidden_dim,
|
||||||
|
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||||
|
gather_output=False,
|
||||||
|
has_bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
hidden_dim,
|
||||||
|
dim,
|
||||||
|
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||||
|
input_is_parallel=True,
|
||||||
|
has_bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.gate_proj = nn.Linear(dim, hidden_dim, bias_attr=bias)
|
||||||
|
self.up_proj = nn.Linear(dim, hidden_dim, bias_attr=bias)
|
||||||
|
self.down_proj = nn.Linear(hidden_dim, dim, bias_attr=bias)
|
||||||
|
self.act = ACT2FN[hidden_act]
|
||||||
|
|
||||||
|
def forward(self, x) -> paddle.Tensor:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (_type_): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: _description_
|
||||||
|
"""
|
||||||
|
x_gate = self.gate_proj(x)
|
||||||
|
x_gate = self.act(x_gate)
|
||||||
|
x_up = self.up_proj(x)
|
||||||
|
x_down = self.down_proj(x_gate * x_up)
|
||||||
|
return x_down
|
||||||
|
|
||||||
|
|
||||||
|
class VisionRotaryEmbedding(nn.Layer):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nn (_type_): _description_
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): _description_
|
||||||
|
theta (float, optional): _description_. Defaults to 10000.0.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.theta = theta
|
||||||
|
inv_freq = 1.0 / theta ** (paddle.arange(start=0, end=dim, step=2, dtype="float32") / dim)
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistable=False)
|
||||||
|
self._seq_len_cached = 0
|
||||||
|
self._freqs_cached = None
|
||||||
|
|
||||||
|
def update_freqs_cache(self, seqlen: int) -> None:
|
||||||
|
if seqlen > self._seq_len_cached:
|
||||||
|
seqlen *= 2
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
self.inv_freq = 1.0 / (self.theta ** (paddle.arange(0, self.dim, 2, dtype="float32") / self.dim))
|
||||||
|
seq = paddle.arange(seqlen, dtype=self.inv_freq.dtype)
|
||||||
|
freqs = paddle.outer(seq, self.inv_freq)
|
||||||
|
self._freqs_cached = freqs
|
||||||
|
|
||||||
|
def forward(self, seqlen: int) -> paddle.Tensor:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seqlen (int): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: _description_
|
||||||
|
"""
|
||||||
|
self.update_freqs_cache(seqlen)
|
||||||
|
return self._freqs_cached[:seqlen]
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2RMSNorm(nn.Layer):
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
"""
|
||||||
|
Qwen2RMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = paddle.create_parameter(
|
||||||
|
shape=[hidden_size],
|
||||||
|
dtype=paddle.get_default_dtype(),
|
||||||
|
default_initializer=nn.initializer.Constant(1.0),
|
||||||
|
)
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
if paddle.in_dynamic_mode():
|
||||||
|
with paddle.amp.auto_cast(False):
|
||||||
|
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
|
||||||
|
else:
|
||||||
|
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
|
||||||
|
|
||||||
|
if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
|
||||||
|
hidden_states = paddle.cast(hidden_states, self.weight.dtype)
|
||||||
|
return hidden_states * self.weight
|
||||||
|
|
||||||
|
|
||||||
|
class DFNRopeVisionBlock(nn.Layer):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nn (_type_): _description_
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_hidden_dim: int,
|
||||||
|
hidden_act: str = "gelu",
|
||||||
|
tensor_parallel_degree: int = 1,
|
||||||
|
attn_implementation: str = "sdpa",
|
||||||
|
) -> None:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (_type_): _description_
|
||||||
|
attn_implementation (str, optional): _description_. Defaults to "sdpa".
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
|
||||||
|
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
|
||||||
|
|
||||||
|
self.attn = VisionFlashAttention2(
|
||||||
|
dim=dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
tensor_parallel_degree=tensor_parallel_degree,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mlp = VisionMlp(
|
||||||
|
dim=dim,
|
||||||
|
hidden_dim=mlp_hidden_dim,
|
||||||
|
bias=True,
|
||||||
|
hidden_act=hidden_act,
|
||||||
|
tensor_parallel_degree=tensor_parallel_degree,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> paddle.Tensor:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (_type_): _description_
|
||||||
|
cu_seqlens (_type_): _description_
|
||||||
|
rotary_pos_emb (_type_): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: _description_
|
||||||
|
"""
|
||||||
|
|
||||||
|
hidden_states = hidden_states + self.attn(
|
||||||
|
self.norm1(hidden_states),
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
rotary_pos_emb=rotary_pos_emb,
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class PatchMerger(nn.Layer):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nn (_type_): _description_
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): _description_
|
||||||
|
context_dim (int): _description_
|
||||||
|
spatial_merge_size (int, optional): _description_. Defaults to 2.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||||
|
self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(self.hidden_size, self.hidden_size, bias_attr=True),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(self.hidden_size, dim, bias_attr=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: _description_
|
||||||
|
"""
|
||||||
|
x = self.mlp(self.ln_q(x).reshape([-1, self.hidden_size]))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
PretrainedModel (_type_): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = DFNRopeVisionTransformerConfig
|
||||||
|
|
||||||
|
def __init__(self, config, prefix_name: str = "") -> None:
|
||||||
|
super().__init__(config.vision_config)
|
||||||
|
self.spatial_merge_size = config.vision_config.spatial_merge_size
|
||||||
|
self.prefix_name = prefix_name
|
||||||
|
|
||||||
|
# args for get_window_index_thw
|
||||||
|
self.window_size = config.vision_config.window_size
|
||||||
|
self.patch_size = config.vision_config.patch_size
|
||||||
|
self.spatial_merge_size = config.vision_config.spatial_merge_size
|
||||||
|
self.fullatt_block_indexes = config.vision_config.fullatt_block_indexes
|
||||||
|
self.spatial_merge_unit = self.spatial_merge_size**2
|
||||||
|
|
||||||
|
self.patch_embed = PatchEmbed(
|
||||||
|
patch_size=config.vision_config.patch_size,
|
||||||
|
temporal_patch_size=config.vision_config.temporal_patch_size,
|
||||||
|
in_channels=config.vision_config.in_chans,
|
||||||
|
hidden_size=config.vision_config.hidden_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
head_dim = config.vision_config.hidden_size // config.vision_config.num_heads
|
||||||
|
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
||||||
|
|
||||||
|
self.blocks = nn.LayerList(
|
||||||
|
[
|
||||||
|
DFNRopeVisionBlock(
|
||||||
|
dim=config.vision_config.hidden_size,
|
||||||
|
num_heads=config.vision_config.num_heads,
|
||||||
|
mlp_hidden_dim=config.vision_config.intermediate_size,
|
||||||
|
hidden_act=config.vision_config.hidden_act,
|
||||||
|
tensor_parallel_degree=config.pretrained_config.tensor_parallel_degree,
|
||||||
|
)
|
||||||
|
for _ in range(config.vision_config.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.merger = PatchMerger(
|
||||||
|
dim=config.vision_config.out_hidden_size, context_dim=config.vision_config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> paddle.device:
|
||||||
|
return self.patch_embed.proj.weight.device
|
||||||
|
|
||||||
|
def get_dtype(self) -> paddle.dtype:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
paddle.dtype: _description_
|
||||||
|
"""
|
||||||
|
return self.blocks[0].mlp.fc2.weight.dtype
|
||||||
|
|
||||||
|
def get_window_index(self, grid_thw):
|
||||||
|
window_index: list = []
|
||||||
|
cu_window_seqlens: list = [0]
|
||||||
|
window_index_id = 0
|
||||||
|
vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
|
||||||
|
for grid_t, grid_h, grid_w in grid_thw:
|
||||||
|
llm_grid_h, llm_grid_w = (grid_h // self.spatial_merge_size, grid_w // self.spatial_merge_size)
|
||||||
|
index = paddle.arange(end=grid_t * llm_grid_h * llm_grid_w).reshape([grid_t, llm_grid_h, llm_grid_w])
|
||||||
|
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
||||||
|
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
||||||
|
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
||||||
|
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
||||||
|
index_padded = paddle.nn.functional.pad(
|
||||||
|
x=index, pad=(0, pad_w, 0, pad_h), mode="constant", value=-100, pad_from_left_axis=False
|
||||||
|
)
|
||||||
|
index_padded = index_padded.reshape(
|
||||||
|
[grid_t, num_windows_h, vit_merger_window_size, num_windows_w, vit_merger_window_size]
|
||||||
|
)
|
||||||
|
index_padded = index_padded.transpose(perm=[0, 1, 3, 2, 4]).reshape(
|
||||||
|
[grid_t, num_windows_h * num_windows_w, vit_merger_window_size, vit_merger_window_size]
|
||||||
|
)
|
||||||
|
seqlens = (index_padded != -100).sum(axis=[2, 3]).reshape([-1])
|
||||||
|
index_padded = index_padded.reshape([-1])
|
||||||
|
index_new = index_padded[index_padded != -100]
|
||||||
|
window_index.append(index_new + window_index_id)
|
||||||
|
cu_seqlens_tmp = seqlens.cumsum(axis=0) * self.spatial_merge_unit + cu_window_seqlens[-1]
|
||||||
|
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||||
|
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||||
|
window_index = paddle.concat(x=window_index, axis=0)
|
||||||
|
return window_index, cu_window_seqlens
|
||||||
|
|
||||||
|
def rot_pos_emb(self, grid_thw):
|
||||||
|
pos_ids = []
|
||||||
|
for t, h, w in grid_thw:
|
||||||
|
hpos_ids = paddle.arange(h).unsqueeze(1).expand([-1, w])
|
||||||
|
hpos_ids = hpos_ids.reshape(
|
||||||
|
[
|
||||||
|
h // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
w // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
hpos_ids = hpos_ids.transpose(perm=[0, 2, 1, 3])
|
||||||
|
hpos_ids = hpos_ids.flatten()
|
||||||
|
|
||||||
|
wpos_ids = paddle.arange(w).unsqueeze(0).expand([h, -1])
|
||||||
|
wpos_ids = wpos_ids.reshape(
|
||||||
|
[
|
||||||
|
h // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
w // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
wpos_ids = wpos_ids.transpose([0, 2, 1, 3])
|
||||||
|
wpos_ids = wpos_ids.flatten()
|
||||||
|
pos_ids.append(paddle.stack(x=[hpos_ids, wpos_ids], axis=-1).tile(repeat_times=[t, 1]))
|
||||||
|
pos_ids = paddle.concat(x=pos_ids, axis=0)
|
||||||
|
max_grid_size = grid_thw[:, 1:].max()
|
||||||
|
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||||
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(start_axis=1)
|
||||||
|
return rotary_pos_emb
|
||||||
|
|
||||||
|
def get_rope_by_thw(self, t, h, w):
|
||||||
|
window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw(t, h, w)
|
||||||
|
rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w)
|
||||||
|
rotary_pos_emb_thw = rotary_pos_emb_thw[window_index_thw, :, :]
|
||||||
|
rotary_pos_emb_thw = rotary_pos_emb_thw.flatten(start_dim=0, end_dim=1)
|
||||||
|
cu_seqlens_thw = paddle.repeat_interleave(paddle.tensor([h * w], dtype=paddle.int32), t)
|
||||||
|
return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw, cu_seqlens_thw)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor, num_pad=0) -> paddle.Tensor:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (paddle.Tensor): _description_
|
||||||
|
grid_thw (paddle.Tensor): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: _description_
|
||||||
|
"""
|
||||||
|
|
||||||
|
hidden_states = self.patch_embed(hidden_states)
|
||||||
|
|
||||||
|
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||||
|
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
||||||
|
cu_window_seqlens = paddle.to_tensor(data=cu_window_seqlens, dtype="int32", place=hidden_states.place)
|
||||||
|
cu_window_seqlens = paddle.unique_consecutive(x=cu_window_seqlens)
|
||||||
|
seq_len, _ = tuple(hidden_states.shape)
|
||||||
|
hidden_states = hidden_states.reshape([seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1])
|
||||||
|
hidden_states = hidden_states[window_index, :, :]
|
||||||
|
hidden_states = hidden_states.reshape([seq_len, -1])
|
||||||
|
rotary_pos_emb = rotary_pos_emb.reshape([seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1])
|
||||||
|
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
||||||
|
rotary_pos_emb = rotary_pos_emb.reshape([seq_len, -1])
|
||||||
|
|
||||||
|
cu_seqlens = paddle.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
||||||
|
axis=0, dtype="int32"
|
||||||
|
)
|
||||||
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||||
|
|
||||||
|
for layer_num, blk in enumerate(self.blocks):
|
||||||
|
if layer_num in self.fullatt_block_indexes:
|
||||||
|
cu_seqlens_now = cu_seqlens
|
||||||
|
else:
|
||||||
|
cu_seqlens_now = cu_window_seqlens
|
||||||
|
|
||||||
|
hidden_states = blk(
|
||||||
|
hidden_states,
|
||||||
|
cu_seqlens=cu_seqlens_now,
|
||||||
|
rotary_pos_emb=rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
# adapter
|
||||||
|
hidden_states = self.merger(hidden_states)
|
||||||
|
reverse_indices = paddle.argsort(window_index)
|
||||||
|
hidden_states = hidden_states[reverse_indices, :]
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def extract_feature(self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (paddle.Tensor): _description_
|
||||||
|
grid_thw (paddle.Tensor): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: _description_
|
||||||
|
"""
|
||||||
|
return self.forward(hidden_states, grid_thw)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_tensor_parallel_mappings(cls, config, is_split=True):
|
||||||
|
"""
|
||||||
|
dummy
|
||||||
|
"""
|
||||||
|
|
||||||
|
from paddleformers.transformers.conversion_utils import split_or_merge_func
|
||||||
|
|
||||||
|
fn = split_or_merge_func(
|
||||||
|
is_split=is_split,
|
||||||
|
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||||
|
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||||
|
)
|
||||||
|
vision_config = config.vision_config
|
||||||
|
|
||||||
|
def split_qkv_weight(x):
|
||||||
|
head_dim = vision_config.hidden_size // vision_config.num_heads
|
||||||
|
x = x.reshape(
|
||||||
|
[
|
||||||
|
vision_config.hidden_size,
|
||||||
|
3,
|
||||||
|
vision_config.num_heads,
|
||||||
|
head_dim,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
x = np.split(x, vision_config.tensor_parallel_degree, axis=-2)[vision_config.tensor_parallel_rank]
|
||||||
|
x = x.reshape([vision_config.hidden_size, -1])
|
||||||
|
return x
|
||||||
|
|
||||||
|
def split_qkv_bias(x):
|
||||||
|
head_dim = vision_config.hidden_size // vision_config.num_heads
|
||||||
|
x = x.reshape([3, vision_config.num_heads, head_dim])
|
||||||
|
x = np.split(x, vision_config.tensor_parallel_degree, axis=-2)[vision_config.tensor_parallel_rank]
|
||||||
|
x = x.reshape([-1])
|
||||||
|
return x
|
||||||
|
|
||||||
|
def get_tensor_parallel_split_mappings(depth):
|
||||||
|
final_actions = {}
|
||||||
|
base_actions = {
|
||||||
|
"visual.blocks.0.attn.proj.weight": partial(fn, is_column=False),
|
||||||
|
"visual.blocks.0.mlp.gate_proj.weight": partial(fn, is_column=True),
|
||||||
|
"visual.blocks.0.mlp.gate_proj.bias": partial(fn, is_column=True),
|
||||||
|
"visual.blocks.0.mlp.up_proj.weight": partial(fn, is_column=True),
|
||||||
|
"visual.blocks.0.mlp.up_proj.bias": partial(fn, is_column=True),
|
||||||
|
"visual.blocks.0.mlp.down_proj.weight": partial(fn, is_column=False),
|
||||||
|
"visual.blocks.0.qkv.weight": split_qkv_weight,
|
||||||
|
"visual.blocks.0.qkv.bias": split_qkv_bias,
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, action in base_actions.items():
|
||||||
|
if "blocks.0." in key:
|
||||||
|
for i in range(depth):
|
||||||
|
newkey = key.replace("blocks.0.", f"blocks.{i}.")
|
||||||
|
final_actions[newkey] = action
|
||||||
|
return final_actions
|
||||||
|
|
||||||
|
mappings = get_tensor_parallel_split_mappings(vision_config.depth)
|
||||||
|
return mappings
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for param_name, param in params_dict.items():
|
||||||
|
state_dict_key = f"{self.prefix_name}.{param_name}"
|
||||||
|
if state_dict_key not in state_dict:
|
||||||
|
raise ValueError(f"The key {state_dict_key} does not exist in state_dict. ")
|
||||||
|
tensor = get_tensor(state_dict.pop(state_dict_key))
|
||||||
|
if param.shape != tensor.shape:
|
||||||
|
raise ValueError(f"{state_dict_key} param.shape={param.shape} tensor.shape={tensor.shape}")
|
||||||
|
else:
|
||||||
|
param.copy_(tensor, False)
|
390
fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py
Normal file
390
fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py
Normal file
@@ -0,0 +1,390 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
from paddleformers.transformers import PretrainedModel
|
||||||
|
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||||
|
from paddleformers.utils.log import logger
|
||||||
|
|
||||||
|
from fastdeploy.config import FDConfig
|
||||||
|
from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||||
|
support_graph_optimization,
|
||||||
|
)
|
||||||
|
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||||
|
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||||
|
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||||
|
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||||
|
from fastdeploy.model_executor.models.qwen2 import Qwen2DecoderLayer
|
||||||
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
from fastdeploy.model_executor.ops.gpu import extract_text_token_output
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
|
|
||||||
|
|
||||||
|
@support_graph_optimization
|
||||||
|
class Qwen2_5_VLModel(nn.Layer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fd_config: FDConfig = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initializer for the Ernie4_5_VLModel class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_layers = fd_config.model_config.num_hidden_layers
|
||||||
|
self.image_token_id = fd_config.model_config.image_token_id
|
||||||
|
self.video_token_id = fd_config.model_config.video_token_id
|
||||||
|
self._dtype = fd_config.model_config.dtype
|
||||||
|
fd_config.model_config.pretrained_config.prefix_name = "model"
|
||||||
|
self.fd_config = fd_config
|
||||||
|
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
fd_config=fd_config,
|
||||||
|
num_embeddings=fd_config.model_config.vocab_size,
|
||||||
|
embedding_dim=fd_config.model_config.hidden_size,
|
||||||
|
params_dtype=paddle.get_default_dtype,
|
||||||
|
prefix=(f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.LayerList(
|
||||||
|
[
|
||||||
|
Qwen2DecoderLayer(
|
||||||
|
fd_config=fd_config,
|
||||||
|
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}",
|
||||||
|
)
|
||||||
|
for i in range(self.num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = RMSNorm(
|
||||||
|
fd_config,
|
||||||
|
hidden_size=fd_config.model_config.hidden_size,
|
||||||
|
eps=fd_config.model_config.rms_norm_eps,
|
||||||
|
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm",
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
"""
|
||||||
|
Load model parameters from a given state dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict[str, np.ndarray | paddle.Tensor]):
|
||||||
|
A dictionary containing model parameters, where keys are parameter names
|
||||||
|
and values are NumPy arrays or PaddlePaddle tensors.
|
||||||
|
"""
|
||||||
|
self.embed_tokens.load_state_dict(state_dict)
|
||||||
|
self.norm.load_state_dict(state_dict)
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
logger.info(f"Start load layer {i}")
|
||||||
|
self.layers[i].load_state_dict(state_dict)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
ids_remove_padding: paddle.Tensor,
|
||||||
|
image_features: Optional[paddle.Tensor],
|
||||||
|
forward_meta: ForwardMeta,
|
||||||
|
):
|
||||||
|
|
||||||
|
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
|
||||||
|
|
||||||
|
# -----------------------
|
||||||
|
# 将 image_embeds 替换 input_embeds 里的 image video 占位符
|
||||||
|
image_mask = ids_remove_padding == self.image_token_id
|
||||||
|
image_token_num = image_mask.sum()
|
||||||
|
|
||||||
|
video_mask = ids_remove_padding == self.video_token_id
|
||||||
|
video_token_num = video_mask.sum()
|
||||||
|
|
||||||
|
# 由于框架只有 image_features,所以目前不支持图片和视频混合
|
||||||
|
# TODO(wangyafeng) 后续考虑支持传入 video_features
|
||||||
|
if image_token_num > 0:
|
||||||
|
hidden_states[image_mask] = image_features.cast(self._dtype)
|
||||||
|
if video_token_num > 0:
|
||||||
|
hidden_states[video_mask] = image_features.cast(self._dtype)
|
||||||
|
|
||||||
|
# -----------------------
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
hidden_states, residual = self.layers[i](
|
||||||
|
forward_meta,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
# -----------------------
|
||||||
|
max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time, k=1)
|
||||||
|
hidden_states = extract_text_token_output(
|
||||||
|
max_seq_len,
|
||||||
|
max_seq_len_index.cast("int32"),
|
||||||
|
image_token_num.cast("int32"),
|
||||||
|
forward_meta.seq_lens_this_time,
|
||||||
|
forward_meta.cu_seqlens_q,
|
||||||
|
hidden_states.cast("float32"),
|
||||||
|
).cast(self._dtype)
|
||||||
|
# -----------------------
|
||||||
|
|
||||||
|
out = self.norm(hidden_states)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM):
|
||||||
|
"""
|
||||||
|
Qwen2_5_VLForConditionalGeneration
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, fd_config: FDConfig):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
fd_config (FDConfig): Configurations for the LLM model.
|
||||||
|
"""
|
||||||
|
super(Qwen2_5_VLForConditionalGeneration, self).__init__(fd_config)
|
||||||
|
# ----------- vision model ------------
|
||||||
|
self.visual = self._init_vision_model(fd_config.model_config)
|
||||||
|
# ----------- language model -------------
|
||||||
|
self.model = Qwen2_5_VLModel(fd_config=fd_config)
|
||||||
|
|
||||||
|
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
|
||||||
|
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
fd_config=fd_config,
|
||||||
|
embedding_dim=fd_config.model_config.hidden_size,
|
||||||
|
num_embeddings=fd_config.model_config.vocab_size,
|
||||||
|
prefix="lm_head",
|
||||||
|
)
|
||||||
|
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
|
||||||
|
|
||||||
|
def _init_vision_model(self, model_config) -> nn.Layer:
|
||||||
|
from fastdeploy.model_executor.models.qwen2_5_vl.dfnrope.modeling import (
|
||||||
|
DFNRopeVisionTransformerPretrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
visual = DFNRopeVisionTransformerPretrainedModel(model_config, prefix_name="visual")
|
||||||
|
visual = paddle.amp.decorate(models=visual, level="O2", dtype="bfloat16")
|
||||||
|
visual.eval()
|
||||||
|
return visual
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def name(self):
|
||||||
|
return "Qwen2_5_VLForConditionalGeneration"
|
||||||
|
|
||||||
|
@paddle.no_grad()
|
||||||
|
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
|
||||||
|
"""
|
||||||
|
Load model parameters from a given state dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict[str, np.ndarray | paddle.Tensor]):
|
||||||
|
A dictionary containing model parameters, where keys are parameter names
|
||||||
|
and values are NumPy arrays or PaddlePaddle tensors.
|
||||||
|
"""
|
||||||
|
self.model.load_state_dict(state_dict)
|
||||||
|
self.visual.load_state_dict(state_dict)
|
||||||
|
if self.tie_word_embeddings:
|
||||||
|
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||||
|
else:
|
||||||
|
self.lm_head.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = paddle.cast(logits, paddle.float32)
|
||||||
|
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def empty_input_forward(self):
|
||||||
|
"""
|
||||||
|
empty_input_forward
|
||||||
|
"""
|
||||||
|
fake_hidden_states = paddle.empty(
|
||||||
|
shape=[0, self.fd_config.model_config.hidden_size],
|
||||||
|
dtype=paddle.get_default_dtype(),
|
||||||
|
)
|
||||||
|
for i in range(
|
||||||
|
self.fd_config.model_config.moe_layer_start_index,
|
||||||
|
self.fd_config.model_config.num_hidden_layers,
|
||||||
|
):
|
||||||
|
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states)
|
||||||
|
self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
ids_remove_padding: paddle.Tensor,
|
||||||
|
image_features: Optional[paddle.Tensor],
|
||||||
|
forward_meta: ForwardMeta,
|
||||||
|
):
|
||||||
|
|
||||||
|
hidden_states = self.model(
|
||||||
|
ids_remove_padding=ids_remove_padding,
|
||||||
|
image_features=image_features,
|
||||||
|
forward_meta=forward_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5_VLPretrainedModel(PretrainedModel):
|
||||||
|
"""
|
||||||
|
Qwen2_PretrainedModel
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = FDConfig
|
||||||
|
|
||||||
|
def _init_weight(self, layer):
|
||||||
|
"""
|
||||||
|
_init_weight
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def arch_name(self):
|
||||||
|
return "Qwen2_5_VLForConditionalGeneration"
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
|
||||||
|
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid
|
||||||
|
from fastdeploy.model_executor.models.utils import WeightMeta
|
||||||
|
|
||||||
|
weight_infos = [
|
||||||
|
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.q_proj.weight", True),
|
||||||
|
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.q_proj.bias", True),
|
||||||
|
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.k_proj.weight", True),
|
||||||
|
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.k_proj.bias", True),
|
||||||
|
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.v_proj.weight", True),
|
||||||
|
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.v_proj.bias", True),
|
||||||
|
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.weight", False),
|
||||||
|
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.mlp.gate_proj.weight", True),
|
||||||
|
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.mlp.up_proj.weight", True),
|
||||||
|
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.mlp.down_proj.weight", False),
|
||||||
|
WeightMeta(".embed_tokens.weight", False),
|
||||||
|
WeightMeta("lm_head.weight", True),
|
||||||
|
]
|
||||||
|
|
||||||
|
weight_vison = [
|
||||||
|
# vision
|
||||||
|
WeightMeta(
|
||||||
|
f"visual.blocks.{{{layerid.LAYER_ID}}}.attn.proj.weight",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.mlp.up_proj.weight", True),
|
||||||
|
WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.mlp.up_proj.bias", True),
|
||||||
|
WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.mlp.gate_proj.weight", True),
|
||||||
|
WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.mlp.gate_proj.bias", True),
|
||||||
|
WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.mlp.down_proj.weight", False),
|
||||||
|
WeightMeta(
|
||||||
|
f"visual.blocks.{{{layerid.LAYER_ID}}}.attn.qkv.weight",
|
||||||
|
True,
|
||||||
|
tsm.GQA,
|
||||||
|
),
|
||||||
|
WeightMeta(
|
||||||
|
f"visual.blocks.{{{layerid.LAYER_ID}}}.attn.qkv.bias",
|
||||||
|
True,
|
||||||
|
tsm.GQA,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_tensor_parallel_mappings(cls, config: PretrainedConfig, is_split=True):
|
||||||
|
"""
|
||||||
|
get_tensor_parallel_mappings
|
||||||
|
"""
|
||||||
|
logger.info("qwen2_5_vl inference model _get_tensor_parallel_mappings")
|
||||||
|
from fastdeploy.model_executor.models.tp_utils import (
|
||||||
|
build_expanded_keys,
|
||||||
|
has_prefix,
|
||||||
|
split_or_merge_func_v1,
|
||||||
|
)
|
||||||
|
|
||||||
|
fn = split_or_merge_func_v1(
|
||||||
|
is_split=is_split,
|
||||||
|
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||||
|
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||||
|
num_attention_heads=config.num_attention_heads,
|
||||||
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
|
head_dim=config.head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
vision_fn = split_or_merge_func_v1(
|
||||||
|
is_split=is_split,
|
||||||
|
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||||
|
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||||
|
num_attention_heads=config.vision_config.get("num_heads"),
|
||||||
|
num_key_value_heads=config.vision_config.get("num_heads"),
|
||||||
|
head_dim=config.vision_config.get("hidden_size") // config.vision_config.get("num_heads"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_tensor_parallel_split_mappings(
|
||||||
|
num_layers: int,
|
||||||
|
prefix_name: str,
|
||||||
|
):
|
||||||
|
base_actions = {}
|
||||||
|
for weight_name, is_column, extra in cls.weight_infos:
|
||||||
|
params = {
|
||||||
|
"is_column": is_column,
|
||||||
|
**({extra.value: True} if extra else {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if "lm_head.weight" or "" in weight_name:
|
||||||
|
key = weight_name
|
||||||
|
elif not has_prefix(prefix_name, weight_name):
|
||||||
|
key = f"{prefix_name}{weight_name}"
|
||||||
|
else:
|
||||||
|
key = weight_name
|
||||||
|
base_actions[key] = partial(fn, **params)
|
||||||
|
final_actions = {}
|
||||||
|
final_actions = build_expanded_keys(
|
||||||
|
base_actions,
|
||||||
|
num_layers,
|
||||||
|
)
|
||||||
|
return final_actions
|
||||||
|
|
||||||
|
def get_vison_parallel_split_mappings(num_layers: int):
|
||||||
|
base_actions = {}
|
||||||
|
for weight_name, is_column, extra in cls.weight_vison:
|
||||||
|
params = {
|
||||||
|
"is_column": is_column,
|
||||||
|
**({extra.value: True} if extra else {}),
|
||||||
|
}
|
||||||
|
base_actions[weight_name] = partial(vision_fn, **params)
|
||||||
|
final_actions = {}
|
||||||
|
final_actions = build_expanded_keys(
|
||||||
|
base_actions,
|
||||||
|
num_layers,
|
||||||
|
)
|
||||||
|
return final_actions
|
||||||
|
|
||||||
|
mappings = get_tensor_parallel_split_mappings(
|
||||||
|
config.num_hidden_layers,
|
||||||
|
config.prefix_name,
|
||||||
|
)
|
||||||
|
vision_mappings = get_vison_parallel_split_mappings(config.vision_config.get("depth"))
|
||||||
|
|
||||||
|
return {**mappings, **vision_mappings}
|
@@ -20,7 +20,11 @@ class MultimodalRegistry:
|
|||||||
A registry for multimodal models
|
A registry for multimodal models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mm_models: set[str] = {"Ernie4_5_VLMoeForConditionalGeneration", "Ernie5MoeForCausalLM"}
|
mm_models: set[str] = {
|
||||||
|
"Ernie4_5_VLMoeForConditionalGeneration",
|
||||||
|
"Ernie5MoeForCausalLM",
|
||||||
|
"Qwen2_5_VLForConditionalGeneration",
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def contains_model(cls, name: str) -> bool:
|
def contains_model(cls, name: str) -> bool:
|
||||||
|
@@ -33,6 +33,10 @@ from fastdeploy.model_executor.models.qwen2 import (
|
|||||||
Qwen2ForCausalLM,
|
Qwen2ForCausalLM,
|
||||||
Qwen2PretrainedModel,
|
Qwen2PretrainedModel,
|
||||||
)
|
)
|
||||||
|
from fastdeploy.model_executor.models.qwen2_5_vl.qwen2_5_vl import (
|
||||||
|
Qwen2_5_VLForConditionalGeneration,
|
||||||
|
Qwen2_5_VLPretrainedModel,
|
||||||
|
)
|
||||||
from fastdeploy.model_executor.models.qwen3 import (
|
from fastdeploy.model_executor.models.qwen3 import (
|
||||||
Qwen3ForCausalLM,
|
Qwen3ForCausalLM,
|
||||||
Qwen3PretrainedModel,
|
Qwen3PretrainedModel,
|
||||||
@@ -477,3 +481,51 @@ class Qwen3ForCausalLMRL(Qwen3ForCausalLM, BaseRLModel):
|
|||||||
self._complete_missing_mappings()
|
self._complete_missing_mappings()
|
||||||
|
|
||||||
return self.infer_to_train_mapping
|
return self.infer_to_train_mapping
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5_VLForConditionalGenerationRL(Qwen2_5_VLForConditionalGeneration, BaseRLModel):
|
||||||
|
"""
|
||||||
|
Qwen2_5_VLForConditionalGenerationRL
|
||||||
|
"""
|
||||||
|
|
||||||
|
_get_tensor_parallel_mappings = Qwen2_5_VLPretrainedModel._get_tensor_parallel_mappings
|
||||||
|
|
||||||
|
def __init__(self, fd_config: FDConfig):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
fd_config (FDConfig): Configurations for the LLM model.
|
||||||
|
"""
|
||||||
|
super(Qwen2_5_VLForConditionalGenerationRL, self).__init__(fd_config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""name"""
|
||||||
|
return "Qwen2_5_VLForConditionalGenerationRL"
|
||||||
|
|
||||||
|
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
|
||||||
|
if self._mappings_built:
|
||||||
|
return self.infer_to_train_mapping
|
||||||
|
|
||||||
|
self.infer_to_train_mapping = {}
|
||||||
|
self._mappings_built = True
|
||||||
|
# Prepare placeholders
|
||||||
|
place_holders = ["weight"]
|
||||||
|
|
||||||
|
# Initialize mapping dictionary
|
||||||
|
self._update_base_mappings("model")
|
||||||
|
base_name = "model.layers"
|
||||||
|
|
||||||
|
# Helper function to add layer mappings
|
||||||
|
def _add_layer_mappings(layer_idx):
|
||||||
|
# FFN mappings
|
||||||
|
for ph in place_holders:
|
||||||
|
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.up_gate_proj.{ph}"] = (
|
||||||
|
f"{base_name}.{layer_idx}.mlp.gate_up_fused_proj.{ph}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for layer_idx in range(self.fd_config.model_config.num_hidden_layers):
|
||||||
|
_add_layer_mappings(layer_idx)
|
||||||
|
|
||||||
|
self._complete_missing_mappings()
|
||||||
|
|
||||||
|
return self.infer_to_train_mapping
|
||||||
|
@@ -103,7 +103,8 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
# VL model config:
|
# VL model config:
|
||||||
if self.enable_mm:
|
if self.enable_mm:
|
||||||
self._init_image_preprocess()
|
if "ernie" in self.fd_config.model_config.model_type:
|
||||||
|
self._init_image_preprocess()
|
||||||
|
|
||||||
self.amp_black = [
|
self.amp_black = [
|
||||||
"reduce_sum",
|
"reduce_sum",
|
||||||
@@ -242,7 +243,8 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
dtype=paddle.int64,
|
dtype=paddle.int64,
|
||||||
)
|
)
|
||||||
vision_inputs["images"] = paddle.to_tensor(
|
vision_inputs["images"] = paddle.to_tensor(
|
||||||
inputs["images"][request.image_start : request.image_end], dtype="uint8"
|
inputs["images"][request.image_start : request.image_end],
|
||||||
|
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
|
||||||
)
|
)
|
||||||
vision_inputs["grid_thw"] = paddle.to_tensor(
|
vision_inputs["grid_thw"] = paddle.to_tensor(
|
||||||
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
|
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
|
||||||
@@ -797,6 +799,11 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
if self.enable_mm:
|
if self.enable_mm:
|
||||||
head_dim = self.model_config.head_dim
|
head_dim = self.model_config.head_dim
|
||||||
|
if "qwen" in self.model_config.model_type: # neox style = True
|
||||||
|
rope_head_dim = head_dim
|
||||||
|
else: # neox style = False
|
||||||
|
rope_head_dim = head_dim // 2
|
||||||
|
|
||||||
self.share_inputs["rope_emb"] = paddle.full(
|
self.share_inputs["rope_emb"] = paddle.full(
|
||||||
shape=[
|
shape=[
|
||||||
max_num_seqs,
|
max_num_seqs,
|
||||||
@@ -804,14 +811,16 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
1,
|
1,
|
||||||
self.parallel_config.max_model_len,
|
self.parallel_config.max_model_len,
|
||||||
1,
|
1,
|
||||||
head_dim // 2,
|
rope_head_dim,
|
||||||
],
|
],
|
||||||
fill_value=0,
|
fill_value=0,
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
)
|
)
|
||||||
self.share_inputs["image_features"] = None
|
self.share_inputs["image_features"] = None
|
||||||
self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||||
self.share_inputs["enable_thinking"] = paddle.full(shape=[1], fill_value=True, dtype="bool")
|
self.share_inputs["enable_thinking"] = paddle.full(
|
||||||
|
shape=[1], fill_value=("ernie" in self.model_config.model_type), dtype="bool"
|
||||||
|
)
|
||||||
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||||
|
|
||||||
def _prepare_inputs(self) -> None:
|
def _prepare_inputs(self) -> None:
|
||||||
@@ -1186,7 +1195,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
|
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
|
||||||
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
|
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
|
||||||
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
|
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
|
||||||
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1),
|
think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1),
|
||||||
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None),
|
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None),
|
||||||
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
|
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
|
||||||
stop_token_ids=self.share_inputs["stop_seqs"],
|
stop_token_ids=self.share_inputs["stop_seqs"],
|
||||||
@@ -1476,7 +1485,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
|
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
|
||||||
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
|
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
|
||||||
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
|
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
|
||||||
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1),
|
think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1),
|
||||||
need_think_end=(self.share_inputs["need_think_end"][:num_running_requests] if self.enable_mm else None),
|
need_think_end=(self.share_inputs["need_think_end"][:num_running_requests] if self.enable_mm else None),
|
||||||
reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None),
|
reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None),
|
||||||
stop_token_ids=self.share_inputs["stop_seqs"],
|
stop_token_ids=self.share_inputs["stop_seqs"],
|
||||||
@@ -1720,7 +1729,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
image_type_ids = one["image_type_ids"][np.newaxis, :]
|
image_type_ids = one["image_type_ids"][np.newaxis, :]
|
||||||
images = one["images"]
|
images = one["images"]
|
||||||
image_type_ids = paddle.to_tensor(image_type_ids, dtype=paddle.int64)
|
image_type_ids = paddle.to_tensor(image_type_ids, dtype=paddle.int64)
|
||||||
images = paddle.to_tensor(images, dtype="uint8")
|
images = paddle.to_tensor(images, dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16")
|
||||||
grid_thw = paddle.to_tensor(one["grid_thw"], dtype="int64")
|
grid_thw = paddle.to_tensor(one["grid_thw"], dtype="int64")
|
||||||
else:
|
else:
|
||||||
image_type_ids = None
|
image_type_ids = None
|
||||||
@@ -1742,12 +1751,10 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@paddle.no_grad()
|
def extract_vision_features_ernie(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
|
||||||
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
|
|
||||||
"""extract_vision_features"""
|
|
||||||
assert inputs["images"] is not None
|
assert inputs["images"] is not None
|
||||||
grid_thw = inputs["grid_thw"]
|
grid_thw = inputs["grid_thw"]
|
||||||
|
# ernie-vl has images norm
|
||||||
images = inputs["images"].cast("float32")
|
images = inputs["images"].cast("float32")
|
||||||
images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor
|
images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor
|
||||||
images = images / self.image_preprocess.image_std_tensor
|
images = images / self.image_preprocess.image_std_tensor
|
||||||
@@ -1772,6 +1779,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
image_features = image_features.reshape([-1, C * self.model_config.spatial_conv_size**2])
|
image_features = image_features.reshape([-1, C * self.model_config.spatial_conv_size**2])
|
||||||
image_features = ScatterOp.apply(image_features, axis=-1) # mp 切 Fea
|
image_features = ScatterOp.apply(image_features, axis=-1) # mp 切 Fea
|
||||||
image_features = image_features.reshape([S, -1])
|
image_features = image_features.reshape([S, -1])
|
||||||
|
# ernie-vl has resampler_model
|
||||||
image_features = self.model.resampler_model(
|
image_features = self.model.resampler_model(
|
||||||
image_features,
|
image_features,
|
||||||
image_mask,
|
image_mask,
|
||||||
@@ -1781,6 +1789,31 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
)
|
)
|
||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
|
def extract_vision_features_qwen(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
|
||||||
|
assert inputs["images"] is not None
|
||||||
|
grid_thw = inputs["grid_thw"]
|
||||||
|
images = inputs["images"]
|
||||||
|
with paddle.amp.auto_cast(
|
||||||
|
True,
|
||||||
|
custom_black_list=self.amp_black,
|
||||||
|
custom_white_list=self.amp_white,
|
||||||
|
level="O2",
|
||||||
|
dtype=self.parallel_config.dtype,
|
||||||
|
):
|
||||||
|
image_features = self.model.visual.extract_feature(images, grid_thw)
|
||||||
|
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
@paddle.no_grad()
|
||||||
|
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
|
||||||
|
"""extract_vision_features"""
|
||||||
|
if "ernie" in self.model_config.model_type:
|
||||||
|
return self.extract_vision_features_ernie(inputs)
|
||||||
|
elif "qwen" in self.model_config.model_type:
|
||||||
|
return self.extract_vision_features_qwen(inputs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
|
||||||
|
|
||||||
@paddle.no_grad()
|
@paddle.no_grad()
|
||||||
def prepare_rope3d(self, position_ids: paddle.Tensor, max_len: int) -> paddle.Tensor:
|
def prepare_rope3d(self, position_ids: paddle.Tensor, max_len: int) -> paddle.Tensor:
|
||||||
"""prepare_rope3d"""
|
"""prepare_rope3d"""
|
||||||
@@ -1800,5 +1833,6 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
base=self.model_config.rope_theta,
|
base=self.model_config.rope_theta,
|
||||||
max_position=self.parallel_config.max_model_len,
|
max_position=self.parallel_config.max_model_len,
|
||||||
freq_allocation=getattr(self.model_config, "freq_allocation", 20),
|
freq_allocation=getattr(self.model_config, "freq_allocation", 20),
|
||||||
|
model_type=self.model_config.model_type,
|
||||||
)
|
)
|
||||||
return rope_emb
|
return rope_emb
|
||||||
|
Reference in New Issue
Block a user