polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -18,7 +18,7 @@ import math
from typing import Optional, Tuple
import paddle
import paddle.nn as nn
from paddle import nn
from fastdeploy.config import ModelConfig
from fastdeploy.platforms import current_platform
@@ -30,7 +30,6 @@ from .utils import CpuGuard
class ErnieRotaryEmbedding:
def __init__(self, rotary_dim, base, partial_rotary_factor):
"""
Pre-calculate rotary position embedding for position_ids.
@@ -41,45 +40,36 @@ class ErnieRotaryEmbedding:
def __call__(self, position_ids):
bsz, max_seq_len = position_ids.shape[:2]
inv_freq = self.base**(
-paddle.arange(0, self.rotary_dim, 2, dtype="float32") /
self.rotary_dim)
inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
partial_rotary_position_ids = position_ids / self.partial_rotary_factor
freqs = paddle.einsum("ij,k->ijk",
partial_rotary_position_ids.cast("float32"),
inv_freq)
if paddle.is_compiled_with_xpu(
) or paddle.is_compiled_with_custom_device("iluvatar_gpu"):
freqs = paddle.einsum("ij,k->ijk", partial_rotary_position_ids.cast("float32"), inv_freq)
if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_custom_device("iluvatar_gpu"):
# shape: [B, S, D]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim),
dtype="float32")
emb = paddle.stack([freqs, freqs], axis=-1).reshape(
(bsz, max_seq_len, self.rotary_dim))
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
emb = paddle.stack([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim))
elif current_platform.is_gcu():
# shape: [B, S, D]
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
return rot_emb
else:
# shape: [B, S, D/2]
rot_emb = paddle.zeros(
(2, bsz, max_seq_len, 1, self.rotary_dim // 2),
dtype="float32")
emb = paddle.stack([freqs], axis=-1).reshape(
(bsz, max_seq_len, self.rotary_dim // 2))
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
# shape: [B, S, 1, D]
emb = paddle.unsqueeze(emb, 2)
rot_emb[0] = paddle.cos(emb)
rot_emb[1] = paddle.sin(emb)
if paddle.is_compiled_with_custom_device("npu"):
return (paddle.concat([rot_emb, rot_emb], axis=3).transpose(
[0, 1, 2, 4,
3]).reshape([2, bsz, max_seq_len, 1, self.rotary_dim]))
return (
paddle.concat([rot_emb, rot_emb], axis=3)
.transpose([0, 1, 2, 4, 3])
.reshape([2, bsz, max_seq_len, 1, self.rotary_dim])
)
else:
return rot_emb
class QwenRotaryEmbedding:
def __init__(self, rotary_dim, base, partial_rotary_factor):
"""
Pre-calculate rotary position embedding for position_ids.
@@ -90,22 +80,17 @@ class QwenRotaryEmbedding:
def __call__(self, position_ids):
bsz, max_seq_len = position_ids.shape[:2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim),
dtype="float32")
inv_freq = self.base**(
-paddle.arange(0, self.rotary_dim, 2, dtype="float32") /
self.rotary_dim)
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
# shape: [B, S, D/2]
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"),
inv_freq)
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
if current_platform.is_gcu():
# shape: [B, S, D]
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
return rot_emb
# shape: [B, S, 1, D]
emb = paddle.concat([freqs, freqs], axis=-1).reshape(
(bsz, max_seq_len, 1, self.rotary_dim))
emb = paddle.concat([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, 1, self.rotary_dim))
rot_emb[0] = paddle.cos(emb)
rot_emb[1] = paddle.sin(emb)
@@ -114,46 +99,30 @@ class QwenRotaryEmbedding:
def yarn_get_mscale(scale=1, mscale=1):
"""
"""
""" """
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def yarn_find_correction_dim(num_rotations,
dim,
base=10000,
max_position_embeddings=2048):
"""
"""
return (dim * math.log(max_position_embeddings /
(num_rotations * 2 * math.pi))) / (2 *
math.log(base))
def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
""" """
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
def yarn_find_correction_range(low_rot,
high_rot,
dim,
base=10000,
max_position_embeddings=2048):
"""
"""
low = math.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
""" """
low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1) # Clamp values just in case
def yarn_linear_ramp_mask(min, max, dim):
"""
"""
""" """
if min == max:
max += 0.001 # Prevent singularity
linear_func = (paddle.arange(dim, dtype=paddle.float32) - min) / (max -
min)
linear_func = (paddle.arange(dim, dtype=paddle.float32) - min) / (max - min)
ramp_func = paddle.clip(linear_func, 0, 1)
return ramp_func
@@ -205,9 +174,10 @@ class DeepseekScalingRotaryEmbedding(nn.Layer):
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) /
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
yarn_get_mscale(self.scaling_factor, float(mscale))
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
* attn_factor
)
cache = self._compute_cos_sin_cache()
@@ -215,27 +185,29 @@ class DeepseekScalingRotaryEmbedding(nn.Layer):
self.register_buffer("cos_sin_cache", cache, persistable=True)
def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor:
pos_freqs = self.base**(
paddle.arange(0, self.rotary_dim, 2, dtype=paddle.float32) /
self.rotary_dim)
pos_freqs = self.base ** (paddle.arange(0, self.rotary_dim, 2, dtype=paddle.float32) / self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
low, high = yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
self.rotary_dim,
self.base,
self.max_position_embeddings,
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
inv_freq_mask = (1 - yarn_linear_ramp_mask(low, high, self.rotary_dim // 2)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
def _compute_cos_sin_cache(self) -> paddle.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = paddle.arange(self.max_position_embeddings * self.scaling_factor,
dtype=paddle.float32)
t = paddle.arange(
self.max_position_embeddings * self.scaling_factor,
dtype=paddle.float32,
)
freqs = paddle.einsum("i,j->ij", t, inv_freq)
cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale
@@ -248,12 +220,9 @@ class DeepseekScalingRotaryEmbedding(nn.Layer):
query: paddle.Tensor,
key: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""
"""
""" """
# In-place operations that update the query and key tensors.
fused_rotary_position_encoding(query, key, position_ids,
self.cos_sin_cache, self.rotary_dim,
False)
fused_rotary_position_encoding(query, key, position_ids, self.cos_sin_cache, self.rotary_dim, False)
return query, key
@@ -271,12 +240,10 @@ def get_rope_impl(
architecture = model_config.architectures[0]
if model_config is None or architecture.startswith("Qwen"):
rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base,
partial_rotary_factor)
rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
rotary_emb = rotary_emb_layer(position_ids)
else:
rotary_emb_layer = ErnieRotaryEmbedding(rotary_dim, base,
partial_rotary_factor)
rotary_emb_layer = ErnieRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
rotary_emb = rotary_emb_layer(position_ids)
return rotary_emb
@@ -293,9 +260,8 @@ def get_rope_xpu(
"""
with CpuGuard():
position_ids = position_ids.cpu()
rotary_emb = get_rope_impl(rotary_dim, base, position_ids,
model_config, partial_rotary_factor)
return rotary_emb.to('xpu')
rotary_emb = get_rope_impl(rotary_dim, base, position_ids, model_config, partial_rotary_factor)
return rotary_emb.to("xpu")
def get_rope(
@@ -324,17 +290,20 @@ def get_rope(
Default: 1 (apply to all dimensions).
"""
if current_platform.is_xpu():
return get_rope_xpu(rotary_dim, base, position_ids, model_config,
partial_rotary_factor)
return get_rope_xpu(rotary_dim, base, position_ids, model_config, partial_rotary_factor)
else:
return get_rope_impl(rotary_dim, base, position_ids, model_config,
partial_rotary_factor)
return get_rope_impl(rotary_dim, base, position_ids, model_config, partial_rotary_factor)
class ErnieVlRotaryEmbedding3D:
def __init__(self, rotary_dim, base, partial_rotary_factor, max_position,
freq_allocation):
def __init__(
self,
rotary_dim,
base,
partial_rotary_factor,
max_position,
freq_allocation,
):
self.rotary_dim = rotary_dim
self.base = base
self.paritial_rotary_factor = partial_rotary_factor
@@ -342,36 +311,31 @@ class ErnieVlRotaryEmbedding3D:
self.freq_allocation = freq_allocation
def __call__(self, position_ids):
rot_emb = paddle.zeros(
(2, 1, self.max_position, 1, self.rotary_dim // 2),
dtype="float32")
rot_emb = paddle.zeros((2, 1, self.max_position, 1, self.rotary_dim // 2), dtype="float32")
# position_ids_3d: [bsz, seq_len, 3]
position_ids_3d = paddle.tile(
paddle.arange(self.max_position,
dtype="int64").unsqueeze(0).unsqueeze(-1), [1, 1, 3])
paddle.arange(self.max_position, dtype="int64").unsqueeze(0).unsqueeze(-1),
[1, 1, 3],
)
position_ids_3d[:, :position_ids.shape[1], :] = position_ids
position_ids_3d[:, : position_ids.shape[1], :] = position_ids
# import pdb;pdb.set_trace()
# position_ids: [bsz, seq_len]
position_ids = paddle.arange(0, self.max_position, 1,
dtype="float32").reshape((1, -1))
position_ids = paddle.arange(0, self.max_position, 1, dtype="float32").reshape((1, -1))
position_ids = position_ids / self.paritial_rotary_factor
indices = paddle.arange(0, self.rotary_dim, 2, dtype="float32")
indices = 1 / self.base**(indices / self.rotary_dim)
indices = 1 / self.base ** (indices / self.rotary_dim)
# sinusoid_inp: [bsz, seq_len, 1, head_dim // 2]
sinusoid_inp = position_ids.unsqueeze(-1) * indices.unsqueeze(0)
# pos_emb: [bsz, seq_len, 1, head_dim]
pos_emb = paddle.concat(
[paddle.sin(sinusoid_inp),
paddle.cos(sinusoid_inp)], axis=-1)
pos_emb = paddle.concat([paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1)
# pos_emb: [bsz, 1, seq_len, head_dim]
pos_emb = paddle.reshape(pos_emb,
(-1, 1, self.max_position, self.rotary_dim))
pos_emb = paddle.reshape(pos_emb, (-1, 1, self.max_position, self.rotary_dim))
# pos_emb: [bsz, seq_len, 1, head_dim]
pos_emb = pos_emb.transpose([0, 2, 1, 3])
# sin: [bsz, seq_len, 1, head_dim // 2]
@@ -388,35 +352,29 @@ class ErnieVlRotaryEmbedding3D:
tmp_pos_id_2 = position_ids_3d[..., 2].squeeze().astype("int64")
sin_bsz = paddle.index_select(sin, index=batch_indices, axis=0)
sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0,
axis=1)[:, :, :, -self.freq_allocation:]
sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1,
axis=1)[:, :, :, :self.rotary_dim // 2 -
self.freq_allocation:2]
sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2,
axis=1)[:, :, :, 1:self.rotary_dim // 2 -
self.freq_allocation:2]
sin_hw = paddle.stack([sin_h, sin_w],
axis=-1).reshape(sin_h.shape[:-1] +
[sin_h.shape[-1] * 2])
sin_thw = paddle.concat([sin_hw, sin_t], axis=-1) # noqa
sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, -self.freq_allocation :]
sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1, axis=1)[
:, :, :, : self.rotary_dim // 2 - self.freq_allocation : 2
]
sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2, axis=1)[
:, :, :, 1 : self.rotary_dim // 2 - self.freq_allocation : 2
]
sin_hw = paddle.stack([sin_h, sin_w], axis=-1).reshape(sin_h.shape[:-1] + [sin_h.shape[-1] * 2])
sin_thw = paddle.concat([sin_hw, sin_t], axis=-1)
cos_bsz = paddle.index_select(cos, index=batch_indices, axis=0)
cos_t = paddle.index_select(cos_bsz, index=tmp_pos_id_0,
axis=1)[:, :, :, -self.freq_allocation:]
cos_h = paddle.index_select(cos_bsz, index=tmp_pos_id_1,
axis=1)[:, :, :, :self.rotary_dim // 2 -
self.freq_allocation:2]
cos_w = paddle.index_select(cos_bsz, index=tmp_pos_id_2,
axis=1)[:, :, :, 1:self.rotary_dim // 2 -
self.freq_allocation:2]
cos_hw = paddle.stack([cos_h, cos_w],
axis=-1).reshape(cos_h.shape[:-1] +
[cos_h.shape[-1] * 2])
cos_thw = paddle.concat([cos_hw, cos_t], axis=-1) # noqa
cos_t = paddle.index_select(cos_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, -self.freq_allocation :]
cos_h = paddle.index_select(cos_bsz, index=tmp_pos_id_1, axis=1)[
:, :, :, : self.rotary_dim // 2 - self.freq_allocation : 2
]
cos_w = paddle.index_select(cos_bsz, index=tmp_pos_id_2, axis=1)[
:, :, :, 1 : self.rotary_dim // 2 - self.freq_allocation : 2
]
cos_hw = paddle.stack([cos_h, cos_w], axis=-1).reshape(cos_h.shape[:-1] + [cos_h.shape[-1] * 2])
cos_thw = paddle.concat([cos_hw, cos_t], axis=-1)
rot_emb[0] = cos_thw # noqa
rot_emb[1] = sin_thw # noqa
rot_emb[0] = cos_thw
rot_emb[1] = sin_thw
return rot_emb
@@ -446,9 +404,8 @@ def get_rope_3d(
max_position: Maximum position index to precompute.
freq_allocation: Number of rotary dimensions allocated to temporal axis
"""
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(rotary_dim, base,
partial_rotary_factor,
max_position,
freq_allocation)
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(
rotary_dim, base, partial_rotary_factor, max_position, freq_allocation
)
rotary_emb_3d = rotary_emb3d_layer(position_ids)
return rotary_emb_3d