""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import math from typing import Optional, Tuple import paddle import paddle.nn as nn from fastdeploy.config import ModelConfig from fastdeploy.platforms import current_platform if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import fused_rotary_position_encoding from .utils import CpuGuard class ErnieRotaryEmbedding: def __init__(self, rotary_dim, base, partial_rotary_factor): """ Pre-calculate rotary position embedding for position_ids. """ self.rotary_dim = rotary_dim self.base = base self.partial_rotary_factor = partial_rotary_factor def __call__(self, position_ids): bsz, max_seq_len = position_ids.shape[:2] inv_freq = self.base**( -paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim) partial_rotary_position_ids = position_ids / self.partial_rotary_factor freqs = paddle.einsum("ij,k->ijk", partial_rotary_position_ids.cast("float32"), inv_freq) if paddle.is_compiled_with_xpu(): # shape: [B, S, D] rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32") emb = paddle.stack([freqs, freqs], axis=-1).reshape( (bsz, max_seq_len, self.rotary_dim)) else: # shape: [B, S, D/2] rot_emb = paddle.zeros( (2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32") emb = paddle.stack([freqs], axis=-1).reshape( (bsz, max_seq_len, self.rotary_dim // 2)) # shape: [B, S, 1, D] emb = paddle.unsqueeze(emb, 2) rot_emb[0] = paddle.cos(emb) rot_emb[1] = paddle.sin(emb) if paddle.is_compiled_with_custom_device("npu"): return (paddle.concat([rot_emb, rot_emb], axis=3).transpose( [0, 1, 2, 4, 3]).reshape([2, bsz, max_seq_len, 1, self.rotary_dim])) else: return rot_emb class QwenRotaryEmbedding: def __init__(self, rotary_dim, base, partial_rotary_factor): """ Pre-calculate rotary position embedding for position_ids. """ self.rotary_dim = rotary_dim self.base = base self.partial_rotary_factor = partial_rotary_factor def __call__(self, position_ids): bsz, max_seq_len = position_ids.shape[:2] rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32") inv_freq = self.base**( -paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim) # shape: [B, S, D/2] freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) # shape: [B, S, 1, D] emb = paddle.concat([freqs, freqs], axis=-1).reshape( (bsz, max_seq_len, 1, self.rotary_dim)) rot_emb[0] = paddle.cos(emb) rot_emb[1] = paddle.sin(emb) return rot_emb def yarn_get_mscale(scale=1, mscale=1): """ """ if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): """ """ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): """ """ low = math.floor( yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) high = math.ceil( yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) return max(low, 0), min(high, dim - 1) # Clamp values just in case def yarn_linear_ramp_mask(min, max, dim): """ """ if min == max: max += 0.001 # Prevent singularity linear_func = (paddle.arange(dim, dtype=paddle.float32) - min) / (max - min) ramp_func = paddle.clip(linear_func, 0, 1) return ramp_func class DeepseekScalingRotaryEmbedding(nn.Layer): """RotaryEmbedding extended with YaRN method. Credits to Peng et al. github.com/jquesnelle/yarn Args: rotary_dim(int): Dimension of rotary embeddings (head dimension) max_position_embeddings(int): Original training context length base(float): Base value used to compute the inverse frequencies. scaling_factor(float): Context extension scaling ratio (target_len / original_len) extrapolation_factor(float): Weight for extrapolated frequencies (default=1) attn_factor(float): Attention magnitude scaling factor (default=1) beta_fast(int): High-frequency correction cutoff (default=32) beta_slow(int): Low-frequency correction cutoff (default=1) mscale(float): Primary magnitude scaling factor (default=1) mscale_all_dim(float): Alternate magnitude scaling factor (default=0) """ def __init__( self, rotary_dim: int, max_position_embeddings: int, base: int, scaling_factor: float, *, extrapolation_factor: float = 1, attn_factor: float = 1, beta_fast: int = 32, beta_slow: int = 1, mscale: float = 1, mscale_all_dim: float = 0, ) -> None: super().__init__() self._dtype = paddle.get_default_dtype() self.rotary_dim = rotary_dim self.max_position_embeddings = max_position_embeddings self.base = base self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow # Get n-d magnitude scaling corrected for interpolation. self.mscale = float( yarn_get_mscale(self.scaling_factor, float(mscale)) / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * attn_factor) cache = self._compute_cos_sin_cache() self.cos_sin_cache: paddle.Tensor self.register_buffer("cos_sin_cache", cache, persistable=True) def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor: pos_freqs = self.base**( paddle.arange(0, self.rotary_dim, 2, dtype=paddle.float32) / self.rotary_dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow, self.rotary_dim, self.base, self.max_position_embeddings) # Get n-d rotational scaling corrected for extrapolation inv_freq_mask = (1 - yarn_linear_ramp_mask( low, high, self.rotary_dim // 2)) * self.extrapolation_factor inv_freq = inv_freq_interpolation * ( 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask return inv_freq def _compute_cos_sin_cache(self) -> paddle.Tensor: inv_freq = self._compute_inv_freq(self.scaling_factor) t = paddle.arange(self.max_position_embeddings * self.scaling_factor, dtype=paddle.float32) freqs = paddle.einsum("i,j->ij", t, inv_freq) cos = freqs.cos() * self.mscale sin = freqs.sin() * self.mscale cache = paddle.concat((cos, sin), axis=-1) return cache.cast(self._dtype) def forward( self, position_ids: paddle.Tensor, query: paddle.Tensor, key: paddle.Tensor, ) -> Tuple[paddle.Tensor, paddle.Tensor]: """ """ # In-place operations that update the query and key tensors. fused_rotary_position_encoding(query, key, position_ids, self.cos_sin_cache, self.rotary_dim, False) return query, key def get_rope_impl( rotary_dim: int, base: 10000.0, position_ids: paddle.Tensor, model_config: Optional[ModelConfig] = None, partial_rotary_factor=1, ) -> paddle.Tensor: """ The real implementation of get_rope """ architecture = model_config.architectures[0] if model_config is None or architecture.startswith("Qwen"): rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base, partial_rotary_factor) rotary_emb = rotary_emb_layer(position_ids) else: rotary_emb_layer = ErnieRotaryEmbedding(rotary_dim, base, partial_rotary_factor) rotary_emb = rotary_emb_layer(position_ids) return rotary_emb def get_rope_xpu( rotary_dim: int, base: 10000.0, position_ids: paddle.Tensor, model_config: Optional[ModelConfig] = None, partial_rotary_factor=1, ) -> paddle.Tensor: """ In XPU, cos and sin compute must be done on cpu """ with CpuGuard(): position_ids = position_ids.cpu() rotary_emb = get_rope_impl(rotary_dim, base, position_ids, model_config, partial_rotary_factor) return rotary_emb.to('xpu') def get_rope( rotary_dim: int, base: 10000.0, position_ids: paddle.Tensor, model_config: Optional[ModelConfig] = None, partial_rotary_factor: int = 1, ) -> paddle.Tensor: """ Pre-calculate rotary position embedding for position_ids. Args: rotary_dim (int): Dimension of rotary embeddings (head dimension) base (float, optional): Base value used to compute the inverse frequencies. Default: 10000.0. position_ids (paddle.Tensor): Tensor containing position indices of input tokens. model_config (Optional[ModelConfig]): Model configuration object containing architecture information. If provided, determines RoPE implementation based on model architecture. partial_rotary_factor (int, optional): Factor controlling partial rotary application. Default: 1 (apply to all dimensions). """ if current_platform.is_xpu(): return get_rope_xpu(rotary_dim, base, position_ids, model_config, partial_rotary_factor) else: return get_rope_impl(rotary_dim, base, position_ids, model_config, partial_rotary_factor) class ErnieVlRotaryEmbedding3D: def __init__(self, rotary_dim, base, partial_rotary_factor, max_position, freq_allocation): self.rotary_dim = rotary_dim self.base = base self.paritial_rotary_factor = partial_rotary_factor self.max_position = max_position self.freq_allocation = freq_allocation def __call__(self, position_ids): rot_emb = paddle.zeros( (2, 1, self.max_position, 1, self.rotary_dim // 2), dtype="float32") # position_ids_3d: [bsz, seq_len, 3] position_ids_3d = paddle.tile( paddle.arange(self.max_position, dtype="int64").unsqueeze(0).unsqueeze(-1), [1, 1, 3]) position_ids_3d[:, :position_ids.shape[1], :] = position_ids # import pdb;pdb.set_trace() # position_ids: [bsz, seq_len] position_ids = paddle.arange(0, self.max_position, 1, dtype="float32").reshape((1, -1)) position_ids = position_ids / self.paritial_rotary_factor indices = paddle.arange(0, self.rotary_dim, 2, dtype="float32") indices = 1 / self.base**(indices / self.rotary_dim) # sinusoid_inp: [bsz, seq_len, 1, head_dim // 2] sinusoid_inp = position_ids.unsqueeze(-1) * indices.unsqueeze(0) # pos_emb: [bsz, seq_len, 1, head_dim] pos_emb = paddle.concat( [paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1) # pos_emb: [bsz, 1, seq_len, head_dim] pos_emb = paddle.reshape(pos_emb, (-1, 1, self.max_position, self.rotary_dim)) # pos_emb: [bsz, seq_len, 1, head_dim] pos_emb = pos_emb.transpose([0, 2, 1, 3]) # sin: [bsz, seq_len, 1, head_dim // 2] sin, cos = paddle.chunk(pos_emb, 2, axis=-1) batch_indices = paddle.arange(end=position_ids.shape[0]).cast("int64") # batch_indices: [[0]] batch_indices = batch_indices[..., None] # sin, cos: [3, seq_len, 1, head_dim // 2] sin = sin.tile([position_ids.shape[0], 1, 1, 1]) cos = cos.tile([position_ids.shape[0], 1, 1, 1]) tmp_pos_id_0 = position_ids_3d[..., 0].squeeze().astype("int64") tmp_pos_id_1 = position_ids_3d[..., 1].squeeze().astype("int64") tmp_pos_id_2 = position_ids_3d[..., 2].squeeze().astype("int64") sin_bsz = paddle.index_select(sin, index=batch_indices, axis=0) sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, -self.freq_allocation:] sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1, axis=1)[:, :, :, :self.rotary_dim // 2 - self.freq_allocation:2] sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2, axis=1)[:, :, :, 1:self.rotary_dim // 2 - self.freq_allocation:2] sin_hw = paddle.stack([sin_h, sin_w], axis=-1).reshape(sin_h.shape[:-1] + [sin_h.shape[-1] * 2]) sin_thw = paddle.concat([sin_hw, sin_t], axis=-1) # noqa cos_bsz = paddle.index_select(cos, index=batch_indices, axis=0) cos_t = paddle.index_select(cos_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, -self.freq_allocation:] cos_h = paddle.index_select(cos_bsz, index=tmp_pos_id_1, axis=1)[:, :, :, :self.rotary_dim // 2 - self.freq_allocation:2] cos_w = paddle.index_select(cos_bsz, index=tmp_pos_id_2, axis=1)[:, :, :, 1:self.rotary_dim // 2 - self.freq_allocation:2] cos_hw = paddle.stack([cos_h, cos_w], axis=-1).reshape(cos_h.shape[:-1] + [cos_h.shape[-1] * 2]) cos_thw = paddle.concat([cos_hw, cos_t], axis=-1) # noqa rot_emb[0] = cos_thw # noqa rot_emb[1] = sin_thw # noqa return rot_emb def get_rope_3d( rotary_dim: int, base: 10000, position_ids, paritial_rotary_factor: 1, max_position: 131072, freq_allocation: 2, ) -> paddle.Tensor: """ Pre-calculate rotary position embedding for position_ids. Args: rotary_dim (int): Dimension of rotary embeddings (head dimension) base (float, optional): Base value used to compute the inverse frequencies. Default: 10000.0. position_ids (paddle.Tensor): Tensor containing position indices of input tokens. partial_rotary_factor (int, optional): Factor controlling partial rotary application. Default: 1 (apply to all dimensions). max_position: Maximum position index to precompute. freq_allocation: Number of rotary dimensions allocated to temporal axis """ rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(rotary_dim, base, paritial_rotary_factor, max_position, freq_allocation) rotary_emb_3d = rotary_emb3d_layer(position_ids) return rotary_emb_3d