diff --git a/fastdeploy/model_executor/layers/rotary_embedding.py b/fastdeploy/model_executor/layers/rotary_embedding.py index c0e2b5a14..e52add837 100644 --- a/fastdeploy/model_executor/layers/rotary_embedding.py +++ b/fastdeploy/model_executor/layers/rotary_embedding.py @@ -325,8 +325,6 @@ class ErnieVlRotaryEmbedding3D: position_ids_3d[:, : position_ids.shape[1], :] = position_ids - # import pdb;pdb.set_trace() - # position_ids: [bsz, seq_len] position_ids = paddle.arange(0, self.max_position, 1, dtype="float32").reshape((1, -1)) @@ -383,6 +381,100 @@ class ErnieVlRotaryEmbedding3D: return rot_emb +class QwenVlRotaryEmbedding3D: + def __init__( + self, + rotary_dim, + base, + partial_rotary_factor, + max_position, + freq_allocation, + ): + self.rotary_dim = rotary_dim + self.base = base + self.paritial_rotary_factor = partial_rotary_factor + self.max_position = max_position + self.freq_allocation = freq_allocation + + def __call__(self, position_ids): + rot_emb = paddle.zeros((2, 1, self.max_position, 1, self.rotary_dim // 2), dtype="float32") + + # position_ids_3d: [bsz, seq_len, 3] + position_ids_3d = paddle.tile( + paddle.arange(self.max_position, dtype="int64").unsqueeze(0).unsqueeze(-1), + [1, 1, 3], + ) + + position_ids_3d[:, : position_ids.shape[1], :] = position_ids + + # position_ids: [bsz, seq_len] + position_ids = paddle.arange(0, self.max_position, 1, dtype="float32").reshape((1, -1)) + + position_ids = position_ids / self.paritial_rotary_factor + + indices = paddle.arange(0, self.rotary_dim, 2, dtype="float32") + indices = 1 / self.base ** (indices / self.rotary_dim) + # sinusoid_inp: [bsz, seq_len, 1, head_dim // 2] + sinusoid_inp = position_ids.unsqueeze(-1) * indices.unsqueeze(0) + # pos_emb: [bsz, seq_len, 1, head_dim] + pos_emb = paddle.concat([paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1) + # pos_emb: [bsz, 1, seq_len, head_dim] + pos_emb = paddle.reshape(pos_emb, (-1, 1, self.max_position, self.rotary_dim)) + # pos_emb: [bsz, seq_len, 1, head_dim] + pos_emb = pos_emb.transpose([0, 2, 1, 3]) + # sin: [bsz, seq_len, 1, head_dim // 2] + sin, cos = paddle.chunk(pos_emb, 2, axis=-1) + batch_indices = paddle.arange(end=position_ids.shape[0]).cast("int64") + # batch_indices: [[0]] + batch_indices = batch_indices[..., None] + # sin, cos: [3, seq_len, 1, head_dim // 2] + sin = sin.tile([position_ids.shape[0], 1, 1, 1]) + cos = cos.tile([position_ids.shape[0], 1, 1, 1]) + + tmp_pos_id_0 = position_ids_3d[..., 0].squeeze().astype("int64") + tmp_pos_id_1 = position_ids_3d[..., 1].squeeze().astype("int64") + tmp_pos_id_2 = position_ids_3d[..., 2].squeeze().astype("int64") + + # sin_bsz = paddle.index_select(sin, index=batch_indices, axis=0) + # sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, -self.freq_allocation :] + # sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1, axis=1)[ + # :, :, :, : self.rotary_dim // 2 - self.freq_allocation : 2 + # ] + # sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2, axis=1)[ + # :, :, :, 1 : self.rotary_dim // 2 - self.freq_allocation : 2 + # ] + # sin_hw = paddle.stack([sin_h, sin_w], axis=-1).reshape(sin_h.shape[:-1] + [sin_h.shape[-1] * 2]) + # sin_thw = paddle.concat([sin_hw, sin_t], axis=-1) + + section_t = self.freq_allocation # 16 + section_h = (self.rotary_dim // 2 - self.freq_allocation) // 2 # 24 + section_w = (self.rotary_dim // 2 - self.freq_allocation) // 2 # 24 + + sin_bsz = paddle.index_select(sin, index=batch_indices, axis=0) + sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, :section_t] + sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1, axis=1)[:, :, :, section_t : section_t + section_h] + sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2, axis=1)[ + :, :, :, section_t + section_h : section_t + section_h + section_w + ] + sin_thw = paddle.concat([sin_t, sin_h, sin_w], axis=-1) + + cos_bsz = paddle.index_select(cos, index=batch_indices, axis=0) + + cos_t = paddle.index_select(cos_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, :section_t] + cos_h = paddle.index_select(cos_bsz, index=tmp_pos_id_1, axis=1)[:, :, :, section_t : section_t + section_h] + cos_w = paddle.index_select(cos_bsz, index=tmp_pos_id_2, axis=1)[ + :, :, :, section_t + section_h : section_t + section_h + section_w + ] + cos_thw = paddle.concat([cos_t, cos_h, cos_w], axis=-1) + + rot_emb[0] = cos_thw + rot_emb[1] = sin_thw + + # neox style need + rot_emb_neox = paddle.concat([rot_emb, rot_emb], axis=-1) + return rot_emb_neox + + def get_rope_3d( rotary_dim: int, base: float, @@ -390,6 +482,7 @@ def get_rope_3d( partial_rotary_factor: float, max_position: int, freq_allocation: int, + model_type: str, ) -> paddle.Tensor: """ Pre-calculate rotary position embedding for position_ids. @@ -407,9 +500,20 @@ def get_rope_3d( Default: 1 (apply to all dimensions). max_position: Maximum position index to precompute. freq_allocation: Number of rotary dimensions allocated to temporal axis + model_type: Model type, such as 'ernie4_5_moe_vl' or 'qwen2_5_vl'. """ - rotary_emb3d_layer = ErnieVlRotaryEmbedding3D( - rotary_dim, base, partial_rotary_factor, max_position, freq_allocation - ) + if "ernie" in model_type: + rotary_emb3d_layer = ErnieVlRotaryEmbedding3D( + rotary_dim, base, partial_rotary_factor, max_position, freq_allocation + ) + elif "qwen" in model_type: + rotary_emb3d_layer = QwenVlRotaryEmbedding3D( + rotary_dim, base, partial_rotary_factor, max_position, freq_allocation + ) + else: # default ernie + rotary_emb3d_layer = ErnieVlRotaryEmbedding3D( + rotary_dim, base, partial_rotary_factor, max_position, freq_allocation + ) + rotary_emb_3d = rotary_emb3d_layer(position_ids) return rotary_emb_3d diff --git a/fastdeploy/model_executor/models/qwen2_5_vl/__init__.py b/fastdeploy/model_executor/models/qwen2_5_vl/__init__.py new file mode 100644 index 000000000..f4ede9062 --- /dev/null +++ b/fastdeploy/model_executor/models/qwen2_5_vl/__init__.py @@ -0,0 +1,15 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" diff --git a/fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/__init__.py b/fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/__init__.py new file mode 100644 index 000000000..4c283de51 --- /dev/null +++ b/fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/__init__.py @@ -0,0 +1,23 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from .configuration import DFNRopeVisionTransformerConfig +from .modeling import DFNRopeVisionTransformerPretrainedModel + +__all__ = [ + "DFNRopeVisionTransformerConfig", + "DFNRopeVisionTransformerPretrainedModel", +] diff --git a/fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/activation.py b/fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/activation.py new file mode 100644 index 000000000..1c3b22ae1 --- /dev/null +++ b/fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/activation.py @@ -0,0 +1,277 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import math +from collections import OrderedDict + +import paddle +import paddle.nn.functional as F +from paddle import Tensor, nn + + +class NewGELUActivation(nn.Layer): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def forward(self, input: Tensor) -> Tensor: + """_summary_ + + Args: + input (Tensor): _description_ + + Returns: + Tensor: _description_ + """ + return ( + 0.5 * input * (1.0 + paddle.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * paddle.pow(input, 3.0)))) + ) + + +class GELUActivation(nn.Layer): + """ + Original Implementation of the GELU activation function in Google BERT repo when initially created. For + information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 + + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional + Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, use_gelu_python: bool = False): + """_summary_ + + Args: + use_gelu_python (bool, optional): _description_. Defaults to False. + """ + super().__init__() + if use_gelu_python: + self.act = self._gelu_python + else: + self.act = nn.functional.gelu + + def _gelu_python(self, input: Tensor) -> Tensor: + """_summary_ + + Args: + input (Tensor): _description_ + + Returns: + Tensor: _description_ + """ + return input * 0.5 * (1.0 + paddle.erf(input / math.sqrt(2.0))) + + def forward(self, input: Tensor) -> Tensor: + """_summary_ + + Args: + input (Tensor): _description_ + + Returns: + Tensor: _description_ + """ + return self.act(input) + + +class FastGELUActivation(nn.Layer): + """ + Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs + """ + + def forward(self, input: Tensor) -> Tensor: + """_summary_ + + Args: + input (Tensor): _description_ + + Returns: + Tensor: _description_ + """ + return 0.5 * input * (1.0 + paddle.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) + + +class QuickGELUActivation(nn.Layer): + """ + Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs + """ + + def forward(self, input: Tensor) -> Tensor: + """_summary_ + + Args: + input (Tensor): _description_ + + Returns: + Tensor: _description_ + """ + return input * F.sigmoid(1.702 * input) + + +class ClippedGELUActivation(nn.Layer): + """ + Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as + it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to + https://arxiv.org/abs/2004.09602. + + Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when + initially created. + + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, min: float, max: float): + if min > max: + raise ValueError(f"min should be < max (got min: {min}, max: {max})") + + super().__init__() + self.min = min + self.max = max + + def forward(self, x: Tensor) -> Tensor: + """_summary_ + + Args: + x (Tensor): _description_ + + Returns: + Tensor: _description_ + """ + return paddle.clip(gelu(x), self.min, self.max) + + +class SiLUActivation(nn.Layer): + """ + See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear + Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function + Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated + Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with + later. + """ + + def forward(self, input: Tensor) -> Tensor: + """_summary_ + + Args: + input (Tensor): _description_ + + Returns: + Tensor: _description_ + """ + return F.silu(input) + + +class MishActivation(nn.Layer): + """ + See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also + visit the official repository for the paper: https://github.com/digantamisra98/Mish + """ + + def forward(self, input: Tensor) -> Tensor: + """_summary_ + + Args: + input (Tensor): _description_ + + Returns: + Tensor: _description_ + """ + return F.mish(input) + + +class LinearActivation(nn.Layer): + """ + Applies the linear activation function, i.e. forwarding input directly to output. + """ + + def forward(self, input: Tensor) -> Tensor: + """_summary_ + + Args: + input (Tensor): _description_ + + Returns: + Tensor: _description_ + """ + return input + + +class ClassInstantier(OrderedDict): + """_summary_ + + Args: + OrderedDict (_type_): _description_ + """ + + def __getitem__(self, key): + """_summary_ + + Args: + key (_type_): _description_ + + Returns: + _type_: _description_ + """ + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + "gelu": GELUActivation, + "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}), + "gelu_fast": FastGELUActivation, + "gelu_new": NewGELUActivation, + "gelu_python": (GELUActivation, {"use_gelu_python": True}), + "linear": LinearActivation, + "mish": MishActivation, + "quick_gelu": QuickGELUActivation, + "relu": nn.ReLU, + "relu6": nn.ReLU6, + "sigmoid": nn.Sigmoid, + "silu": SiLUActivation, + "swish": SiLUActivation, + "tanh": nn.Tanh, +} +ACT2FN = ClassInstantier(ACT2CLS) + + +def get_activation(activation_string): + """_summary_ + + Args: + activation_string (_type_): _description_ + + Raises: + KeyError: _description_ + + Returns: + _type_: _description_ + """ + if activation_string in ACT2FN: + return ACT2FN[activation_string] + else: + raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") + + +# For backwards compatibility with: from activations import gelu_python +gelu_python = get_activation("gelu_python") +gelu_new = get_activation("gelu_new") +gelu = get_activation("gelu") +gelu_fast = get_activation("gelu_fast") +quick_gelu = get_activation("quick_gelu") +silu = get_activation("silu") +mish = get_activation("mish") +linear_act = get_activation("linear") diff --git a/fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/configuration.py b/fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/configuration.py new file mode 100644 index 000000000..73b14a9f8 --- /dev/null +++ b/fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/configuration.py @@ -0,0 +1,96 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from paddleformers.transformers.configuration_utils import PretrainedConfig + +__all__ = [ + "DFNRopeVisionTransformerConfig", +] + + +# qwen2_5 视觉参数 +""" + "vision_config": { + "depth": 32, + "hidden_act": "silu", + "hidden_size": 1280, + "intermediate_size": 3420, + "num_heads": 16, + "in_chans": 3, + "out_hidden_size": 3584, + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "window_size": 112, + "fullatt_block_indexes": [ + 7, + 15, + 23, + 31 + ], + "tokens_per_second": 2, + "temporal_patch_size": 2 + }, +""" + + +# qwen: +# hidden_size -> embed_dim +# out_hidden_size -> hidden_size +# intermediate_size -> qwen_vision_block 中 mlp/mlp_hidden_dim +# fullatt_block_indexes 区分vit部分不同attention的layer_index +# spatial_patch_size 和 tokens_per_second 在vllm中没用到 +class DFNRopeVisionTransformerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~ErnieModel`]. It is used to instantiate an Ernie + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Ernie-7B. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + """ + + model_type = "DFNRope_vision_transformer" + + def __init__( + self, + depth=32, + hidden_size=1280, + out_hidden_size=3584, + intermediate_size=3420, + hidden_act="silu", + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + window_size=112, + fullatt_block_indexes=[7, 15, 23, 31], + temporal_patch_size=2, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.out_hidden_size = out_hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.window_size = window_size + self.fullatt_block_indexes = fullatt_block_indexes + self.temporal_patch_size = temporal_patch_size diff --git a/fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/modeling.py b/fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/modeling.py new file mode 100644 index 000000000..0188ee868 --- /dev/null +++ b/fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/modeling.py @@ -0,0 +1,706 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from functools import partial + +import numpy as np +import paddle +import paddle.nn.functional as F +from paddle import nn +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import ( + ColumnParallelLinear, + RowParallelLinear, +) +from paddle.nn.functional.flash_attention import ( + flash_attn_unpadded as flash_attn_varlen_func, +) +from paddleformers.transformers.model_utils import PretrainedModel + +from fastdeploy.model_executor.layers.utils import get_tensor + +from .activation import ACT2FN +from .configuration import DFNRopeVisionTransformerConfig + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return paddle.concat([-x2, x1], axis=-1) # shape is the same as x + + +def apply_rotary_pos_emb_vision(tensor: paddle.Tensor, freqs: paddle.Tensor) -> paddle.Tensor: + """_summary_ + + Args: + tensor (paddle.Tensor): _description_ + freqs (paddle.Tensor): _description_ + + Returns: + paddle.Tensor: _description_ + """ + orig_dtype = tensor.dtype + + with paddle.amp.auto_cast(False): + tensor = tensor.astype(dtype="float32") + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).tile(repeat_times=[1, 1, 2]).unsqueeze(0).astype(dtype="float32") + sin = sin.unsqueeze(1).tile(repeat_times=[1, 1, 2]).unsqueeze(0).astype(dtype="float32") + output = tensor * cos + rotate_half(tensor) * sin + output = paddle.cast(output, orig_dtype) + return output + + +class VisionFlashAttention2(nn.Layer): + """_summary_ + + Args: + nn (_type_): _description_ + """ + + def __init__(self, dim: int, num_heads: int = 16, tensor_parallel_degree: int = 1) -> None: + super().__init__() + self.num_heads = num_heads + self.tensor_parallel_degree = tensor_parallel_degree + + if tensor_parallel_degree > 1: + self.qkv = ColumnParallelLinear( + dim, + dim * 3, + mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), + weight_attr=None, + has_bias=True, + fuse_matmul_bias=True, + gather_output=False, + ) + self.proj = RowParallelLinear( + dim, + dim, + mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), + input_is_parallel=True, + has_bias=True, + ) + else: + self.qkv = nn.Linear(dim, dim * 3, bias_attr=True) + self.proj = nn.Linear(dim, dim, bias_attr=True) + + self.head_dim = dim // num_heads # must added + + def forward( + self, + hidden_states: paddle.Tensor, + cu_seqlens: paddle.Tensor, + rotary_pos_emb: paddle.Tensor = None, + ) -> paddle.Tensor: + """_summary_ + + Args: + hidden_states (paddle.Tensor): _description_ + cu_seqlens (paddle.Tensor): _description_ + rotary_pos_emb (paddle.Tensor, optional): _description_. Defaults to None. + + Returns: + paddle.Tensor: _description_ + """ + seq_length = hidden_states.shape[0] + qkv = ( + self.qkv(hidden_states) + .reshape( + [ + seq_length, + 3, + self.num_heads // self.tensor_parallel_degree, + -1, + ] + ) + .transpose(perm=[1, 0, 2, 3]) + ) + q, k, v = qkv.unbind(axis=0) + + q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + + softmax_scale = self.head_dim**-0.5 + + attn_output = ( + flash_attn_varlen_func( # flash_attn_unpadded + q, # 不支持float32 + k, + v, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + scale=softmax_scale, + )[0] + .squeeze(0) + .reshape([seq_length, -1]) + ) + + attn_output = attn_output.astype(paddle.float32) + attn_output = self.proj(attn_output) + return attn_output + + +class PatchEmbed(nn.Layer): + """_summary_ + + Args: + nn (_type_): _description_ + """ + + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + hidden_size: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.hidden_size = hidden_size + + kernel_size = (temporal_patch_size, patch_size, patch_size) + self.proj = nn.layer.Conv3D( + in_channels, hidden_size, kernel_size=kernel_size, stride=kernel_size, bias_attr=False + ) + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + """_summary_ + + Args: + hidden_states (paddle.Tensor): _description_ + + Returns: + paddle.Tensor: _description_ + """ + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.reshape( + [-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size] + ) + + hidden_states = self.proj(paddle.cast(hidden_states, dtype=target_dtype)).reshape([-1, self.hidden_size]) + return hidden_states + + +class VisionMlp(nn.Layer): + """_summary_ + + Args: + nn (_type_): _description_ + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + bias: bool = False, + hidden_act: str = "gelu", + tensor_parallel_degree: int = 1, + ) -> None: + super().__init__() + self.tensor_parallel_degree = tensor_parallel_degree + + if self.tensor_parallel_degree > 1: + self.gate_proj = ColumnParallelLinear( + dim, + hidden_dim, + mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), + gather_output=False, + has_bias=bias, + ) + + self.up_proj = ColumnParallelLinear( + dim, + hidden_dim, + mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), + gather_output=False, + has_bias=bias, + ) + + self.down_proj = RowParallelLinear( + hidden_dim, + dim, + mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), + input_is_parallel=True, + has_bias=bias, + ) + + else: + self.gate_proj = nn.Linear(dim, hidden_dim, bias_attr=bias) + self.up_proj = nn.Linear(dim, hidden_dim, bias_attr=bias) + self.down_proj = nn.Linear(hidden_dim, dim, bias_attr=bias) + self.act = ACT2FN[hidden_act] + + def forward(self, x) -> paddle.Tensor: + """_summary_ + + Args: + x (_type_): _description_ + + Returns: + paddle.Tensor: _description_ + """ + x_gate = self.gate_proj(x) + x_gate = self.act(x_gate) + x_up = self.up_proj(x) + x_down = self.down_proj(x_gate * x_up) + return x_down + + +class VisionRotaryEmbedding(nn.Layer): + """_summary_ + + Args: + nn (_type_): _description_ + """ + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + """_summary_ + + Args: + dim (int): _description_ + theta (float, optional): _description_. Defaults to 10000.0. + """ + super().__init__() + self.dim = dim + self.theta = theta + inv_freq = 1.0 / theta ** (paddle.arange(start=0, end=dim, step=2, dtype="float32") / dim) + self.register_buffer("inv_freq", inv_freq, persistable=False) + self._seq_len_cached = 0 + self._freqs_cached = None + + def update_freqs_cache(self, seqlen: int) -> None: + if seqlen > self._seq_len_cached: + seqlen *= 2 + self._seq_len_cached = seqlen + self.inv_freq = 1.0 / (self.theta ** (paddle.arange(0, self.dim, 2, dtype="float32") / self.dim)) + seq = paddle.arange(seqlen, dtype=self.inv_freq.dtype) + freqs = paddle.outer(seq, self.inv_freq) + self._freqs_cached = freqs + + def forward(self, seqlen: int) -> paddle.Tensor: + """_summary_ + + Args: + seqlen (int): _description_ + + Returns: + paddle.Tensor: _description_ + """ + self.update_freqs_cache(seqlen) + return self._freqs_cached[:seqlen] + + +class Qwen2RMSNorm(nn.Layer): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = paddle.create_parameter( + shape=[hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + self.variance_epsilon = eps + + def forward(self, hidden_states): + if paddle.in_dynamic_mode(): + with paddle.amp.auto_cast(False): + variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + else: + variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + + if self.weight.dtype in [paddle.float16, paddle.bfloat16]: + hidden_states = paddle.cast(hidden_states, self.weight.dtype) + return hidden_states * self.weight + + +class DFNRopeVisionBlock(nn.Layer): + """_summary_ + + Args: + nn (_type_): _description_ + """ + + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + hidden_act: str = "gelu", + tensor_parallel_degree: int = 1, + attn_implementation: str = "sdpa", + ) -> None: + """_summary_ + + Args: + config (_type_): _description_ + attn_implementation (str, optional): _description_. Defaults to "sdpa". + """ + super().__init__() + + self.norm1 = Qwen2RMSNorm(dim, eps=1e-6) + self.norm2 = Qwen2RMSNorm(dim, eps=1e-6) + + self.attn = VisionFlashAttention2( + dim=dim, + num_heads=num_heads, + tensor_parallel_degree=tensor_parallel_degree, + ) + + self.mlp = VisionMlp( + dim=dim, + hidden_dim=mlp_hidden_dim, + bias=True, + hidden_act=hidden_act, + tensor_parallel_degree=tensor_parallel_degree, + ) + + def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> paddle.Tensor: + """_summary_ + + Args: + hidden_states (_type_): _description_ + cu_seqlens (_type_): _description_ + rotary_pos_emb (_type_): _description_ + + Returns: + paddle.Tensor: _description_ + """ + + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class PatchMerger(nn.Layer): + """_summary_ + + Args: + nn (_type_): _description_ + """ + + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + """_summary_ + + Args: + dim (int): _description_ + context_dim (int): _description_ + spatial_merge_size (int, optional): _description_. Defaults to 2. + """ + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size, bias_attr=True), + nn.GELU(), + nn.Linear(self.hidden_size, dim, bias_attr=True), + ) + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + """_summary_ + + Args: + x (paddle.Tensor): _description_ + + Returns: + paddle.Tensor: _description_ + """ + x = self.mlp(self.ln_q(x).reshape([-1, self.hidden_size])) + + return x + + +class DFNRopeVisionTransformerPretrainedModel(PretrainedModel): + """_summary_ + + Args: + PretrainedModel (_type_): _description_ + + Returns: + _type_: _description_ + """ + + config_class = DFNRopeVisionTransformerConfig + + def __init__(self, config, prefix_name: str = "") -> None: + super().__init__(config.vision_config) + self.spatial_merge_size = config.vision_config.spatial_merge_size + self.prefix_name = prefix_name + + # args for get_window_index_thw + self.window_size = config.vision_config.window_size + self.patch_size = config.vision_config.patch_size + self.spatial_merge_size = config.vision_config.spatial_merge_size + self.fullatt_block_indexes = config.vision_config.fullatt_block_indexes + self.spatial_merge_unit = self.spatial_merge_size**2 + + self.patch_embed = PatchEmbed( + patch_size=config.vision_config.patch_size, + temporal_patch_size=config.vision_config.temporal_patch_size, + in_channels=config.vision_config.in_chans, + hidden_size=config.vision_config.hidden_size, + ) + + head_dim = config.vision_config.hidden_size // config.vision_config.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.LayerList( + [ + DFNRopeVisionBlock( + dim=config.vision_config.hidden_size, + num_heads=config.vision_config.num_heads, + mlp_hidden_dim=config.vision_config.intermediate_size, + hidden_act=config.vision_config.hidden_act, + tensor_parallel_degree=config.pretrained_config.tensor_parallel_degree, + ) + for _ in range(config.vision_config.depth) + ] + ) + + self.merger = PatchMerger( + dim=config.vision_config.out_hidden_size, context_dim=config.vision_config.hidden_size + ) + + @property + def device(self) -> paddle.device: + return self.patch_embed.proj.weight.device + + def get_dtype(self) -> paddle.dtype: + """_summary_ + + Returns: + paddle.dtype: _description_ + """ + return self.blocks[0].mlp.fc2.weight.dtype + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = (grid_h // self.spatial_merge_size, grid_w // self.spatial_merge_size) + index = paddle.arange(end=grid_t * llm_grid_h * llm_grid_w).reshape([grid_t, llm_grid_h, llm_grid_w]) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = paddle.nn.functional.pad( + x=index, pad=(0, pad_w, 0, pad_h), mode="constant", value=-100, pad_from_left_axis=False + ) + index_padded = index_padded.reshape( + [grid_t, num_windows_h, vit_merger_window_size, num_windows_w, vit_merger_window_size] + ) + index_padded = index_padded.transpose(perm=[0, 1, 3, 2, 4]).reshape( + [grid_t, num_windows_h * num_windows_w, vit_merger_window_size, vit_merger_window_size] + ) + seqlens = (index_padded != -100).sum(axis=[2, 3]).reshape([-1]) + index_padded = index_padded.reshape([-1]) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(axis=0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = paddle.concat(x=window_index, axis=0) + return window_index, cu_window_seqlens + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = paddle.arange(h).unsqueeze(1).expand([-1, w]) + hpos_ids = hpos_ids.reshape( + [ + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ] + ) + hpos_ids = hpos_ids.transpose(perm=[0, 2, 1, 3]) + hpos_ids = hpos_ids.flatten() + + wpos_ids = paddle.arange(w).unsqueeze(0).expand([h, -1]) + wpos_ids = wpos_ids.reshape( + [ + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ] + ) + wpos_ids = wpos_ids.transpose([0, 2, 1, 3]) + wpos_ids = wpos_ids.flatten() + pos_ids.append(paddle.stack(x=[hpos_ids, wpos_ids], axis=-1).tile(repeat_times=[t, 1])) + pos_ids = paddle.concat(x=pos_ids, axis=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(start_axis=1) + return rotary_pos_emb + + def get_rope_by_thw(self, t, h, w): + window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw(t, h, w) + rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w) + rotary_pos_emb_thw = rotary_pos_emb_thw[window_index_thw, :, :] + rotary_pos_emb_thw = rotary_pos_emb_thw.flatten(start_dim=0, end_dim=1) + cu_seqlens_thw = paddle.repeat_interleave(paddle.tensor([h * w], dtype=paddle.int32), t) + return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw, cu_seqlens_thw) + + def forward(self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor, num_pad=0) -> paddle.Tensor: + """_summary_ + + Args: + hidden_states (paddle.Tensor): _description_ + grid_thw (paddle.Tensor): _description_ + + Returns: + paddle.Tensor: _description_ + """ + + hidden_states = self.patch_embed(hidden_states) + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = paddle.to_tensor(data=cu_window_seqlens, dtype="int32", place=hidden_states.place) + cu_window_seqlens = paddle.unique_consecutive(x=cu_window_seqlens) + seq_len, _ = tuple(hidden_states.shape) + hidden_states = hidden_states.reshape([seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1]) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape([seq_len, -1]) + rotary_pos_emb = rotary_pos_emb.reshape([seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1]) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape([seq_len, -1]) + + cu_seqlens = paddle.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + axis=0, dtype="int32" + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + rotary_pos_emb=rotary_pos_emb, + ) + + # adapter + hidden_states = self.merger(hidden_states) + reverse_indices = paddle.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + return hidden_states + + def extract_feature(self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor) -> paddle.Tensor: + """_summary_ + + Args: + hidden_states (paddle.Tensor): _description_ + grid_thw (paddle.Tensor): _description_ + + Returns: + paddle.Tensor: _description_ + """ + return self.forward(hidden_states, grid_thw) + + @classmethod + def _get_tensor_parallel_mappings(cls, config, is_split=True): + """ + dummy + """ + + from paddleformers.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + ) + vision_config = config.vision_config + + def split_qkv_weight(x): + head_dim = vision_config.hidden_size // vision_config.num_heads + x = x.reshape( + [ + vision_config.hidden_size, + 3, + vision_config.num_heads, + head_dim, + ] + ) + x = np.split(x, vision_config.tensor_parallel_degree, axis=-2)[vision_config.tensor_parallel_rank] + x = x.reshape([vision_config.hidden_size, -1]) + return x + + def split_qkv_bias(x): + head_dim = vision_config.hidden_size // vision_config.num_heads + x = x.reshape([3, vision_config.num_heads, head_dim]) + x = np.split(x, vision_config.tensor_parallel_degree, axis=-2)[vision_config.tensor_parallel_rank] + x = x.reshape([-1]) + return x + + def get_tensor_parallel_split_mappings(depth): + final_actions = {} + base_actions = { + "visual.blocks.0.attn.proj.weight": partial(fn, is_column=False), + "visual.blocks.0.mlp.gate_proj.weight": partial(fn, is_column=True), + "visual.blocks.0.mlp.gate_proj.bias": partial(fn, is_column=True), + "visual.blocks.0.mlp.up_proj.weight": partial(fn, is_column=True), + "visual.blocks.0.mlp.up_proj.bias": partial(fn, is_column=True), + "visual.blocks.0.mlp.down_proj.weight": partial(fn, is_column=False), + "visual.blocks.0.qkv.weight": split_qkv_weight, + "visual.blocks.0.qkv.bias": split_qkv_bias, + } + + for key, action in base_actions.items(): + if "blocks.0." in key: + for i in range(depth): + newkey = key.replace("blocks.0.", f"blocks.{i}.") + final_actions[newkey] = action + return final_actions + + mappings = get_tensor_parallel_split_mappings(vision_config.depth) + return mappings + + def load_state_dict(self, state_dict): + params_dict = dict(self.named_parameters()) + for param_name, param in params_dict.items(): + state_dict_key = f"{self.prefix_name}.{param_name}" + if state_dict_key not in state_dict: + raise ValueError(f"The key {state_dict_key} does not exist in state_dict. ") + tensor = get_tensor(state_dict.pop(state_dict_key)) + if param.shape != tensor.shape: + raise ValueError(f"{state_dict_key} param.shape={param.shape} tensor.shape={tensor.shape}") + else: + param.copy_(tensor, False) diff --git a/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py b/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py new file mode 100644 index 000000000..5de437ef6 --- /dev/null +++ b/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py @@ -0,0 +1,390 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +from functools import partial +from typing import Dict, Optional, Union + +import numpy as np +import paddle +from paddle import nn +from paddleformers.transformers import PretrainedModel +from paddleformers.transformers.configuration_utils import PretrainedConfig +from paddleformers.utils.log import logger + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.graph_optimization.decorator import ( + support_graph_optimization, +) +from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding +from fastdeploy.model_executor.layers.lm_head import ParallelLMHead +from fastdeploy.model_executor.layers.normalization import RMSNorm +from fastdeploy.model_executor.models.model_base import ModelForCasualLM +from fastdeploy.model_executor.models.qwen2 import Qwen2DecoderLayer +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import extract_text_token_output + +from fastdeploy.model_executor.forward_meta import ForwardMeta + + +@support_graph_optimization +class Qwen2_5_VLModel(nn.Layer): + def __init__( + self, + fd_config: FDConfig = None, + ): + """ + Initializer for the Ernie4_5_VLModel class. + + Args: + + """ + super().__init__() + + self.num_layers = fd_config.model_config.num_hidden_layers + self.image_token_id = fd_config.model_config.image_token_id + self.video_token_id = fd_config.model_config.video_token_id + self._dtype = fd_config.model_config.dtype + fd_config.model_config.pretrained_config.prefix_name = "model" + self.fd_config = fd_config + + self.embed_tokens = VocabParallelEmbedding( + fd_config=fd_config, + num_embeddings=fd_config.model_config.vocab_size, + embedding_dim=fd_config.model_config.hidden_size, + params_dtype=paddle.get_default_dtype, + prefix=(f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens"), + ) + + self.layers = nn.LayerList( + [ + Qwen2DecoderLayer( + fd_config=fd_config, + prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}", + ) + for i in range(self.num_layers) + ] + ) + + self.norm = RMSNorm( + fd_config, + hidden_size=fd_config.model_config.hidden_size, + eps=fd_config.model_config.rms_norm_eps, + prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm", + ) + + def load_state_dict(self, state_dict): + """ + Load model parameters from a given state dictionary. + + Args: + state_dict (dict[str, np.ndarray | paddle.Tensor]): + A dictionary containing model parameters, where keys are parameter names + and values are NumPy arrays or PaddlePaddle tensors. + """ + self.embed_tokens.load_state_dict(state_dict) + self.norm.load_state_dict(state_dict) + for i in range(self.num_layers): + logger.info(f"Start load layer {i}") + self.layers[i].load_state_dict(state_dict) + + def forward( + self, + ids_remove_padding: paddle.Tensor, + image_features: Optional[paddle.Tensor], + forward_meta: ForwardMeta, + ): + + hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding) + + # ----------------------- + # 将 image_embeds 替换 input_embeds 里的 image video 占位符 + image_mask = ids_remove_padding == self.image_token_id + image_token_num = image_mask.sum() + + video_mask = ids_remove_padding == self.video_token_id + video_token_num = video_mask.sum() + + # 由于框架只有 image_features,所以目前不支持图片和视频混合 + # TODO(wangyafeng) 后续考虑支持传入 video_features + if image_token_num > 0: + hidden_states[image_mask] = image_features.cast(self._dtype) + if video_token_num > 0: + hidden_states[video_mask] = image_features.cast(self._dtype) + + # ----------------------- + + residual = None + for i in range(self.num_layers): + hidden_states, residual = self.layers[i]( + forward_meta, + hidden_states, + residual, + ) + + hidden_states = hidden_states + residual + + # ----------------------- + max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time, k=1) + hidden_states = extract_text_token_output( + max_seq_len, + max_seq_len_index.cast("int32"), + image_token_num.cast("int32"), + forward_meta.seq_lens_this_time, + forward_meta.cu_seqlens_q, + hidden_states.cast("float32"), + ).cast(self._dtype) + # ----------------------- + + out = self.norm(hidden_states) + + return out + + +class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM): + """ + Qwen2_5_VLForConditionalGeneration + """ + + def __init__(self, fd_config: FDConfig): + """ + Args: + fd_config (FDConfig): Configurations for the LLM model. + """ + super(Qwen2_5_VLForConditionalGeneration, self).__init__(fd_config) + # ----------- vision model ------------ + self.visual = self._init_vision_model(fd_config.model_config) + # ----------- language model ------------- + self.model = Qwen2_5_VLModel(fd_config=fd_config) + + self.ori_vocab_size = fd_config.model_config.ori_vocab_size + + self.lm_head = ParallelLMHead( + fd_config=fd_config, + embedding_dim=fd_config.model_config.hidden_size, + num_embeddings=fd_config.model_config.vocab_size, + prefix="lm_head", + ) + self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings + + def _init_vision_model(self, model_config) -> nn.Layer: + from fastdeploy.model_executor.models.qwen2_5_vl.dfnrope.modeling import ( + DFNRopeVisionTransformerPretrainedModel, + ) + + visual = DFNRopeVisionTransformerPretrainedModel(model_config, prefix_name="visual") + visual = paddle.amp.decorate(models=visual, level="O2", dtype="bfloat16") + visual.eval() + return visual + + @classmethod + def name(self): + return "Qwen2_5_VLForConditionalGeneration" + + @paddle.no_grad() + def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]): + """ + Load model parameters from a given state dictionary. + + Args: + state_dict (dict[str, np.ndarray | paddle.Tensor]): + A dictionary containing model parameters, where keys are parameter names + and values are NumPy arrays or PaddlePaddle tensors. + """ + self.model.load_state_dict(state_dict) + self.visual.load_state_dict(state_dict) + if self.tie_word_embeddings: + self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) + else: + self.lm_head.load_state_dict(state_dict) + + def compute_logits(self, hidden_states: paddle.Tensor): + logits = self.lm_head(hidden_states) + logits = paddle.cast(logits, paddle.float32) + logits[:, self.ori_vocab_size :] = -float("inf") + + return logits + + def empty_input_forward(self): + """ + empty_input_forward + """ + fake_hidden_states = paddle.empty( + shape=[0, self.fd_config.model_config.hidden_size], + dtype=paddle.get_default_dtype(), + ) + for i in range( + self.fd_config.model_config.moe_layer_start_index, + self.fd_config.model_config.num_hidden_layers, + ): + self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states) + self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states) + + def forward( + self, + ids_remove_padding: paddle.Tensor, + image_features: Optional[paddle.Tensor], + forward_meta: ForwardMeta, + ): + + hidden_states = self.model( + ids_remove_padding=ids_remove_padding, + image_features=image_features, + forward_meta=forward_meta, + ) + + return hidden_states + + +class Qwen2_5_VLPretrainedModel(PretrainedModel): + """ + Qwen2_PretrainedModel + """ + + config_class = FDConfig + + def _init_weight(self, layer): + """ + _init_weight + """ + return None + + @classmethod + def arch_name(self): + return "Qwen2_5_VLForConditionalGeneration" + + from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm + from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid + from fastdeploy.model_executor.models.utils import WeightMeta + + weight_infos = [ + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.q_proj.weight", True), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.q_proj.bias", True), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.k_proj.weight", True), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.k_proj.bias", True), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.v_proj.weight", True), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.v_proj.bias", True), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.weight", False), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.mlp.gate_proj.weight", True), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.mlp.up_proj.weight", True), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.mlp.down_proj.weight", False), + WeightMeta(".embed_tokens.weight", False), + WeightMeta("lm_head.weight", True), + ] + + weight_vison = [ + # vision + WeightMeta( + f"visual.blocks.{{{layerid.LAYER_ID}}}.attn.proj.weight", + False, + ), + WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.mlp.up_proj.weight", True), + WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.mlp.up_proj.bias", True), + WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.mlp.gate_proj.weight", True), + WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.mlp.gate_proj.bias", True), + WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.mlp.down_proj.weight", False), + WeightMeta( + f"visual.blocks.{{{layerid.LAYER_ID}}}.attn.qkv.weight", + True, + tsm.GQA, + ), + WeightMeta( + f"visual.blocks.{{{layerid.LAYER_ID}}}.attn.qkv.bias", + True, + tsm.GQA, + ), + ] + + @classmethod + def _get_tensor_parallel_mappings(cls, config: PretrainedConfig, is_split=True): + """ + get_tensor_parallel_mappings + """ + logger.info("qwen2_5_vl inference model _get_tensor_parallel_mappings") + from fastdeploy.model_executor.models.tp_utils import ( + build_expanded_keys, + has_prefix, + split_or_merge_func_v1, + ) + + fn = split_or_merge_func_v1( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.head_dim, + ) + + vision_fn = split_or_merge_func_v1( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.vision_config.get("num_heads"), + num_key_value_heads=config.vision_config.get("num_heads"), + head_dim=config.vision_config.get("hidden_size") // config.vision_config.get("num_heads"), + ) + + def get_tensor_parallel_split_mappings( + num_layers: int, + prefix_name: str, + ): + base_actions = {} + for weight_name, is_column, extra in cls.weight_infos: + params = { + "is_column": is_column, + **({extra.value: True} if extra else {}), + } + + if "lm_head.weight" or "" in weight_name: + key = weight_name + elif not has_prefix(prefix_name, weight_name): + key = f"{prefix_name}{weight_name}" + else: + key = weight_name + base_actions[key] = partial(fn, **params) + final_actions = {} + final_actions = build_expanded_keys( + base_actions, + num_layers, + ) + return final_actions + + def get_vison_parallel_split_mappings(num_layers: int): + base_actions = {} + for weight_name, is_column, extra in cls.weight_vison: + params = { + "is_column": is_column, + **({extra.value: True} if extra else {}), + } + base_actions[weight_name] = partial(vision_fn, **params) + final_actions = {} + final_actions = build_expanded_keys( + base_actions, + num_layers, + ) + return final_actions + + mappings = get_tensor_parallel_split_mappings( + config.num_hidden_layers, + config.prefix_name, + ) + vision_mappings = get_vison_parallel_split_mappings(config.vision_config.get("depth")) + + return {**mappings, **vision_mappings} diff --git a/fastdeploy/multimodal/registry.py b/fastdeploy/multimodal/registry.py index 502082cb8..ab6c701dd 100644 --- a/fastdeploy/multimodal/registry.py +++ b/fastdeploy/multimodal/registry.py @@ -20,7 +20,11 @@ class MultimodalRegistry: A registry for multimodal models """ - mm_models: set[str] = {"Ernie4_5_VLMoeForConditionalGeneration", "Ernie5MoeForCausalLM"} + mm_models: set[str] = { + "Ernie4_5_VLMoeForConditionalGeneration", + "Ernie5MoeForCausalLM", + "Qwen2_5_VLForConditionalGeneration", + } @classmethod def contains_model(cls, name: str) -> bool: diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py index e3e3f4e38..f6c390120 100644 --- a/fastdeploy/rl/rollout_model.py +++ b/fastdeploy/rl/rollout_model.py @@ -33,6 +33,10 @@ from fastdeploy.model_executor.models.qwen2 import ( Qwen2ForCausalLM, Qwen2PretrainedModel, ) +from fastdeploy.model_executor.models.qwen2_5_vl.qwen2_5_vl import ( + Qwen2_5_VLForConditionalGeneration, + Qwen2_5_VLPretrainedModel, +) from fastdeploy.model_executor.models.qwen3 import ( Qwen3ForCausalLM, Qwen3PretrainedModel, @@ -477,3 +481,51 @@ class Qwen3ForCausalLMRL(Qwen3ForCausalLM, BaseRLModel): self._complete_missing_mappings() return self.infer_to_train_mapping + + +class Qwen2_5_VLForConditionalGenerationRL(Qwen2_5_VLForConditionalGeneration, BaseRLModel): + """ + Qwen2_5_VLForConditionalGenerationRL + """ + + _get_tensor_parallel_mappings = Qwen2_5_VLPretrainedModel._get_tensor_parallel_mappings + + def __init__(self, fd_config: FDConfig): + """ + Args: + fd_config (FDConfig): Configurations for the LLM model. + """ + super(Qwen2_5_VLForConditionalGenerationRL, self).__init__(fd_config) + + @classmethod + def name(self) -> str: + """name""" + return "Qwen2_5_VLForConditionalGenerationRL" + + def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: + if self._mappings_built: + return self.infer_to_train_mapping + + self.infer_to_train_mapping = {} + self._mappings_built = True + # Prepare placeholders + place_holders = ["weight"] + + # Initialize mapping dictionary + self._update_base_mappings("model") + base_name = "model.layers" + + # Helper function to add layer mappings + def _add_layer_mappings(layer_idx): + # FFN mappings + for ph in place_holders: + self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.up_gate_proj.{ph}"] = ( + f"{base_name}.{layer_idx}.mlp.gate_up_fused_proj.{ph}" + ) + + for layer_idx in range(self.fd_config.model_config.num_hidden_layers): + _add_layer_mappings(layer_idx) + + self._complete_missing_mappings() + + return self.infer_to_train_mapping diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 95649f0a6..6b042fbbe 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -103,7 +103,8 @@ class GPUModelRunner(ModelRunnerBase): # VL model config: if self.enable_mm: - self._init_image_preprocess() + if "ernie" in self.fd_config.model_config.model_type: + self._init_image_preprocess() self.amp_black = [ "reduce_sum", @@ -242,7 +243,8 @@ class GPUModelRunner(ModelRunnerBase): dtype=paddle.int64, ) vision_inputs["images"] = paddle.to_tensor( - inputs["images"][request.image_start : request.image_end], dtype="uint8" + inputs["images"][request.image_start : request.image_end], + dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16", ) vision_inputs["grid_thw"] = paddle.to_tensor( inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64" @@ -797,6 +799,11 @@ class GPUModelRunner(ModelRunnerBase): if self.enable_mm: head_dim = self.model_config.head_dim + if "qwen" in self.model_config.model_type: # neox style = True + rope_head_dim = head_dim + else: # neox style = False + rope_head_dim = head_dim // 2 + self.share_inputs["rope_emb"] = paddle.full( shape=[ max_num_seqs, @@ -804,14 +811,16 @@ class GPUModelRunner(ModelRunnerBase): 1, self.parallel_config.max_model_len, 1, - head_dim // 2, + rope_head_dim, ], fill_value=0, dtype="float32", ) self.share_inputs["image_features"] = None self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") - self.share_inputs["enable_thinking"] = paddle.full(shape=[1], fill_value=True, dtype="bool") + self.share_inputs["enable_thinking"] = paddle.full( + shape=[1], fill_value=("ernie" in self.model_config.model_type), dtype="bool" + ) self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") def _prepare_inputs(self) -> None: @@ -1186,7 +1195,7 @@ class GPUModelRunner(ModelRunnerBase): accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None), - think_end_id=(self.model_config.think_end_id if self.enable_mm else -1), + think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1), need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None), reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None), stop_token_ids=self.share_inputs["stop_seqs"], @@ -1476,7 +1485,7 @@ class GPUModelRunner(ModelRunnerBase): accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None), - think_end_id=(self.model_config.think_end_id if self.enable_mm else -1), + think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1), need_think_end=(self.share_inputs["need_think_end"][:num_running_requests] if self.enable_mm else None), reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None), stop_token_ids=self.share_inputs["stop_seqs"], @@ -1720,7 +1729,7 @@ class GPUModelRunner(ModelRunnerBase): image_type_ids = one["image_type_ids"][np.newaxis, :] images = one["images"] image_type_ids = paddle.to_tensor(image_type_ids, dtype=paddle.int64) - images = paddle.to_tensor(images, dtype="uint8") + images = paddle.to_tensor(images, dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16") grid_thw = paddle.to_tensor(one["grid_thw"], dtype="int64") else: image_type_ids = None @@ -1742,12 +1751,10 @@ class GPUModelRunner(ModelRunnerBase): ) return result - @paddle.no_grad() - def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: - """extract_vision_features""" + def extract_vision_features_ernie(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: assert inputs["images"] is not None grid_thw = inputs["grid_thw"] - + # ernie-vl has images norm images = inputs["images"].cast("float32") images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor images = images / self.image_preprocess.image_std_tensor @@ -1772,6 +1779,7 @@ class GPUModelRunner(ModelRunnerBase): image_features = image_features.reshape([-1, C * self.model_config.spatial_conv_size**2]) image_features = ScatterOp.apply(image_features, axis=-1) # mp 切 Fea image_features = image_features.reshape([S, -1]) + # ernie-vl has resampler_model image_features = self.model.resampler_model( image_features, image_mask, @@ -1781,6 +1789,31 @@ class GPUModelRunner(ModelRunnerBase): ) return image_features + def extract_vision_features_qwen(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: + assert inputs["images"] is not None + grid_thw = inputs["grid_thw"] + images = inputs["images"] + with paddle.amp.auto_cast( + True, + custom_black_list=self.amp_black, + custom_white_list=self.amp_white, + level="O2", + dtype=self.parallel_config.dtype, + ): + image_features = self.model.visual.extract_feature(images, grid_thw) + + return image_features + + @paddle.no_grad() + def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: + """extract_vision_features""" + if "ernie" in self.model_config.model_type: + return self.extract_vision_features_ernie(inputs) + elif "qwen" in self.model_config.model_type: + return self.extract_vision_features_qwen(inputs) + else: + raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported") + @paddle.no_grad() def prepare_rope3d(self, position_ids: paddle.Tensor, max_len: int) -> paddle.Tensor: """prepare_rope3d""" @@ -1800,5 +1833,6 @@ class GPUModelRunner(ModelRunnerBase): base=self.model_config.rope_theta, max_position=self.parallel_config.max_model_len, freq_allocation=getattr(self.model_config, "freq_allocation", 20), + model_type=self.model_config.model_type, ) return rope_emb