mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Bugfix] Fix model accuracy in some ops (#3231)
* fix noaux_tc op * fix * update * fix qk norm * fix linear for prequant loader * test * fix * fix * rm some print * fix noaux_tc op * test * Fix the confused enable_early_stop when only set early_stop_config (#3214) * fix the confused early_stop_config when only set early_stop_config * pre-commit * write a general method * Add ci case for min token and max token (#3229) Co-authored-by: xujing43 <xujing43@baidu.com> * add some evil cases (#3240) * add repitation early stop cases * add repitation early stop cases * add bad cases * add bad cases * add evil cases * qwen3_moe (#3084) * [Feature] support seed parameter (#3161) * support seed * fix * add SamplingMetadata seed test * The next_tokens values are inconsistent! * add air and rejection seed test * fix * add SamplingParams seed test * fix seed=0 * Default to defualt * fix * fix args_utils * fix review * fix review * fix * fix * add xpu,gcu,iluvatar support seed * fix * 【Fix Bug】 修复 fa3 支持集中式bug (#3235) * fix fa3 集中式bug * 增加qknorm参数 * fix qk norm * fix * update * fix linear for prequant loader * fix * fix * rm some print * fix * fix moe init weight&scale * fix moe init weight&scale --------- Co-authored-by: bukejiyu <395822456@qq.com> Co-authored-by: yuanxiaolan <yuanxiaolan01@baidu.com> Co-authored-by: Zero Rains <linjunlu@zerorains.top> Co-authored-by: xjkmfa <108254620+xjkmfa@users.noreply.github.com> Co-authored-by: xujing43 <xujing43@baidu.com> Co-authored-by: Divano <dddivano@outlook.com> Co-authored-by: bukejiyu <52310069+bukejiyu@users.noreply.github.com> Co-authored-by: lizexu123 <39205361+lizexu123@users.noreply.github.com> Co-authored-by: yangjianfengo1 <125249383+yangjianfengo1@users.noreply.github.com> Co-authored-by: qingqing01 <dangqingqing@baidu.com>
This commit is contained in:
@@ -56,15 +56,14 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
|
||||
int64_t global_warp_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int64_t all_warp_num = gridDim.x * blockDim.x;
|
||||
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||
int64_t all_head_dim = elem_cnt / head_size;
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
|
||||
// const int64_t offset = 2 * hidden_size;
|
||||
const int half_head_size = head_size / 2;
|
||||
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) {
|
||||
int64_t linear_index = gloabl_hi * head_size + threadIdx.y * VecSize;
|
||||
int64_t linear_index = gloabl_hi * head_size + threadIdx.x * VecSize;
|
||||
const int ori_bi = linear_index / hidden_size;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
@@ -122,13 +121,13 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
LoadT q_norm_vec, k_norm_vec;
|
||||
if (hi < num_heads) { // q
|
||||
Load<T, VecSize>(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec);
|
||||
Load<T, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
|
||||
}
|
||||
} else { // k
|
||||
Load<T, VecSize>(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec);
|
||||
Load<T, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
|
||||
}
|
||||
|
||||
@@ -45,7 +45,6 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
const uint32_t elem_nums =
|
||||
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||
: bsz * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||
assert(dim_head == 128 && "dim_head must be 128");
|
||||
constexpr int HEAD_DIM = 128;
|
||||
|
||||
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
||||
@@ -53,7 +52,7 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
dim3 block_dim(blocksize / kWarpSize, kWarpSize, 1);
|
||||
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
|
||||
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize>
|
||||
<<<grid_size, block_dim, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
|
||||
@@ -432,13 +432,13 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
|
||||
LoadT src_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
int64_t global_warp_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int64_t all_warp_num = gridDim.x * blockDim.x;
|
||||
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||
const int half_lastdim = last_dim / 2;
|
||||
const int offset = (q_num_head + kv_num_head) * last_dim;
|
||||
const int all_head_num = elem_cnt / last_dim;
|
||||
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) {
|
||||
int64_t linear_index = gloabl_hi * last_dim + threadIdx.y * VecSize;
|
||||
int64_t linear_index = gloabl_hi * last_dim + threadIdx.x * VecSize;
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
@@ -478,13 +478,13 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
LoadT q_norm_vec, k_norm_vec;
|
||||
if (hi < q_num_head) {
|
||||
Load<T, VecSize>(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec);
|
||||
Load<T, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
|
||||
}
|
||||
} else {
|
||||
Load<T, VecSize>(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec);
|
||||
Load<T, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
|
||||
}
|
||||
@@ -1690,13 +1690,13 @@ void gqa_rotary_qk_norm_variable(
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
dim3 Blocks(grid_size/kWarpSize, kWarpSize, 1);
|
||||
dim3 Block_Size(kWarpSize, blocksize/kWarpSize, 1);
|
||||
|
||||
const float *cos_emb = rotary_emb;
|
||||
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
|
||||
|
||||
GQAVariableLengthRotaryQKNormKernel<T, PackSize>
|
||||
<<<grid_size, Blocks, 0, stream>>>(
|
||||
<<<grid_size, Block_Size, 0, stream>>>(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
|
||||
@@ -430,6 +430,9 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
} else if (group_size == 12) { \
|
||||
constexpr size_t GROUP_SIZE = 12; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 14) { \
|
||||
constexpr size_t GROUP_SIZE = 14; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 16) { \
|
||||
constexpr size_t GROUP_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
|
||||
Reference in New Issue
Block a user