[Bugfix] Fix model accuracy in some ops (#3231)

* fix noaux_tc op

* fix

* update

* fix qk norm

* fix linear for prequant loader

* test

* fix

* fix

* rm some print

* fix noaux_tc op

* test

* Fix the confused enable_early_stop when only set early_stop_config (#3214)

* fix the confused early_stop_config when only set early_stop_config

* pre-commit

* write a general method

* Add ci case for min token and max token (#3229)

Co-authored-by: xujing43 <xujing43@baidu.com>

* add some evil cases (#3240)

* add repitation early stop cases

* add repitation early stop cases

* add bad cases

* add bad cases

* add evil cases

* qwen3_moe (#3084)

* [Feature] support seed parameter (#3161)

* support seed

* fix

* add SamplingMetadata seed test

* The next_tokens values are inconsistent!

* add air and rejection seed test

* fix

* add SamplingParams seed test

* fix seed=0

* Default to defualt

* fix

* fix args_utils

* fix review

* fix review

* fix

* fix

* add xpu,gcu,iluvatar support seed

* fix

* 【Fix Bug】 修复 fa3 支持集中式bug (#3235)

* fix fa3 集中式bug

* 增加qknorm参数

* fix qk norm

* fix

* update

* fix linear for prequant loader

* fix

* fix

* rm some print

* fix

* fix moe init weight&scale

* fix moe init weight&scale

---------

Co-authored-by: bukejiyu <395822456@qq.com>
Co-authored-by: yuanxiaolan <yuanxiaolan01@baidu.com>
Co-authored-by: Zero Rains <linjunlu@zerorains.top>
Co-authored-by: xjkmfa <108254620+xjkmfa@users.noreply.github.com>
Co-authored-by: xujing43 <xujing43@baidu.com>
Co-authored-by: Divano <dddivano@outlook.com>
Co-authored-by: bukejiyu <52310069+bukejiyu@users.noreply.github.com>
Co-authored-by: lizexu123 <39205361+lizexu123@users.noreply.github.com>
Co-authored-by: yangjianfengo1 <125249383+yangjianfengo1@users.noreply.github.com>
Co-authored-by: qingqing01 <dangqingqing@baidu.com>
This commit is contained in:
gaoziyuan
2025-08-08 17:30:37 +08:00
committed by GitHub
parent ce1f353c70
commit a799d14df1
8 changed files with 62 additions and 31 deletions

View File

@@ -56,15 +56,14 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
LoadEmbT cos_emb_vec; LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec; LoadEmbT sin_emb_vec;
int64_t global_warp_idx = blockIdx.x * blockDim.x + threadIdx.x; int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.x; int64_t all_warp_num = gridDim.x * blockDim.y;
int64_t all_head_dim = elem_cnt / head_size; int64_t all_head_dim = elem_cnt / head_size;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size; const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
// const int64_t offset = 2 * hidden_size;
const int half_head_size = head_size / 2; const int half_head_size = head_size / 2;
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) { for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) {
int64_t linear_index = gloabl_hi * head_size + threadIdx.y * VecSize; int64_t linear_index = gloabl_hi * head_size + threadIdx.x * VecSize;
const int ori_bi = linear_index / hidden_size; const int ori_bi = linear_index / hidden_size;
const int bias = linear_index % hidden_size; const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v const int hi = bias / head_size; // q + k + v
@@ -122,13 +121,13 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
float row_inv_var = Rsqrt(row_variance + rms_norm_eps); float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
LoadT q_norm_vec, k_norm_vec; LoadT q_norm_vec, k_norm_vec;
if (hi < num_heads) { // q if (hi < num_heads) { // q
Load<T, VecSize>(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec); Load<T, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; i++) { for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i])); out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
} }
} else { // k } else { // k
Load<T, VecSize>(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec); Load<T, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
for (int i = 0; i < VecSize; i++) { for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i])); out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
} }

View File

