""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ from typing import Any, Optional import paddle class ErnieRotaryEmbedding: def __init__(self, rotary_dim, base, partial_rotary_factor, rope_scaling=None): """ Pre-calculate rotary position embedding for position_ids. """ self.rotary_dim = rotary_dim self.base = base self.partial_rotary_factor = partial_rotary_factor self.rope_scaling = rope_scaling def __call__(self, position_ids): bsz, max_seq_len = position_ids.shape[:2] inv_freq = self.base**( -paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim) partial_rotary_position_ids = position_ids / self.partial_rotary_factor freqs = paddle.einsum("ij,k->ijk", partial_rotary_position_ids.cast("float32"), inv_freq) if paddle.is_compiled_with_xpu(): # shape: [B, S, D] rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32") emb = paddle.stack([freqs, freqs], axis=-1).reshape( (bsz, max_seq_len, self.rotary_dim)) else: # shape: [B, S, D/2] rot_emb = paddle.zeros( (2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32") emb = paddle.stack([freqs], axis=-1).reshape( (bsz, max_seq_len, self.rotary_dim // 2)) # shape: [B, S, 1, D] emb = paddle.unsqueeze(emb, 2) rot_emb[0] = paddle.cos(emb) rot_emb[1] = paddle.sin(emb) if paddle.is_compiled_with_custom_device("npu"): return (paddle.concat([rot_emb, rot_emb], axis=3).transpose( [0, 1, 2, 4, 3]).reshape([2, bsz, max_seq_len, 1, self.rotary_dim])) else: return rot_emb class QwenRotaryEmbedding: def __init__(self, rotary_dim, base, partial_rotary_factor, rope_scaling=None): """ Pre-calculate rotary position embedding for position_ids. """ self.rotary_dim = rotary_dim self.base = base self.partial_rotary_factor = partial_rotary_factor self.rope_scaling = rope_scaling def __call__(self, position_ids): bsz, max_seq_len = position_ids.shape[:2] rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32") inv_freq = self.base**( -paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim) # shape: [B, S, D/2] freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) # shape: [B, S, 1, D] emb = paddle.concat([freqs, freqs], axis=-1).reshape( (bsz, max_seq_len, 1, self.rotary_dim)) rot_emb[0] = paddle.cos(emb) rot_emb[1] = paddle.sin(emb) return rot_emb def get_rope( rotary_dim: int, base: 10000.0, position_ids, partial_rotary_factor=1, rope_scaling: Optional[dict[str, Any]] = None, ): rope_type = rope_scaling.get("architectures", None) if "Qwen2ForCausalLM" in rope_type: rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base, partial_rotary_factor, rope_scaling) rotary_emb = rotary_emb_layer(position_ids) else: rotary_emb_layer = ErnieRotaryEmbedding(rotary_dim, base, partial_rotary_factor, rope_scaling) rotary_emb = rotary_emb_layer(position_ids) return rotary_emb class ErnieVlRotaryEmbedding3D: def __init__(self, rotary_dim, base, partial_rotary_factor, max_position, freq_allocation, rope_scaling): self.rotary_dim = rotary_dim self.base = base self.paritial_rotary_factor = partial_rotary_factor self.rope_scaling = rope_scaling self.max_position = max_position self.freq_allocation = freq_allocation def __call__(self, position_ids): rot_emb = paddle.zeros( (2, 1, self.max_position, 1, self.rotary_dim // 2), dtype="float32") # position_ids_3d: [bsz, seq_len, 3] position_ids_3d = paddle.tile( paddle.arange(self.max_position, dtype="int64").unsqueeze(0).unsqueeze(-1), [1, 1, 3]) position_ids_3d[:, :position_ids.shape[1], :] = position_ids # import pdb;pdb.set_trace() # position_ids: [bsz, seq_len] position_ids = paddle.arange(0, self.max_position, 1, dtype="float32").reshape((1, -1)) position_ids = position_ids / self.paritial_rotary_factor indices = paddle.arange(0, self.rotary_dim, 2, dtype="float32") indices = 1 / self.base**(indices / self.rotary_dim) # sinusoid_inp: [bsz, seq_len, 1, head_dim // 2] sinusoid_inp = position_ids.unsqueeze(-1) * indices.unsqueeze(0) # pos_emb: [bsz, seq_len, 1, head_dim] pos_emb = paddle.concat( [paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1) # pos_emb: [bsz, 1, seq_len, head_dim] pos_emb = paddle.reshape(pos_emb, (-1, 1, self.max_position, self.rotary_dim)) # pos_emb: [bsz, seq_len, 1, head_dim] pos_emb = pos_emb.transpose([0, 2, 1, 3]) # sin: [bsz, seq_len, 1, head_dim // 2] sin, cos = paddle.chunk(pos_emb, 2, axis=-1) batch_indices = paddle.arange(end=position_ids.shape[0]).cast("int64") # batch_indices: [[0]] batch_indices = batch_indices[..., None] # sin, cos: [3, seq_len, 1, head_dim // 2] sin = sin.tile([position_ids.shape[0], 1, 1, 1]) cos = cos.tile([position_ids.shape[0], 1, 1, 1]) tmp_pos_id_0 = position_ids_3d[..., 0].squeeze().astype("int64") tmp_pos_id_1 = position_ids_3d[..., 1].squeeze().astype("int64") tmp_pos_id_2 = position_ids_3d[..., 2].squeeze().astype("int64") sin_bsz = paddle.index_select(sin, index=batch_indices, axis=0) sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, -self.freq_allocation:] sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1, axis=1)[:, :, :, :self.rotary_dim // 2 - self.freq_allocation:2] sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2, axis=1)[:, :, :, 1:self.rotary_dim // 2 - self.freq_allocation:2] sin_hw = paddle.stack([sin_h, sin_w], axis=-1).reshape(sin_h.shape[:-1] + [sin_h.shape[-1] * 2]) sin_thw = paddle.concat([sin_hw, sin_t], axis=-1) # noqa cos_bsz = paddle.index_select(cos, index=batch_indices, axis=0) cos_t = paddle.index_select(cos_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, -self.freq_allocation:] cos_h = paddle.index_select(cos_bsz, index=tmp_pos_id_1, axis=1)[:, :, :, :self.rotary_dim // 2 - self.freq_allocation:2] cos_w = paddle.index_select(cos_bsz, index=tmp_pos_id_2, axis=1)[:, :, :, 1:self.rotary_dim // 2 - self.freq_allocation:2] cos_hw = paddle.stack([cos_h, cos_w], axis=-1).reshape(cos_h.shape[:-1] + [cos_h.shape[-1] * 2]) cos_thw = paddle.concat([cos_hw, cos_t], axis=-1) # noqa rot_emb[0] = cos_thw # noqa rot_emb[1] = sin_thw # noqa return rot_emb def get_rope_3d( rotary_dim: int, base: 10000, position_ids, paritial_rotary_factor: 1, max_position: 131072, freq_allocation: 2, rope_scaling: Optional[dict[str, Any]] = None, ): rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(rotary_dim, base, paritial_rotary_factor, max_position, freq_allocation, rope_scaling) rotary_emb_3d = rotary_emb3d_layer(position_ids) return rotary_emb_3d