mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-27 02:20:31 +08:00
fix mask_offset in append_attn (#3745)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* fix mask_offset in append_attn * fix test
This commit is contained in:
@@ -42,19 +42,22 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
const uint32_t elem_cnt,
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d,
|
||||
const T* q_norm_weight,
|
||||
const T* k_norm_weight,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadKVT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
LoadT src_vec;
|
||||
LoadBiasT out_vec;
|
||||
LoadKVT cache_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
LoadFloat tmp_vec;
|
||||
LoadFloat q_norm_vec, k_norm_vec;
|
||||
|
||||
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||
@@ -105,10 +108,8 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(tmp1);
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(tmp2);
|
||||
tmp_vec[2 * i] = tmp1;
|
||||
tmp_vec[2 * i + 1] = tmp2;
|
||||
} else {
|
||||
out_vec[2 * i] = src_vec[2 * i];
|
||||
out_vec[2 * i + 1] = src_vec[2 * i + 1];
|
||||
@@ -119,17 +120,17 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
float row_variance =
|
||||
max(warp_m2 / head_size, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
LoadT q_norm_vec, k_norm_vec;
|
||||
|
||||
if (hi < num_heads) { // q
|
||||
Load<T, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
|
||||
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
} else { // k
|
||||
Load<T, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
|
||||
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,8 +39,8 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d,
|
||||
const T* q_norm_weight,
|
||||
const T* k_norm_weight,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
const uint32_t elem_nums =
|
||||
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||
@@ -569,8 +569,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
reinterpret_cast<const DataType_*>(q_norm_weight.get().data<T>()),
|
||||
reinterpret_cast<const DataType_*>(k_norm_weight.get().data<T>()),
|
||||
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
PD_THROW(
|
||||
|
||||
@@ -431,16 +431,19 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
|
||||
const int seq_len,
|
||||
const int last_dim,
|
||||
const bool rope_3d,
|
||||
const T* q_norm_weight,
|
||||
const T* k_norm_weight,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps
|
||||
) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
LoadT src_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
LoadFloat tmp_vec;
|
||||
LoadFloat q_norm_vec, k_norm_vec;
|
||||
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||
const int half_lastdim = last_dim / 2;
|
||||
@@ -477,25 +480,25 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
src_vec[2 * i] = static_cast<T>(tmp1);
|
||||
src_vec[2 * i + 1] = static_cast<T>(tmp2);
|
||||
tmp_vec[2 * i] = tmp1;
|
||||
tmp_vec[2 * i + 1] = tmp2;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
}
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / last_dim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
LoadT q_norm_vec, k_norm_vec;
|
||||
|
||||
if (hi < q_num_head) {
|
||||
Load<T, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
|
||||
src_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
} else {
|
||||
Load<T, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
|
||||
src_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
Store<T, VecSize>(src_vec, &qkv_out[base_idx]);
|
||||
@@ -1695,8 +1698,8 @@ void gqa_rotary_qk_norm_variable(
|
||||
const cudaStream_t &stream,
|
||||
bool use_neox_style = false,
|
||||
bool rope_3d = false,
|
||||
const T *q_norm_weight = nullptr,
|
||||
const T *k_norm_weight = nullptr,
|
||||
const float *q_norm_weight = nullptr,
|
||||
const float *k_norm_weight = nullptr,
|
||||
const float rms_norm_eps = 1e-6) {
|
||||
int64_t elem_nums =
|
||||
qkv_out_scales
|
||||
|
||||
@@ -80,8 +80,8 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d,
|
||||
q_norm_weight ? q_norm_weight.get().data<T>() : nullptr,
|
||||
k_norm_weight ? k_norm_weight.get().data<T>() : nullptr,
|
||||
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
PD_THROW(
|
||||
|
||||
Reference in New Issue
Block a user