Files
FastDeploy/fastdeploy/model_executor/layers/rotary_embedding.py
Kane2011 b4fef2cf29 [MetaxGPU] Support FastDeploy on metax gpu (#3241)
* [MetaxGPU] Support FastDeploy on metax gpu

* Update metax_worker.py

1. change worker log;
2. remove custom allreduce, adapt it later;
3. remove cuda graph;

* Update __init__.py

1. remove metax's key work comment

* Update __init__.py

1. remove metax's key word comment;
2. add fused_moe_kernel_paddle import

---------

Co-authored-by: yongqiangma <xing.wo@163.com>
2025-08-13 11:11:54 +08:00

416 lines
16 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import math
from typing import Optional, Tuple
import paddle
from paddle import nn
from fastdeploy.config import ModelConfig
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import fused_rotary_position_encoding
from .utils import CpuGuard
class ErnieRotaryEmbedding:
def __init__(self, rotary_dim, base, partial_rotary_factor):
"""
Pre-calculate rotary position embedding for position_ids.
"""
self.rotary_dim = rotary_dim
self.base = base
self.partial_rotary_factor = partial_rotary_factor
def __call__(self, position_ids):
bsz, max_seq_len = position_ids.shape[:2]
inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
partial_rotary_position_ids = position_ids / self.partial_rotary_factor
freqs = paddle.einsum("ij,k->ijk", partial_rotary_position_ids.cast("float32"), inv_freq)
if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_custom_device("iluvatar_gpu"):
# shape: [B, S, D]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
emb = paddle.stack([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim))
elif current_platform.is_gcu():
# shape: [B, S, D]
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
return rot_emb
elif paddle.is_compiled_with_custom_device("metax_gpu"):
# shape: [B, S, D]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
emb = paddle.stack([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim))
else:
# shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
# shape: [B, S, 1, D]
emb = paddle.unsqueeze(emb, 2)
rot_emb[0] = paddle.cos(emb)
rot_emb[1] = paddle.sin(emb)
if paddle.is_compiled_with_custom_device("npu"):
return (
paddle.concat([rot_emb, rot_emb], axis=3)
.transpose([0, 1, 2, 4, 3])
.reshape([2, bsz, max_seq_len, 1, self.rotary_dim])
)
else:
return rot_emb
class QwenRotaryEmbedding:
def __init__(self, rotary_dim, base, partial_rotary_factor):
"""
Pre-calculate rotary position embedding for position_ids.
"""
self.rotary_dim = rotary_dim
self.base = base
self.partial_rotary_factor = partial_rotary_factor
def __call__(self, position_ids):
bsz, max_seq_len = position_ids.shape[:2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
# shape: [B, S, D/2]
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
if current_platform.is_gcu():
# shape: [B, S, D]
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
return rot_emb
# shape: [B, S, 1, D]
emb = paddle.concat([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, 1, self.rotary_dim))
rot_emb[0] = paddle.cos(emb)
rot_emb[1] = paddle.sin(emb)
return rot_emb
def yarn_get_mscale(scale=1, mscale=1):
""" """
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
""" """
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
""" """
low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1) # Clamp values just in case
def yarn_linear_ramp_mask(min, max, dim):
""" """
if min == max:
max += 0.001 # Prevent singularity
linear_func = (paddle.arange(dim, dtype=paddle.float32) - min) / (max - min)
ramp_func = paddle.clip(linear_func, 0, 1)
return ramp_func
class DeepseekScalingRotaryEmbedding(nn.Layer):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
Args:
rotary_dim(int): Dimension of rotary embeddings (head dimension)
max_position_embeddings(int): Original training context length
base(float): Base value used to compute the inverse frequencies.
scaling_factor(float): Context extension scaling ratio (target_len / original_len)
extrapolation_factor(float): Weight for extrapolated frequencies (default=1)
attn_factor(float): Attention magnitude scaling factor (default=1)
beta_fast(int): High-frequency correction cutoff (default=32)
beta_slow(int): Low-frequency correction cutoff (default=1)
mscale(float): Primary magnitude scaling factor (default=1)
mscale_all_dim(float): Alternate magnitude scaling factor (default=0)
"""
def __init__(
self,
rotary_dim: int,
max_position_embeddings: int,
base: int,
scaling_factor: float,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
super().__init__()
self._dtype = paddle.get_default_dtype()
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale))
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
* attn_factor
)
cache = self._compute_cos_sin_cache()
self.cos_sin_cache: paddle.Tensor
self.register_buffer("cos_sin_cache", cache, persistable=True)
def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor:
pos_freqs = self.base ** (paddle.arange(0, self.rotary_dim, 2, dtype=paddle.float32) / self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
self.rotary_dim,
self.base,
self.max_position_embeddings,
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - yarn_linear_ramp_mask(low, high, self.rotary_dim // 2)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
def _compute_cos_sin_cache(self) -> paddle.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = paddle.arange(
self.max_position_embeddings * self.scaling_factor,
dtype=paddle.float32,
)
freqs = paddle.einsum("i,j->ij", t, inv_freq)
cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale
cache = paddle.concat((cos, sin), axis=-1)
return cache.cast(self._dtype)
def forward(
self,
position_ids: paddle.Tensor,
query: paddle.Tensor,
key: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
""" """
# In-place operations that update the query and key tensors.
fused_rotary_position_encoding(query, key, position_ids, self.cos_sin_cache, self.rotary_dim, False)
return query, key
def get_rope_impl(
rotary_dim: int,
base: 10000.0,
position_ids: paddle.Tensor,
model_config: Optional[ModelConfig] = None,
partial_rotary_factor=1,
) -> paddle.Tensor:
"""
The real implementation of get_rope
"""
architecture = model_config.architectures[0]
if model_config is None or architecture.startswith("Qwen"):
rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
rotary_emb = rotary_emb_layer(position_ids)
else:
rotary_emb_layer = ErnieRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
rotary_emb = rotary_emb_layer(position_ids)
return rotary_emb
def get_rope_xpu(
rotary_dim: int,
base: 10000.0,
position_ids: paddle.Tensor,
model_config: Optional[ModelConfig] = None,
partial_rotary_factor=1,
) -> paddle.Tensor:
"""
In XPU, cos and sin compute must be done on cpu
"""
with CpuGuard():
position_ids = position_ids.cpu()
rotary_emb = get_rope_impl(rotary_dim, base, position_ids, model_config, partial_rotary_factor)
return rotary_emb.to("xpu")
def get_rope(
rotary_dim: int,
base: 10000.0,
position_ids: paddle.Tensor,
model_config: Optional[ModelConfig] = None,
partial_rotary_factor: int = 1,
) -> paddle.Tensor:
"""
Pre-calculate rotary position embedding for position_ids.
Args:
rotary_dim (int):
Dimension of rotary embeddings (head dimension)
base (float, optional):
Base value used to compute the inverse frequencies.
Default: 10000.0.
position_ids (paddle.Tensor):
Tensor containing position indices of input tokens.
model_config (Optional[ModelConfig]):
Model configuration object containing architecture information.
If provided, determines RoPE implementation based on model architecture.
partial_rotary_factor (int, optional):
Factor controlling partial rotary application.
Default: 1 (apply to all dimensions).
"""
if current_platform.is_xpu():
return get_rope_xpu(rotary_dim, base, position_ids, model_config, partial_rotary_factor)
else:
return get_rope_impl(rotary_dim, base, position_ids, model_config, partial_rotary_factor)
class ErnieVlRotaryEmbedding3D:
def __init__(
self,
rotary_dim,
base,
partial_rotary_factor,
max_position,
freq_allocation,
):
self.rotary_dim = rotary_dim
self.base = base
self.paritial_rotary_factor = partial_rotary_factor
self.max_position = max_position
self.freq_allocation = freq_allocation
def __call__(self, position_ids):
rot_emb = paddle.zeros((2, 1, self.max_position, 1, self.rotary_dim // 2), dtype="float32")
# position_ids_3d: [bsz, seq_len, 3]
position_ids_3d = paddle.tile(
paddle.arange(self.max_position, dtype="int64").unsqueeze(0).unsqueeze(-1),
[1, 1, 3],
)
position_ids_3d[:, : position_ids.shape[1], :] = position_ids
# import pdb;pdb.set_trace()
# position_ids: [bsz, seq_len]
position_ids = paddle.arange(0, self.max_position, 1, dtype="float32").reshape((1, -1))
position_ids = position_ids / self.paritial_rotary_factor
indices = paddle.arange(0, self.rotary_dim, 2, dtype="float32")
indices = 1 / self.base ** (indices / self.rotary_dim)
# sinusoid_inp: [bsz, seq_len, 1, head_dim // 2]
sinusoid_inp = position_ids.unsqueeze(-1) * indices.unsqueeze(0)
# pos_emb: [bsz, seq_len, 1, head_dim]
pos_emb = paddle.concat([paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1)
# pos_emb: [bsz, 1, seq_len, head_dim]
pos_emb = paddle.reshape(pos_emb, (-1, 1, self.max_position, self.rotary_dim))
# pos_emb: [bsz, seq_len, 1, head_dim]
pos_emb = pos_emb.transpose([0, 2, 1, 3])
# sin: [bsz, seq_len, 1, head_dim // 2]
sin, cos = paddle.chunk(pos_emb, 2, axis=-1)
batch_indices = paddle.arange(end=position_ids.shape[0]).cast("int64")
# batch_indices: [[0]]
batch_indices = batch_indices[..., None]
# sin, cos: [3, seq_len, 1, head_dim // 2]
sin = sin.tile([position_ids.shape[0], 1, 1, 1])
cos = cos.tile([position_ids.shape[0], 1, 1, 1])
tmp_pos_id_0 = position_ids_3d[..., 0].squeeze().astype("int64")
tmp_pos_id_1 = position_ids_3d[..., 1].squeeze().astype("int64")
tmp_pos_id_2 = position_ids_3d[..., 2].squeeze().astype("int64")
sin_bsz = paddle.index_select(sin, index=batch_indices, axis=0)
sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, -self.freq_allocation :]
sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1, axis=1)[
:, :, :, : self.rotary_dim // 2 - self.freq_allocation : 2
]
sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2, axis=1)[
:, :, :, 1 : self.rotary_dim // 2 - self.freq_allocation : 2
]
sin_hw = paddle.stack([sin_h, sin_w], axis=-1).reshape(sin_h.shape[:-1] + [sin_h.shape[-1] * 2])
sin_thw = paddle.concat([sin_hw, sin_t], axis=-1)
cos_bsz = paddle.index_select(cos, index=batch_indices, axis=0)
cos_t = paddle.index_select(cos_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, -self.freq_allocation :]
cos_h = paddle.index_select(cos_bsz, index=tmp_pos_id_1, axis=1)[
:, :, :, : self.rotary_dim // 2 - self.freq_allocation : 2
]
cos_w = paddle.index_select(cos_bsz, index=tmp_pos_id_2, axis=1)[
:, :, :, 1 : self.rotary_dim // 2 - self.freq_allocation : 2
]
cos_hw = paddle.stack([cos_h, cos_w], axis=-1).reshape(cos_h.shape[:-1] + [cos_h.shape[-1] * 2])
cos_thw = paddle.concat([cos_hw, cos_t], axis=-1)
rot_emb[0] = cos_thw
rot_emb[1] = sin_thw
return rot_emb
def get_rope_3d(
rotary_dim: int,
base: float,
position_ids: paddle.Tensor,
partial_rotary_factor: float,
max_position: int,
freq_allocation: int,
) -> paddle.Tensor:
"""
Pre-calculate rotary position embedding for position_ids.
Args:
rotary_dim (int):
Dimension of rotary embeddings (head dimension)
base (float):
Base value used to compute the inverse frequencies.
Default: 10000.0.
position_ids (paddle.Tensor):
Tensor containing position indices of input tokens.
partial_rotary_factor (float):
Factor controlling partial rotary application.
Default: 1 (apply to all dimensions).
max_position: Maximum position index to precompute.
freq_allocation: Number of rotary dimensions allocated to temporal axis
"""
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(
rotary_dim, base, partial_rotary_factor, max_position, freq_allocation
)
rotary_emb_3d = rotary_emb3d_layer(position_ids)
return rotary_emb_3d