mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -18,7 +18,7 @@ import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.config import ModelConfig
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -30,7 +30,6 @@ from .utils import CpuGuard
|
||||
|
||||
|
||||
class ErnieRotaryEmbedding:
|
||||
|
||||
def __init__(self, rotary_dim, base, partial_rotary_factor):
|
||||
"""
|
||||
Pre-calculate rotary position embedding for position_ids.
|
||||
@@ -41,45 +40,36 @@ class ErnieRotaryEmbedding:
|
||||
|
||||
def __call__(self, position_ids):
|
||||
bsz, max_seq_len = position_ids.shape[:2]
|
||||
inv_freq = self.base**(
|
||||
-paddle.arange(0, self.rotary_dim, 2, dtype="float32") /
|
||||
self.rotary_dim)
|
||||
inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
|
||||
partial_rotary_position_ids = position_ids / self.partial_rotary_factor
|
||||
freqs = paddle.einsum("ij,k->ijk",
|
||||
partial_rotary_position_ids.cast("float32"),
|
||||
inv_freq)
|
||||
if paddle.is_compiled_with_xpu(
|
||||
) or paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||
freqs = paddle.einsum("ij,k->ijk", partial_rotary_position_ids.cast("float32"), inv_freq)
|
||||
if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||
# shape: [B, S, D]
|
||||
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim),
|
||||
dtype="float32")
|
||||
emb = paddle.stack([freqs, freqs], axis=-1).reshape(
|
||||
(bsz, max_seq_len, self.rotary_dim))
|
||||
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
|
||||
emb = paddle.stack([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim))
|
||||
elif current_platform.is_gcu():
|
||||
# shape: [B, S, D]
|
||||
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
|
||||
return rot_emb
|
||||
else:
|
||||
# shape: [B, S, D/2]
|
||||
rot_emb = paddle.zeros(
|
||||
(2, bsz, max_seq_len, 1, self.rotary_dim // 2),
|
||||
dtype="float32")
|
||||
emb = paddle.stack([freqs], axis=-1).reshape(
|
||||
(bsz, max_seq_len, self.rotary_dim // 2))
|
||||
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
|
||||
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
|
||||
# shape: [B, S, 1, D]
|
||||
emb = paddle.unsqueeze(emb, 2)
|
||||
rot_emb[0] = paddle.cos(emb)
|
||||
rot_emb[1] = paddle.sin(emb)
|
||||
if paddle.is_compiled_with_custom_device("npu"):
|
||||
return (paddle.concat([rot_emb, rot_emb], axis=3).transpose(
|
||||
[0, 1, 2, 4,
|
||||
3]).reshape([2, bsz, max_seq_len, 1, self.rotary_dim]))
|
||||
return (
|
||||
paddle.concat([rot_emb, rot_emb], axis=3)
|
||||
.transpose([0, 1, 2, 4, 3])
|
||||
.reshape([2, bsz, max_seq_len, 1, self.rotary_dim])
|
||||
)
|
||||
else:
|
||||
return rot_emb
|
||||
|
||||
|
||||
class QwenRotaryEmbedding:
|
||||
|
||||
def __init__(self, rotary_dim, base, partial_rotary_factor):
|
||||
"""
|
||||
Pre-calculate rotary position embedding for position_ids.
|
||||
@@ -90,22 +80,17 @@ class QwenRotaryEmbedding:
|
||||
|
||||
def __call__(self, position_ids):
|
||||
bsz, max_seq_len = position_ids.shape[:2]
|
||||
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim),
|
||||
dtype="float32")
|
||||
inv_freq = self.base**(
|
||||
-paddle.arange(0, self.rotary_dim, 2, dtype="float32") /
|
||||
self.rotary_dim)
|
||||
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
|
||||
inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
|
||||
|
||||
# shape: [B, S, D/2]
|
||||
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"),
|
||||
inv_freq)
|
||||
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
|
||||
if current_platform.is_gcu():
|
||||
# shape: [B, S, D]
|
||||
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
|
||||
return rot_emb
|
||||
# shape: [B, S, 1, D]
|
||||
emb = paddle.concat([freqs, freqs], axis=-1).reshape(
|
||||
(bsz, max_seq_len, 1, self.rotary_dim))
|
||||
emb = paddle.concat([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, 1, self.rotary_dim))
|
||||
|
||||
rot_emb[0] = paddle.cos(emb)
|
||||
rot_emb[1] = paddle.sin(emb)
|
||||
@@ -114,46 +99,30 @@ class QwenRotaryEmbedding:
|
||||
|
||||
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
"""
|
||||
"""
|
||||
""" """
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def yarn_find_correction_dim(num_rotations,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
"""
|
||||
"""
|
||||
return (dim * math.log(max_position_embeddings /
|
||||
(num_rotations * 2 * math.pi))) / (2 *
|
||||
math.log(base))
|
||||
def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
||||
""" """
|
||||
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
||||
|
||||
|
||||
def yarn_find_correction_range(low_rot,
|
||||
high_rot,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
"""
|
||||
"""
|
||||
low = math.floor(
|
||||
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||
high = math.ceil(
|
||||
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||
def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
|
||||
""" """
|
||||
low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||
high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
||||
|
||||
|
||||
def yarn_linear_ramp_mask(min, max, dim):
|
||||
"""
|
||||
"""
|
||||
""" """
|
||||
if min == max:
|
||||
max += 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (paddle.arange(dim, dtype=paddle.float32) - min) / (max -
|
||||
min)
|
||||
linear_func = (paddle.arange(dim, dtype=paddle.float32) - min) / (max - min)
|
||||
ramp_func = paddle.clip(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
@@ -205,9 +174,10 @@ class DeepseekScalingRotaryEmbedding(nn.Layer):
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation.
|
||||
self.mscale = float(
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale)) /
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
|
||||
attn_factor)
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale))
|
||||
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
|
||||
* attn_factor
|
||||
)
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
|
||||
@@ -215,27 +185,29 @@ class DeepseekScalingRotaryEmbedding(nn.Layer):
|
||||
self.register_buffer("cos_sin_cache", cache, persistable=True)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor:
|
||||
pos_freqs = self.base**(
|
||||
paddle.arange(0, self.rotary_dim, 2, dtype=paddle.float32) /
|
||||
self.rotary_dim)
|
||||
pos_freqs = self.base ** (paddle.arange(0, self.rotary_dim, 2, dtype=paddle.float32) / self.rotary_dim)
|
||||
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow,
|
||||
self.rotary_dim, self.base,
|
||||
self.max_position_embeddings)
|
||||
low, high = yarn_find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
self.rotary_dim,
|
||||
self.base,
|
||||
self.max_position_embeddings,
|
||||
)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (1 - yarn_linear_ramp_mask(
|
||||
low, high, self.rotary_dim // 2)) * self.extrapolation_factor
|
||||
inv_freq = inv_freq_interpolation * (
|
||||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
inv_freq_mask = (1 - yarn_linear_ramp_mask(low, high, self.rotary_dim // 2)) * self.extrapolation_factor
|
||||
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> paddle.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = paddle.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
dtype=paddle.float32)
|
||||
t = paddle.arange(
|
||||
self.max_position_embeddings * self.scaling_factor,
|
||||
dtype=paddle.float32,
|
||||
)
|
||||
freqs = paddle.einsum("i,j->ij", t, inv_freq)
|
||||
cos = freqs.cos() * self.mscale
|
||||
sin = freqs.sin() * self.mscale
|
||||
@@ -248,12 +220,9 @@ class DeepseekScalingRotaryEmbedding(nn.Layer):
|
||||
query: paddle.Tensor,
|
||||
key: paddle.Tensor,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
"""
|
||||
""" """
|
||||
# In-place operations that update the query and key tensors.
|
||||
fused_rotary_position_encoding(query, key, position_ids,
|
||||
self.cos_sin_cache, self.rotary_dim,
|
||||
False)
|
||||
fused_rotary_position_encoding(query, key, position_ids, self.cos_sin_cache, self.rotary_dim, False)
|
||||
|
||||
return query, key
|
||||
|
||||
@@ -271,12 +240,10 @@ def get_rope_impl(
|
||||
|
||||
architecture = model_config.architectures[0]
|
||||
if model_config is None or architecture.startswith("Qwen"):
|
||||
rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base,
|
||||
partial_rotary_factor)
|
||||
rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
|
||||
rotary_emb = rotary_emb_layer(position_ids)
|
||||
else:
|
||||
rotary_emb_layer = ErnieRotaryEmbedding(rotary_dim, base,
|
||||
partial_rotary_factor)
|
||||
rotary_emb_layer = ErnieRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
|
||||
rotary_emb = rotary_emb_layer(position_ids)
|
||||
return rotary_emb
|
||||
|
||||
@@ -293,9 +260,8 @@ def get_rope_xpu(
|
||||
"""
|
||||
with CpuGuard():
|
||||
position_ids = position_ids.cpu()
|
||||
rotary_emb = get_rope_impl(rotary_dim, base, position_ids,
|
||||
model_config, partial_rotary_factor)
|
||||
return rotary_emb.to('xpu')
|
||||
rotary_emb = get_rope_impl(rotary_dim, base, position_ids, model_config, partial_rotary_factor)
|
||||
return rotary_emb.to("xpu")
|
||||
|
||||
|
||||
def get_rope(
|
||||
@@ -324,17 +290,20 @@ def get_rope(
|
||||
Default: 1 (apply to all dimensions).
|
||||
"""
|
||||
if current_platform.is_xpu():
|
||||
return get_rope_xpu(rotary_dim, base, position_ids, model_config,
|
||||
partial_rotary_factor)
|
||||
return get_rope_xpu(rotary_dim, base, position_ids, model_config, partial_rotary_factor)
|
||||
else:
|
||||
return get_rope_impl(rotary_dim, base, position_ids, model_config,
|
||||
partial_rotary_factor)
|
||||
return get_rope_impl(rotary_dim, base, position_ids, model_config, partial_rotary_factor)
|
||||
|
||||
|
||||
class ErnieVlRotaryEmbedding3D:
|
||||
|
||||
def __init__(self, rotary_dim, base, partial_rotary_factor, max_position,
|
||||
freq_allocation):
|
||||
def __init__(
|
||||
self,
|
||||
rotary_dim,
|
||||
base,
|
||||
partial_rotary_factor,
|
||||
max_position,
|
||||
freq_allocation,
|
||||
):
|
||||
self.rotary_dim = rotary_dim
|
||||
self.base = base
|
||||
self.paritial_rotary_factor = partial_rotary_factor
|
||||
@@ -342,36 +311,31 @@ class ErnieVlRotaryEmbedding3D:
|
||||
self.freq_allocation = freq_allocation
|
||||
|
||||
def __call__(self, position_ids):
|
||||
rot_emb = paddle.zeros(
|
||||
(2, 1, self.max_position, 1, self.rotary_dim // 2),
|
||||
dtype="float32")
|
||||
rot_emb = paddle.zeros((2, 1, self.max_position, 1, self.rotary_dim // 2), dtype="float32")
|
||||
|
||||
# position_ids_3d: [bsz, seq_len, 3]
|
||||
position_ids_3d = paddle.tile(
|
||||
paddle.arange(self.max_position,
|
||||
dtype="int64").unsqueeze(0).unsqueeze(-1), [1, 1, 3])
|
||||
paddle.arange(self.max_position, dtype="int64").unsqueeze(0).unsqueeze(-1),
|
||||
[1, 1, 3],
|
||||
)
|
||||
|
||||
position_ids_3d[:, :position_ids.shape[1], :] = position_ids
|
||||
position_ids_3d[:, : position_ids.shape[1], :] = position_ids
|
||||
|
||||
# import pdb;pdb.set_trace()
|
||||
|
||||
# position_ids: [bsz, seq_len]
|
||||
position_ids = paddle.arange(0, self.max_position, 1,
|
||||
dtype="float32").reshape((1, -1))
|
||||
position_ids = paddle.arange(0, self.max_position, 1, dtype="float32").reshape((1, -1))
|
||||
|
||||
position_ids = position_ids / self.paritial_rotary_factor
|
||||
|
||||
indices = paddle.arange(0, self.rotary_dim, 2, dtype="float32")
|
||||
indices = 1 / self.base**(indices / self.rotary_dim)
|
||||
indices = 1 / self.base ** (indices / self.rotary_dim)
|
||||
# sinusoid_inp: [bsz, seq_len, 1, head_dim // 2]
|
||||
sinusoid_inp = position_ids.unsqueeze(-1) * indices.unsqueeze(0)
|
||||
# pos_emb: [bsz, seq_len, 1, head_dim]
|
||||
pos_emb = paddle.concat(
|
||||
[paddle.sin(sinusoid_inp),
|
||||
paddle.cos(sinusoid_inp)], axis=-1)
|
||||
pos_emb = paddle.concat([paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1)
|
||||
# pos_emb: [bsz, 1, seq_len, head_dim]
|
||||
pos_emb = paddle.reshape(pos_emb,
|
||||
(-1, 1, self.max_position, self.rotary_dim))
|
||||
pos_emb = paddle.reshape(pos_emb, (-1, 1, self.max_position, self.rotary_dim))
|
||||
# pos_emb: [bsz, seq_len, 1, head_dim]
|
||||
pos_emb = pos_emb.transpose([0, 2, 1, 3])
|
||||
# sin: [bsz, seq_len, 1, head_dim // 2]
|
||||
@@ -388,35 +352,29 @@ class ErnieVlRotaryEmbedding3D:
|
||||
tmp_pos_id_2 = position_ids_3d[..., 2].squeeze().astype("int64")
|
||||
|
||||
sin_bsz = paddle.index_select(sin, index=batch_indices, axis=0)
|
||||
sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0,
|
||||
axis=1)[:, :, :, -self.freq_allocation:]
|
||||
sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1,
|
||||
axis=1)[:, :, :, :self.rotary_dim // 2 -
|
||||
self.freq_allocation:2]
|
||||
sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2,
|
||||
axis=1)[:, :, :, 1:self.rotary_dim // 2 -
|
||||
self.freq_allocation:2]
|
||||
sin_hw = paddle.stack([sin_h, sin_w],
|
||||
axis=-1).reshape(sin_h.shape[:-1] +
|
||||
[sin_h.shape[-1] * 2])
|
||||
sin_thw = paddle.concat([sin_hw, sin_t], axis=-1) # noqa
|
||||
sin_t = paddle.index_select(sin_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, -self.freq_allocation :]
|
||||
sin_h = paddle.index_select(sin_bsz, index=tmp_pos_id_1, axis=1)[
|
||||
:, :, :, : self.rotary_dim // 2 - self.freq_allocation : 2
|
||||
]
|
||||
sin_w = paddle.index_select(sin_bsz, index=tmp_pos_id_2, axis=1)[
|
||||
:, :, :, 1 : self.rotary_dim // 2 - self.freq_allocation : 2
|
||||
]
|
||||
sin_hw = paddle.stack([sin_h, sin_w], axis=-1).reshape(sin_h.shape[:-1] + [sin_h.shape[-1] * 2])
|
||||
sin_thw = paddle.concat([sin_hw, sin_t], axis=-1)
|
||||
|
||||
cos_bsz = paddle.index_select(cos, index=batch_indices, axis=0)
|
||||
cos_t = paddle.index_select(cos_bsz, index=tmp_pos_id_0,
|
||||
axis=1)[:, :, :, -self.freq_allocation:]
|
||||
cos_h = paddle.index_select(cos_bsz, index=tmp_pos_id_1,
|
||||
axis=1)[:, :, :, :self.rotary_dim // 2 -
|
||||
self.freq_allocation:2]
|
||||
cos_w = paddle.index_select(cos_bsz, index=tmp_pos_id_2,
|
||||
axis=1)[:, :, :, 1:self.rotary_dim // 2 -
|
||||
self.freq_allocation:2]
|
||||
cos_hw = paddle.stack([cos_h, cos_w],
|
||||
axis=-1).reshape(cos_h.shape[:-1] +
|
||||
[cos_h.shape[-1] * 2])
|
||||
cos_thw = paddle.concat([cos_hw, cos_t], axis=-1) # noqa
|
||||
cos_t = paddle.index_select(cos_bsz, index=tmp_pos_id_0, axis=1)[:, :, :, -self.freq_allocation :]
|
||||
cos_h = paddle.index_select(cos_bsz, index=tmp_pos_id_1, axis=1)[
|
||||
:, :, :, : self.rotary_dim // 2 - self.freq_allocation : 2
|
||||
]
|
||||
cos_w = paddle.index_select(cos_bsz, index=tmp_pos_id_2, axis=1)[
|
||||
:, :, :, 1 : self.rotary_dim // 2 - self.freq_allocation : 2
|
||||
]
|
||||
cos_hw = paddle.stack([cos_h, cos_w], axis=-1).reshape(cos_h.shape[:-1] + [cos_h.shape[-1] * 2])
|
||||
cos_thw = paddle.concat([cos_hw, cos_t], axis=-1)
|
||||
|
||||
rot_emb[0] = cos_thw # noqa
|
||||
rot_emb[1] = sin_thw # noqa
|
||||
rot_emb[0] = cos_thw
|
||||
rot_emb[1] = sin_thw
|
||||
|
||||
return rot_emb
|
||||
|
||||
@@ -446,9 +404,8 @@ def get_rope_3d(
|
||||
max_position: Maximum position index to precompute.
|
||||
freq_allocation: Number of rotary dimensions allocated to temporal axis
|
||||
"""
|
||||
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(rotary_dim, base,
|
||||
partial_rotary_factor,
|
||||
max_position,
|
||||
freq_allocation)
|
||||
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(
|
||||
rotary_dim, base, partial_rotary_factor, max_position, freq_allocation
|
||||
)
|
||||
rotary_emb_3d = rotary_emb3d_layer(position_ids)
|
||||
return rotary_emb_3d
|
||||
|
Reference in New Issue
Block a user