@@ -45,7 +45,6 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
const uint32_t elem_nums = const uint32_t elem_nums =
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2 use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
: bsz * (num_heads + 2 * kv_num_heads) * dim_head; : bsz * (num_heads + 2 * kv_num_heads) * dim_head;
assert(dim_head == 128 && "dim_head must be 128");
constexpr int HEAD_DIM = 128; constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize; constexpr int PackSize = HEAD_DIM / kWarpSize;
@@ -53,7 +52,7 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
const int blocksize = 128; const int blocksize = 128;
int grid_size = 1; int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size); GetNumBlocks<128>(pack_num, &grid_size);
dim3 block_dim(blocksize / kWarpSize, kWarpSize, 1); dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize> append_decode_cache_T_rope_qk_norm_kernel<T, PackSize>
<<<grid_size, block_dim, 0, stream>>>(reinterpret_cast<const T*>(qkv), <<<grid_size, block_dim, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache, key_cache,

View File

@@ -432,13 +432,13 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
LoadT src_vec; LoadT src_vec;
LoadEmbT cos_emb_vec; LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec; LoadEmbT sin_emb_vec;
int64_t global_warp_idx = blockDim.x * blockIdx.x + threadIdx.x; int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.x; int64_t all_warp_num = gridDim.x * blockDim.y;
const int half_lastdim = last_dim / 2; const int half_lastdim = last_dim / 2;
const int offset = (q_num_head + kv_num_head) * last_dim; const int offset = (q_num_head + kv_num_head) * last_dim;
const int all_head_num = elem_cnt / last_dim; const int all_head_num = elem_cnt / last_dim;
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) { for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) {
int64_t linear_index = gloabl_hi * last_dim + threadIdx.y * VecSize; int64_t linear_index = gloabl_hi * last_dim + threadIdx.x * VecSize;
const int token_idx = linear_index / offset; const int token_idx = linear_index / offset;
const int ori_bi = batch_id_per_token[token_idx]; const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens[ori_bi] == 0) continue; if (seq_lens[ori_bi] == 0) continue;
@@ -478,13 +478,13 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
float row_inv_var = Rsqrt(row_variance + rms_norm_eps); float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
LoadT q_norm_vec, k_norm_vec; LoadT q_norm_vec, k_norm_vec;
if (hi < q_num_head) { if (hi < q_num_head) {
Load<T, VecSize>(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec); Load<T, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; i++) { for (int i = 0; i < VecSize; i++) {
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i])); src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
} }
} else { } else {
Load<T, VecSize>(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec); Load<T, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
for (int i = 0; i < VecSize; i++) { for (int i = 0; i < VecSize; i++) {
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i])); src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
} }
@@ -1690,13 +1690,13 @@ void gqa_rotary_qk_norm_variable(
const int blocksize = 128; const int blocksize = 128;
int grid_size = 1; int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size); GetNumBlocks<128>(pack_num, &grid_size);
dim3 Blocks(grid_size/kWarpSize, kWarpSize, 1); dim3 Block_Size(kWarpSize, blocksize/kWarpSize, 1);
const float *cos_emb = rotary_emb; const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
GQAVariableLengthRotaryQKNormKernel<T, PackSize> GQAVariableLengthRotaryQKNormKernel<T, PackSize>
<<<grid_size, Blocks, 0, stream>>>( <<<grid_size, Block_Size, 0, stream>>>(
reinterpret_cast<const T *>(qkv_input), reinterpret_cast<const T *>(qkv_input),
cos_emb, cos_emb,
sin_emb, sin_emb,

View File

@@ -430,6 +430,9 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
} else if (group_size == 12) { \ } else if (group_size == 12) { \
constexpr size_t GROUP_SIZE = 12; \ constexpr size_t GROUP_SIZE = 12; \
__VA_ARGS__ \ __VA_ARGS__ \
} else if (group_size == 14) { \
constexpr size_t GROUP_SIZE = 14; \
__VA_ARGS__ \
} else if (group_size == 16) { \ } else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \ constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \ __VA_ARGS__ \

View File

@@ -28,19 +28,20 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
int topk, int topk,
float routed_scaling_factor) { float routed_scaling_factor) {
auto input_shape = scores_with_bias.shape(); auto input_shape = scores_with_bias.shape();
PD_CHECK(input_shape.size() == 2);
int64_t num_tokens = input_shape[0]; int64_t num_tokens = input_shape[0];
int64_t num_experts = input_shape[1]; int64_t num_experts = input_shape[1];
auto input_type = scores_with_bias.dtype(); auto input_type = scores_with_bias.dtype();
auto place = scores_with_bias.place(); auto place = scores_with_bias.place();
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place); auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
auto topk_values = paddle::empty({num_tokens, topk}, input_type, place); auto topk_values = paddle::empty({num_tokens, topk}, input_type, place);
auto topk_indices = paddle::empty({num_tokens, topk}, paddle::DataType::INT32, place); auto topk_indices = paddle::empty({num_tokens, topk}, paddle::DataType::INT64, place);
auto stream = scores_with_bias.stream(); auto stream = scores_with_bias.stream();
invokeNoAuxTc<float, int32_t>(reinterpret_cast<float*>(scores.data<float>()), invokeNoAuxTc<float, int64_t>(reinterpret_cast<float*>(scores.data<float>()),
reinterpret_cast<float*>(group_scores.data<float>()), reinterpret_cast<float*>(group_scores.data<float>()),
reinterpret_cast<float*>(topk_values.data<float>()), reinterpret_cast<float*>(topk_values.data<float>()),
reinterpret_cast<int32_t*>(topk_indices.data<int32_t>()), reinterpret_cast<int64_t*>(topk_indices.data<int64_t>()),
reinterpret_cast<float*>(scores_with_bias.data<float>()), reinterpret_cast<float*>(scores_with_bias.data<float>()),
num_tokens, num_tokens,
num_experts, num_experts,
@@ -56,7 +57,7 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
std::vector<paddle::DataType> NoauxTcInferDtype( std::vector<paddle::DataType> NoauxTcInferDtype(
const paddle::DataType& scores_dtype, const paddle::DataType& scores_dtype,
const paddle::DataType& scores_with_bias_dtype) { const paddle::DataType& scores_with_bias_dtype) {
return {scores_dtype, scores_dtype, paddle::DataType::INT32}; return {scores_dtype, scores_dtype, paddle::DataType::INT64};
} }
std::vector<std::vector<int64_t>> NoauxTcInferShape( std::vector<std::vector<int64_t>> NoauxTcInferShape(
@@ -71,7 +72,7 @@ std::vector<std::vector<int64_t>> NoauxTcInferShape(
PD_BUILD_STATIC_OP(noaux_tc) PD_BUILD_STATIC_OP(noaux_tc)
.Inputs({"scores", "scores_with_bias"}) .Inputs({"scores", "scores_with_bias"})
.Outputs({"output_tensor"}) .Outputs({"output_tensor", "topk_values", "topk_indices"})
.Attrs({"n_group: int", .Attrs({"n_group: int",
"topk_group: int", "topk_group: int",
"topk:int", "topk:int",

View File

@@ -49,7 +49,7 @@ def get_moe_scores(
compute moe scores using e_score_correction_bias. compute moe scores using e_score_correction_bias.
""" """
scores = paddle.nn.functional.sigmoid(gating_output) scores = paddle.nn.functional.sigmoid(gating_output)
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) scores_with_bias = scores + e_score_correction_bias
scores, topk_values, topk_idx = noaux_tc( scores, topk_values, topk_idx = noaux_tc(
scores, scores,
scores_with_bias, scores_with_bias,

View File

@@ -312,6 +312,19 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
below is TP compute method. below is TP compute method.
""" """
gate_out = gate(x.cast("float32")) gate_out = gate(x.cast("float32"))
if layer.topk_method == "noaux_tc":
from .ep import get_moe_scores
_, topk_weights, topk_ids = get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
)
else:
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out, gate_out,
layer.gate_correction_bias, layer.gate_correction_bias,

View File

@@ -285,7 +285,7 @@ class FusedMoE(nn.Layer):
dtype="float32", dtype="float32",
) )
up_gate_proj_output_dim = self.moe_intermediate_size * 2 up_gate_proj_output_dim = self.moe_intermediate_size * 2
if self.moe_quant_type in ["fp8", "wint8"]: if self.moe_quant_type in ["block_wise_fp8", "wint8"]:
up_gate_proj_weight_shape = [ up_gate_proj_weight_shape = [
self.num_local_experts, self.num_local_experts,
up_gate_proj_output_dim, up_gate_proj_output_dim,
@@ -309,9 +309,10 @@ class FusedMoE(nn.Layer):
] ]
# Create parameters # Create parameters
if self.moe_quant_type == "fp8": if self.moe_quant_type == "block_wise_fp8":
# (TODO:gaoziyuan) # (TODO:gaoziyuan)
pass self.weight_dtype = "float8_e4m3fn"
self.init_block_wise_fp8_scale()
elif self.moe_quant_type == "wint8": elif self.moe_quant_type == "wint8":
self.weight_dtype = "int8" self.weight_dtype = "int8"
self.init_weight_only_scale() self.init_weight_only_scale()
@@ -342,6 +343,21 @@ class FusedMoE(nn.Layer):
dtype=self._dtype, dtype=self._dtype,
) )
def init_block_wise_fp8_scale(self):
"""
Initialize the weight scale.
"""
self.up_gate_proj_weight_scale = self.create_parameter(
shape=[self.num_local_experts, self.moe_intermediate_size * 2 // 128, self.hidden_size // 128],
dtype="float32",
is_bias=False,
)
self.down_proj_weight_scale = self.create_parameter(
shape=[self.num_local_experts, self.hidden_size // 128, self.moe_intermediate_size // 128],
dtype="float32",
is_bias=False,
)
def load_experts_weight( def load_experts_weight(
self, self,
state_dict: dict, state_dict: dict,