fix mask_offset in append_attn (#3745)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* fix mask_offset in append_attn

* fix test
This commit is contained in:
lizhenyun01
2025-08-31 15:03:16 +08:00
committed by GitHub
parent 753772ace8
commit bed09ae8f8
8 changed files with 46 additions and 44 deletions

View File

@@ -42,19 +42,22 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
const uint32_t elem_cnt,
const int kv_num_heads,
const bool rope_3d,
const T* q_norm_weight,
const T* k_norm_weight,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps) {
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadKVT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
LoadT src_vec;
LoadBiasT out_vec;
LoadKVT cache_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
LoadFloat tmp_vec;
LoadFloat q_norm_vec, k_norm_vec;
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.y;
@@ -105,10 +108,8 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
out_vec[2 * i] =
static_cast<T>(tmp1);
out_vec[2 * i + 1] =
static_cast<T>(tmp2);
tmp_vec[2 * i] = tmp1;
tmp_vec[2 * i + 1] = tmp2;
} else {
out_vec[2 * i] = src_vec[2 * i];
out_vec[2 * i + 1] = src_vec[2 * i + 1];
@@ -119,17 +120,17 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
float row_variance =
max(warp_m2 / head_size, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
LoadT q_norm_vec, k_norm_vec;
if (hi < num_heads) { // q
Load<T, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
}
} else { // k
Load<T, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
}
}
}

View File

@@ -39,8 +39,8 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
const cudaStream_t& stream,
const bool use_neox_style,
const bool rope_3d,
const T* q_norm_weight,
const T* k_norm_weight,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps) {
const uint32_t elem_nums =
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
@@ -569,8 +569,8 @@ void DecoderWriteCacheWithRoPEKernel(
stream,
use_neox_rotary_style,
rope_3d,
reinterpret_cast<const DataType_*>(q_norm_weight.get().data<T>()),
reinterpret_cast<const DataType_*>(k_norm_weight.get().data<T>()),
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
rms_norm_eps);
} else {
PD_THROW(

View File

@@ -431,16 +431,19 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
const int seq_len,
const int last_dim,
const bool rope_3d,
const T* q_norm_weight,
const T* k_norm_weight,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps
) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
LoadT src_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
LoadFloat tmp_vec;
LoadFloat q_norm_vec, k_norm_vec;
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.y;
const int half_lastdim = last_dim / 2;
@@ -477,25 +480,25 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
src_vec[2 * i] = static_cast<T>(tmp1);
src_vec[2 * i + 1] = static_cast<T>(tmp2);
tmp_vec[2 * i] = tmp1;
tmp_vec[2 * i + 1] = tmp2;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
}
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / last_dim, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
LoadT q_norm_vec, k_norm_vec;
if (hi < q_num_head) {
Load<T, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
src_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
}
} else {
Load<T, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
for (int i = 0; i < VecSize; i++) {
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
src_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
}
}
Store<T, VecSize>(src_vec, &qkv_out[base_idx]);
@@ -1695,8 +1698,8 @@ void gqa_rotary_qk_norm_variable(
const cudaStream_t &stream,
bool use_neox_style = false,
bool rope_3d = false,
const T *q_norm_weight = nullptr,
const T *k_norm_weight = nullptr,
const float *q_norm_weight = nullptr,
const float *k_norm_weight = nullptr,
const float rms_norm_eps = 1e-6) {
int64_t elem_nums =
qkv_out_scales

View File

@@ -80,8 +80,8 @@ void EncoderWriteCacheWithRopeKernel(
stream,
use_neox_style,
rope_3d,
q_norm_weight ? q_norm_weight.get().data<T>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<T>() : nullptr,
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
rms_norm_eps);
} else {
PD_THROW(