[Bugfix] Fix model accuracy in some ops (#3231)

* fix noaux_tc op

* fix

* update

* fix qk norm

* fix linear for prequant loader

* test

* fix

* fix

* rm some print

* fix noaux_tc op

* test

* Fix the confused enable_early_stop when only set early_stop_config (#3214)

* fix the confused early_stop_config when only set early_stop_config

* pre-commit

* write a general method

* Add ci case for min token and max token (#3229)

Co-authored-by: xujing43 <xujing43@baidu.com>

* add some evil cases (#3240)

* add repitation early stop cases

* add repitation early stop cases

* add bad cases

* add bad cases

* add evil cases

* qwen3_moe (#3084)

* [Feature] support seed parameter (#3161)

* support seed

* fix

* add SamplingMetadata seed test

* The next_tokens values are inconsistent!

* add air and rejection seed test

* fix

* add SamplingParams seed test

* fix seed=0

* Default to defualt

* fix

* fix args_utils

* fix review

* fix review

* fix

* fix

* add xpu,gcu,iluvatar support seed

* fix

* 【Fix Bug】 修复 fa3 支持集中式bug (#3235)

* fix fa3 集中式bug

* 增加qknorm参数

* fix qk norm

* fix

* update

* fix linear for prequant loader

* fix

* fix

* rm some print

* fix

* fix moe init weight&scale

* fix moe init weight&scale

---------

Co-authored-by: bukejiyu <395822456@qq.com>
Co-authored-by: yuanxiaolan <yuanxiaolan01@baidu.com>
Co-authored-by: Zero Rains <linjunlu@zerorains.top>
Co-authored-by: xjkmfa <108254620+xjkmfa@users.noreply.github.com>
Co-authored-by: xujing43 <xujing43@baidu.com>
Co-authored-by: Divano <dddivano@outlook.com>
Co-authored-by: bukejiyu <52310069+bukejiyu@users.noreply.github.com>
Co-authored-by: lizexu123 <39205361+lizexu123@users.noreply.github.com>
Co-authored-by: yangjianfengo1 <125249383+yangjianfengo1@users.noreply.github.com>
Co-authored-by: qingqing01 <dangqingqing@baidu.com>
This commit is contained in:
gaoziyuan
2025-08-08 17:30:37 +08:00
committed by GitHub
parent ce1f353c70
commit a799d14df1
8 changed files with 62 additions and 31 deletions

View File

@@ -56,15 +56,14 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
int64_t global_warp_idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t all_warp_num = gridDim.x * blockDim.x;
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.y;
int64_t all_head_dim = elem_cnt / head_size;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
// const int64_t offset = 2 * hidden_size;
const int half_head_size = head_size / 2;
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) {
int64_t linear_index = gloabl_hi * head_size + threadIdx.y * VecSize;
int64_t linear_index = gloabl_hi * head_size + threadIdx.x * VecSize;
const int ori_bi = linear_index / hidden_size;
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
@@ -122,13 +121,13 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
LoadT q_norm_vec, k_norm_vec;
if (hi < num_heads) { // q
Load<T, VecSize>(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec);
Load<T, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
}
} else { // k
Load<T, VecSize>(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec);
Load<T, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
}

View File

@@ -45,7 +45,6 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
const uint32_t elem_nums =
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
: bsz * (num_heads + 2 * kv_num_heads) * dim_head;
assert(dim_head == 128 && "dim_head must be 128");
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
@@ -53,7 +52,7 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
dim3 block_dim(blocksize / kWarpSize, kWarpSize, 1);
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize>
<<<grid_size, block_dim, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache,

View File

@@ -432,13 +432,13 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
LoadT src_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
int64_t global_warp_idx = blockDim.x * blockIdx.x + threadIdx.x;
int64_t all_warp_num = gridDim.x * blockDim.x;
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.y;
const int half_lastdim = last_dim / 2;
const int offset = (q_num_head + kv_num_head) * last_dim;
const int all_head_num = elem_cnt / last_dim;
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) {
int64_t linear_index = gloabl_hi * last_dim + threadIdx.y * VecSize;
int64_t linear_index = gloabl_hi * last_dim + threadIdx.x * VecSize;
const int token_idx = linear_index / offset;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens[ori_bi] == 0) continue;
@@ -478,13 +478,13 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
LoadT q_norm_vec, k_norm_vec;
if (hi < q_num_head) {
Load<T, VecSize>(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec);
Load<T, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
}
} else {
Load<T, VecSize>(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec);
Load<T, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
for (int i = 0; i < VecSize; i++) {
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
}
@@ -1690,13 +1690,13 @@ void gqa_rotary_qk_norm_variable(
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
dim3 Blocks(grid_size/kWarpSize, kWarpSize, 1);
dim3 Block_Size(kWarpSize, blocksize/kWarpSize, 1);
const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
GQAVariableLengthRotaryQKNormKernel<T, PackSize>
<<<grid_size, Blocks, 0, stream>>>(
<<<grid_size, Block_Size, 0, stream>>>(
reinterpret_cast<const T *>(qkv_input),
cos_emb,
sin_emb,

View File

@@ -430,6 +430,9 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
} else if (group_size == 12) { \
constexpr size_t GROUP_SIZE = 12; \
__VA_ARGS__ \
} else if (group_size == 14) { \
constexpr size_t GROUP_SIZE = 14; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \