mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Bugfix] Fix model accuracy in some ops (#3231)
* fix noaux_tc op * fix * update * fix qk norm * fix linear for prequant loader * test * fix * fix * rm some print * fix noaux_tc op * test * Fix the confused enable_early_stop when only set early_stop_config (#3214) * fix the confused early_stop_config when only set early_stop_config * pre-commit * write a general method * Add ci case for min token and max token (#3229) Co-authored-by: xujing43 <xujing43@baidu.com> * add some evil cases (#3240) * add repitation early stop cases * add repitation early stop cases * add bad cases * add bad cases * add evil cases * qwen3_moe (#3084) * [Feature] support seed parameter (#3161) * support seed * fix * add SamplingMetadata seed test * The next_tokens values are inconsistent! * add air and rejection seed test * fix * add SamplingParams seed test * fix seed=0 * Default to defualt * fix * fix args_utils * fix review * fix review * fix * fix * add xpu,gcu,iluvatar support seed * fix * 【Fix Bug】 修复 fa3 支持集中式bug (#3235) * fix fa3 集中式bug * 增加qknorm参数 * fix qk norm * fix * update * fix linear for prequant loader * fix * fix * rm some print * fix * fix moe init weight&scale * fix moe init weight&scale --------- Co-authored-by: bukejiyu <395822456@qq.com> Co-authored-by: yuanxiaolan <yuanxiaolan01@baidu.com> Co-authored-by: Zero Rains <linjunlu@zerorains.top> Co-authored-by: xjkmfa <108254620+xjkmfa@users.noreply.github.com> Co-authored-by: xujing43 <xujing43@baidu.com> Co-authored-by: Divano <dddivano@outlook.com> Co-authored-by: bukejiyu <52310069+bukejiyu@users.noreply.github.com> Co-authored-by: lizexu123 <39205361+lizexu123@users.noreply.github.com> Co-authored-by: yangjianfengo1 <125249383+yangjianfengo1@users.noreply.github.com> Co-authored-by: qingqing01 <dangqingqing@baidu.com>
This commit is contained in:
@@ -56,15 +56,14 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
|||||||
LoadEmbT cos_emb_vec;
|
LoadEmbT cos_emb_vec;
|
||||||
LoadEmbT sin_emb_vec;
|
LoadEmbT sin_emb_vec;
|
||||||
|
|
||||||
int64_t global_warp_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||||
int64_t all_warp_num = gridDim.x * blockDim.x;
|
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||||
int64_t all_head_dim = elem_cnt / head_size;
|
int64_t all_head_dim = elem_cnt / head_size;
|
||||||
|
|
||||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
|
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
|
||||||
// const int64_t offset = 2 * hidden_size;
|
|
||||||
const int half_head_size = head_size / 2;
|
const int half_head_size = head_size / 2;
|
||||||
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) {
|
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) {
|
||||||
int64_t linear_index = gloabl_hi * head_size + threadIdx.y * VecSize;
|
int64_t linear_index = gloabl_hi * head_size + threadIdx.x * VecSize;
|
||||||
const int ori_bi = linear_index / hidden_size;
|
const int ori_bi = linear_index / hidden_size;
|
||||||
const int bias = linear_index % hidden_size;
|
const int bias = linear_index % hidden_size;
|
||||||
const int hi = bias / head_size; // q + k + v
|
const int hi = bias / head_size; // q + k + v
|
||||||
@@ -122,13 +121,13 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
|||||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||||
LoadT q_norm_vec, k_norm_vec;
|
LoadT q_norm_vec, k_norm_vec;
|
||||||
if (hi < num_heads) { // q
|
if (hi < num_heads) { // q
|
||||||
Load<T, VecSize>(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec);
|
Load<T, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < VecSize; i++) {
|
for (int i = 0; i < VecSize; i++) {
|
||||||
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
|
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
|
||||||
}
|
}
|
||||||
} else { // k
|
} else { // k
|
||||||
Load<T, VecSize>(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec);
|
Load<T, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||||
for (int i = 0; i < VecSize; i++) {
|
for (int i = 0; i < VecSize; i++) {
|
||||||
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
|
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
|
||||||
}
|
}
|
||||||
|
@@ -45,7 +45,6 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
|||||||
const uint32_t elem_nums =
|
const uint32_t elem_nums =
|
||||||
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||||
: bsz * (num_heads + 2 * kv_num_heads) * dim_head;
|
: bsz * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||||
assert(dim_head == 128 && "dim_head must be 128");
|
|
||||||
constexpr int HEAD_DIM = 128;
|
constexpr int HEAD_DIM = 128;
|
||||||
|
|
||||||
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
||||||
@@ -53,7 +52,7 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
|||||||
const int blocksize = 128;
|
const int blocksize = 128;
|
||||||
int grid_size = 1;
|
int grid_size = 1;
|
||||||
GetNumBlocks<128>(pack_num, &grid_size);
|
GetNumBlocks<128>(pack_num, &grid_size);
|
||||||
dim3 block_dim(blocksize / kWarpSize, kWarpSize, 1);
|
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
|
||||||
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize>
|
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize>
|
||||||
<<<grid_size, block_dim, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
<<<grid_size, block_dim, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||||
key_cache,
|
key_cache,
|
||||||
|
@@ -432,13 +432,13 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
|
|||||||
LoadT src_vec;
|
LoadT src_vec;
|
||||||
LoadEmbT cos_emb_vec;
|
LoadEmbT cos_emb_vec;
|
||||||
LoadEmbT sin_emb_vec;
|
LoadEmbT sin_emb_vec;
|
||||||
int64_t global_warp_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||||
int64_t all_warp_num = gridDim.x * blockDim.x;
|
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||||
const int half_lastdim = last_dim / 2;
|
const int half_lastdim = last_dim / 2;
|
||||||
const int offset = (q_num_head + kv_num_head) * last_dim;
|
const int offset = (q_num_head + kv_num_head) * last_dim;
|
||||||
const int all_head_num = elem_cnt / last_dim;
|
const int all_head_num = elem_cnt / last_dim;
|
||||||
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) {
|
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) {
|
||||||
int64_t linear_index = gloabl_hi * last_dim + threadIdx.y * VecSize;
|
int64_t linear_index = gloabl_hi * last_dim + threadIdx.x * VecSize;
|
||||||
const int token_idx = linear_index / offset;
|
const int token_idx = linear_index / offset;
|
||||||
const int ori_bi = batch_id_per_token[token_idx];
|
const int ori_bi = batch_id_per_token[token_idx];
|
||||||
if (seq_lens[ori_bi] == 0) continue;
|
if (seq_lens[ori_bi] == 0) continue;
|
||||||
@@ -478,13 +478,13 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
|
|||||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||||
LoadT q_norm_vec, k_norm_vec;
|
LoadT q_norm_vec, k_norm_vec;
|
||||||
if (hi < q_num_head) {
|
if (hi < q_num_head) {
|
||||||
Load<T, VecSize>(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec);
|
Load<T, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < VecSize; i++) {
|
for (int i = 0; i < VecSize; i++) {
|
||||||
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
|
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Load<T, VecSize>(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec);
|
Load<T, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||||
for (int i = 0; i < VecSize; i++) {
|
for (int i = 0; i < VecSize; i++) {
|
||||||
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
|
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
|
||||||
}
|
}
|
||||||
@@ -1690,13 +1690,13 @@ void gqa_rotary_qk_norm_variable(
|
|||||||
const int blocksize = 128;
|
const int blocksize = 128;
|
||||||
int grid_size = 1;
|
int grid_size = 1;
|
||||||
GetNumBlocks<128>(pack_num, &grid_size);
|
GetNumBlocks<128>(pack_num, &grid_size);
|
||||||
dim3 Blocks(grid_size/kWarpSize, kWarpSize, 1);
|
dim3 Block_Size(kWarpSize, blocksize/kWarpSize, 1);
|
||||||
|
|
||||||
const float *cos_emb = rotary_emb;
|
const float *cos_emb = rotary_emb;
|
||||||
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
|
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
|
||||||
|
|
||||||
GQAVariableLengthRotaryQKNormKernel<T, PackSize>
|
GQAVariableLengthRotaryQKNormKernel<T, PackSize>
|
||||||
<<<grid_size, Blocks, 0, stream>>>(
|
<<<grid_size, Block_Size, 0, stream>>>(
|
||||||
reinterpret_cast<const T *>(qkv_input),
|
reinterpret_cast<const T *>(qkv_input),
|
||||||
cos_emb,
|
cos_emb,
|
||||||
sin_emb,
|
sin_emb,
|
||||||
|
@@ -430,6 +430,9 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
|||||||
} else if (group_size == 12) { \
|
} else if (group_size == 12) { \
|
||||||
constexpr size_t GROUP_SIZE = 12; \
|
constexpr size_t GROUP_SIZE = 12; \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
|
} else if (group_size == 14) { \
|
||||||
|
constexpr size_t GROUP_SIZE = 14; \
|
||||||
|
__VA_ARGS__ \
|
||||||
} else if (group_size == 16) { \
|
} else if (group_size == 16) { \
|
||||||
constexpr size_t GROUP_SIZE = 16; \
|
constexpr size_t GROUP_SIZE = 16; \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
|
@@ -28,19 +28,20 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
|||||||
int topk,
|
int topk,
|
||||||
float routed_scaling_factor) {
|
float routed_scaling_factor) {
|
||||||
auto input_shape = scores_with_bias.shape();
|
auto input_shape = scores_with_bias.shape();
|
||||||
|
PD_CHECK(input_shape.size() == 2);
|
||||||
int64_t num_tokens = input_shape[0];
|
int64_t num_tokens = input_shape[0];
|
||||||
int64_t num_experts = input_shape[1];
|
int64_t num_experts = input_shape[1];
|
||||||
auto input_type = scores_with_bias.dtype();
|
auto input_type = scores_with_bias.dtype();
|
||||||
auto place = scores_with_bias.place();
|
auto place = scores_with_bias.place();
|
||||||
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
|
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
|
||||||
auto topk_values = paddle::empty({num_tokens, topk}, input_type, place);
|
auto topk_values = paddle::empty({num_tokens, topk}, input_type, place);
|
||||||
auto topk_indices = paddle::empty({num_tokens, topk}, paddle::DataType::INT32, place);
|
auto topk_indices = paddle::empty({num_tokens, topk}, paddle::DataType::INT64, place);
|
||||||
auto stream = scores_with_bias.stream();
|
auto stream = scores_with_bias.stream();
|
||||||
|
|
||||||
invokeNoAuxTc<float, int32_t>(reinterpret_cast<float*>(scores.data<float>()),
|
invokeNoAuxTc<float, int64_t>(reinterpret_cast<float*>(scores.data<float>()),
|
||||||
reinterpret_cast<float*>(group_scores.data<float>()),
|
reinterpret_cast<float*>(group_scores.data<float>()),
|
||||||
reinterpret_cast<float*>(topk_values.data<float>()),
|
reinterpret_cast<float*>(topk_values.data<float>()),
|
||||||
reinterpret_cast<int32_t*>(topk_indices.data<int32_t>()),
|
reinterpret_cast<int64_t*>(topk_indices.data<int64_t>()),
|
||||||
reinterpret_cast<float*>(scores_with_bias.data<float>()),
|
reinterpret_cast<float*>(scores_with_bias.data<float>()),
|
||||||
num_tokens,
|
num_tokens,
|
||||||
num_experts,
|
num_experts,
|
||||||
@@ -56,7 +57,7 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
|||||||
std::vector<paddle::DataType> NoauxTcInferDtype(
|
std::vector<paddle::DataType> NoauxTcInferDtype(
|
||||||
const paddle::DataType& scores_dtype,
|
const paddle::DataType& scores_dtype,
|
||||||
const paddle::DataType& scores_with_bias_dtype) {
|
const paddle::DataType& scores_with_bias_dtype) {
|
||||||
return {scores_dtype, scores_dtype, paddle::DataType::INT32};
|
return {scores_dtype, scores_dtype, paddle::DataType::INT64};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> NoauxTcInferShape(
|
std::vector<std::vector<int64_t>> NoauxTcInferShape(
|
||||||
@@ -71,7 +72,7 @@ std::vector<std::vector<int64_t>> NoauxTcInferShape(
|
|||||||
|
|
||||||
PD_BUILD_STATIC_OP(noaux_tc)
|
PD_BUILD_STATIC_OP(noaux_tc)
|
||||||
.Inputs({"scores", "scores_with_bias"})
|
.Inputs({"scores", "scores_with_bias"})
|
||||||
.Outputs({"output_tensor"})
|
.Outputs({"output_tensor", "topk_values", "topk_indices"})
|
||||||
.Attrs({"n_group: int",
|
.Attrs({"n_group: int",
|
||||||
"topk_group: int",
|
"topk_group: int",
|
||||||
"topk:int",
|
"topk:int",
|
||||||
|
@@ -49,7 +49,7 @@ def get_moe_scores(
|
|||||||
compute moe scores using e_score_correction_bias.
|
compute moe scores using e_score_correction_bias.
|
||||||
"""
|
"""
|
||||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||||
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
scores_with_bias = scores + e_score_correction_bias
|
||||||
scores, topk_values, topk_idx = noaux_tc(
|
scores, topk_values, topk_idx = noaux_tc(
|
||||||
scores,
|
scores,
|
||||||
scores_with_bias,
|
scores_with_bias,
|
||||||
|
@@ -312,6 +312,19 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
below is TP compute method.
|
below is TP compute method.
|
||||||
"""
|
"""
|
||||||
gate_out = gate(x.cast("float32"))
|
gate_out = gate(x.cast("float32"))
|
||||||
|
|
||||||
|
if layer.topk_method == "noaux_tc":
|
||||||
|
from .ep import get_moe_scores
|
||||||
|
|
||||||
|
_, topk_weights, topk_ids = get_moe_scores(
|
||||||
|
gate_out,
|
||||||
|
layer.n_group,
|
||||||
|
layer.topk_group,
|
||||||
|
layer.top_k,
|
||||||
|
layer.routed_scaling_factor,
|
||||||
|
layer.gate_correction_bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||||
gate_out,
|
gate_out,
|
||||||
layer.gate_correction_bias,
|
layer.gate_correction_bias,
|
||||||
|
@@ -285,7 +285,7 @@ class FusedMoE(nn.Layer):
|
|||||||
dtype="float32",
|
dtype="float32",
|
||||||
)
|
)
|
||||||
up_gate_proj_output_dim = self.moe_intermediate_size * 2
|
up_gate_proj_output_dim = self.moe_intermediate_size * 2
|
||||||
if self.moe_quant_type in ["fp8", "wint8"]:
|
if self.moe_quant_type in ["block_wise_fp8", "wint8"]:
|
||||||
up_gate_proj_weight_shape = [
|
up_gate_proj_weight_shape = [
|
||||||
self.num_local_experts,
|
self.num_local_experts,
|
||||||
up_gate_proj_output_dim,
|
up_gate_proj_output_dim,
|
||||||
@@ -309,9 +309,10 @@ class FusedMoE(nn.Layer):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Create parameters
|
# Create parameters
|
||||||
if self.moe_quant_type == "fp8":
|
if self.moe_quant_type == "block_wise_fp8":
|
||||||
# (TODO:gaoziyuan)
|
# (TODO:gaoziyuan)
|
||||||
pass
|
self.weight_dtype = "float8_e4m3fn"
|
||||||
|
self.init_block_wise_fp8_scale()
|
||||||
elif self.moe_quant_type == "wint8":
|
elif self.moe_quant_type == "wint8":
|
||||||
self.weight_dtype = "int8"
|
self.weight_dtype = "int8"
|
||||||
self.init_weight_only_scale()
|
self.init_weight_only_scale()
|
||||||
@@ -342,6 +343,21 @@ class FusedMoE(nn.Layer):
|
|||||||
dtype=self._dtype,
|
dtype=self._dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def init_block_wise_fp8_scale(self):
|
||||||
|
"""
|
||||||
|
Initialize the weight scale.
|
||||||
|
"""
|
||||||
|
self.up_gate_proj_weight_scale = self.create_parameter(
|
||||||
|
shape=[self.num_local_experts, self.moe_intermediate_size * 2 // 128, self.hidden_size // 128],
|
||||||
|
dtype="float32",
|
||||||
|
is_bias=False,
|
||||||
|
)
|
||||||
|
self.down_proj_weight_scale = self.create_parameter(
|
||||||
|
shape=[self.num_local_experts, self.hidden_size // 128, self.moe_intermediate_size // 128],
|
||||||
|
dtype="float32",
|
||||||
|
is_bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
def load_experts_weight(
|
def load_experts_weight(
|
||||||
self,
|
self,
|
||||||
state_dict: dict,
|
state_dict: dict,
|
||||||
|
Reference in New Issue
Block a user