[Model]support qwen2_5_vl (#3557)

* adapt qwen_2_5_vl model

* adapt qwen_2_5_vl VIT model

* adapt qwen2_5_vl images_embeds

* adapt qwen2_5_vl 3D rope

* adapt qwen2_5_vl 3D rope v2

* adapt qwen2_5_vl processor

* adapt qwen2_5_vl bypass resampler_model

* adapt qwen2_5_vl 绕过部分ernie逻辑

* adapt qwen2_5_vl 绕过部分ernie逻辑 v2

* adapt qwen2_5_vl 权重加载与命名修改

* adapt qwen2_5_vl 非必须think_end_id

* adapt qwen2_5_vl 区分多种模型的extract_vision_features

* fix:adapt qwen2_5_vl model

* adapt qwen2_5_vl norm

* adapt qwen2_5_vl  processor 更新

* adapt qwen2_5_vl image and video success

* adapt qwen2_5_vl 部分整理代码

* adapt qwen2_5_vl 支持多卡

* adapt qwen2_5_vl on latest develop

* adapt qwen2_5_vl RL

* adapt qwen2_5_vl 整理代码

* support noex rope3d

* adapt qwen2_5_vl add init.py

* adapt qwen2_5_vl add init.py v2

* adapt qwen2_5_vl remove space

* adapt qwen2_5_vl remove space v2

* adapt qwen2_5_vl pre-commit

* adapt qwen2_5_vl update

* adapt qwen2_5_vl pre-commit v2

* adapt qwen2_5_vl modify comments

* adapt qwen2_5_vl fix indentation

* adapt qwen2_5_vl fix indentation v2

---------

Co-authored-by: wangyafeng <wangyafeng@baidu.com>
Co-authored-by: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com>
Co-authored-by: CSWYF3634076 <58356743+CSWYF3634076@users.noreply.github.com>
This commit is contained in:
zhouchong
2025-08-29 18:28:39 +08:00
committed by GitHub
parent 65425bf858
commit ccd52b5596
10 changed files with 1718 additions and 17 deletions

View File

@@ -325,8 +325,6 @@ class ErnieVlRotaryEmbedding3D:
position_ids_3d[:, : position_ids.shape[1], :] = position_ids
# import pdb;pdb.set_trace()
# position_ids: [bsz, seq_len]
position_ids = paddle.arange(0, self.max_position, 1, dtype="float32").reshape((1, -1))
@@ -383,6 +381,100 @@ class ErnieVlRotaryEmbedding3D:
return rot_emb
class QwenVlRotaryEmbedding3D:
def __init__(
self,
rotary_dim,
base,
partial_rotary_factor,
max_position,
freq_allocation,
):
self.rotary_dim = rotary_dim
self.base = base
self.paritial_rotary_factor = partial_rotary_factor
self.max_position = max_position
self.freq_allocation = freq_allocation
def __call__(self, position_ids):
rot_emb = paddle.zeros((2, 1, self.max_position, 1, self.rotary_dim // 2), dtype="float32")
# position_ids_3d: [bsz, seq_len, 3]
position_ids_3d = paddle.tile(
paddle.arange(self.max_position, dtype="int64").unsqueeze(0).unsqueeze(-1),
[1, 1, 3],
)
position_ids_3d[:, : position_ids.shape[1], :] = position_ids
# position_ids: [bsz, seq_len]
position_ids = paddle.arange(0, self.max_position, 1, dtype="float32").reshape((1, -1))
position_ids = position_ids / self.paritial_rotary_factor
indices = paddle.arange(0, self.rotary_dim, 2, dtype="float32")
indices = 1 / self.base ** (indices / self.rotary_dim)
# sinusoid_inp: [bsz, seq_len, 1, head_dim // 2]
sinusoid_inp = position_ids.unsqueeze(-1) * indices.unsqueeze(0)
# pos_emb: [bsz, seq_len, 1, head_dim]
pos_emb = paddle.concat([paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1)
# pos_emb: [bsz, 1, seq_len, head_dim]
pos_emb = paddle.reshape(pos_emb, (-1, 1, self.max_position, self.rotary_dim))
# pos_emb: [bsz, seq_len, 1, head_dim]
pos_emb = pos_emb.transpose([0, 2, 1, 3])
# sin: [bsz, seq_len, 1, head_dim // 2]
sin, cos = paddle.chunk(pos_emb, 2, axis=-1)
batch_indices = paddle.arange(end=position_ids.shape[0]).cast("int64")
# batch_indices: [[0]]
batch_indices = batch_indices[..., None]
# sin, cos: [3, seq_len, 1, head_dim // 2]
sin = sin.tile([position_ids.shape[0], 1, 1, 1])
cos = cos.tile([position_ids.shape[0], 1, 1, 1])
tmp_pos_id_0 = position_ids_3d[..., 0].squeeze().astype("int64")
tmp_pos_id_1 = position_ids_3d[..., 1].squeeze().astype("int64")
tmp_pos_id_2 = position_ids_3d[..., 2].squeeze().astype("int64")
# sin_bsz = paddle.index_select(sin, index=batch_indices, axis=0)
# sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, -self.freq_allocation :]
# sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1, axis=1)[
# :, :, :, : self.rotary_dim // 2 - self.freq_allocation : 2
# ]
# sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2, axis=1)[
# :, :, :, 1 : self.rotary_dim // 2 - self.freq_allocation : 2
# ]
# sin_hw = paddle.stack([sin_h, sin_w], axis=-1).reshape(sin_h.shape[:-1] + [sin_h.shape[-1] * 2])
# sin_thw = paddle.concat([sin_hw, sin_t], axis=-1)
section_t = self.freq_allocation # 16
section_h = (self.rotary_dim // 2 - self.freq_allocation) // 2 # 24
section_w = (self.rotary_dim // 2 - self.freq_allocation) // 2 # 24
sin_bsz = paddle.index_select(sin, index=batch_indices, axis=0)
sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, :section_t]
sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1, axis=1)[:, :, :, section_t : section_t + section_h]
sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2, axis=1)[
:, :, :, section_t + section_h : section_t + section_h + section_w
]
sin_thw = paddle.concat([sin_t, sin_h, sin_w], axis=-1)
cos_bsz = paddle.index_select(cos, index=batch_indices, axis=0)
cos_t = paddle.index_select(cos_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, :section_t]
cos_h = paddle.index_select(cos_bsz, index=tmp_pos_id_1, axis=1)[:, :, :, section_t : section_t + section_h]
cos_w = paddle.index_select(cos_bsz, index=tmp_pos_id_2, axis=1)[
:, :, :, section_t + section_h : section_t + section_h + section_w
]
cos_thw = paddle.concat([cos_t, cos_h, cos_w], axis=-1)
rot_emb[0] = cos_thw
rot_emb[1] = sin_thw
# neox style need
rot_emb_neox = paddle.concat([rot_emb, rot_emb], axis=-1)
return rot_emb_neox
def get_rope_3d(
rotary_dim: int,
base: float,
@@ -390,6 +482,7 @@ def get_rope_3d(
partial_rotary_factor: float,
max_position: int,
freq_allocation: int,
model_type: str,
) -> paddle.Tensor:
"""
Pre-calculate rotary position embedding for position_ids.
@@ -407,9 +500,20 @@ def get_rope_3d(
Default: 1 (apply to all dimensions).
max_position: Maximum position index to precompute.
freq_allocation: Number of rotary dimensions allocated to temporal axis
model_type: Model type, such as 'ernie4_5_moe_vl' or 'qwen2_5_vl'.
"""
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(
rotary_dim, base, partial_rotary_factor, max_position, freq_allocation
)
if "ernie" in model_type:
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(
rotary_dim, base, partial_rotary_factor, max_position, freq_allocation
)
elif "qwen" in model_type:
rotary_emb3d_layer = QwenVlRotaryEmbedding3D(
rotary_dim, base, partial_rotary_factor, max_position, freq_allocation
)
else: # default ernie
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(
rotary_dim, base, partial_rotary_factor, max_position, freq_allocation
)
rotary_emb_3d = rotary_emb3d_layer(position_ids)
return rotary_emb_3d