support qk norm (#3145)

This commit is contained in:
Yuan Xiaolan
2025-08-05 16:46:14 +08:00
committed by GitHub
parent 4a10e29804
commit 7ce00e597c
17 changed files with 791 additions and 201 deletions

View File

@@ -73,6 +73,9 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
@@ -223,7 +226,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
main_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
};
if (qkv_out_scales) {
@@ -339,7 +345,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
} else {
DecoderWriteCacheWithRoPEKernel<data_t, data_t>(
meta_data,
@@ -363,7 +372,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
}
@@ -430,6 +442,9 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
@@ -500,6 +515,9 @@ std::vector<paddle::Tensor> AppendAttention(
out_linear_shifts,
out_linear_smooths,
kv_signal_data,
q_norm_weight,
k_norm_weight,
rms_norm_eps,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
@@ -577,6 +595,9 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
@@ -637,6 +658,9 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
@@ -714,7 +738,9 @@ PD_BUILD_STATIC_OP(append_attention)
paddle::Optional("cache_v_zp"),
paddle::Optional("out_linear_shifts"),
paddle::Optional("out_linear_smooths"),
paddle::Optional("kv_signal_data")})
paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
.Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})
.SetInplaceMap({{"key_cache", "key_cache_out"},
{"value_cache", "value_cache_out"}})
@@ -732,7 +758,8 @@ PD_BUILD_STATIC_OP(append_attention)
"encoder_max_partition_size: int",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool"})
"speculate_decoder: bool",
"rms_norm_eps: float"})
.SetKernelFn(PD_KERNEL(AppendAttention))
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));

View File

@@ -18,6 +18,142 @@
#include "mma_tensor_op.cuh"
#include "utils.cuh"
template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_rope_qk_norm_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
// head_size // 2]
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int head_size,
const int block_size,
const uint32_t elem_cnt,
const int kv_num_heads,
const bool rope_3d,
const T* q_norm_weight,
const T* k_norm_weight,
const float rms_norm_eps) {
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadKVT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
LoadT src_vec;
LoadBiasT out_vec;
LoadKVT cache_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
int64_t global_warp_idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t all_warp_num = gridDim.x * blockDim.x;
int64_t all_head_dim = elem_cnt / head_size;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
// const int64_t offset = 2 * hidden_size;
const int half_head_size = head_size / 2;
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) {
int64_t linear_index = gloabl_hi * head_size + threadIdx.y * VecSize;
const int ori_bi = linear_index / hidden_size;
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
const int* block_table_now = nullptr;
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
const int block_offset = write_seq_id % block_size;
const uint32_t ori_idx =
start_token_idx * hidden_size + hi * head_size + h_bias;
const int bias_idx = hi * head_size + h_bias;
Load<T, VecSize>(&quant_qkv[ori_idx], &src_vec);
if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
}
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]);
if (hi < num_heads + kv_num_heads) {
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
out_vec[2 * i] =
static_cast<T>(tmp1);
out_vec[2 * i + 1] =
static_cast<T>(tmp2);
} else {
out_vec[2 * i] = src_vec[2 * i];
out_vec[2 * i + 1] = src_vec[2 * i + 1];
}
}
if (hi < (num_heads + kv_num_heads)) { // q k
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / head_size, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
LoadT q_norm_vec, k_norm_vec;
if (hi < num_heads) { // q
Load<T, VecSize>(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
}
} else { // k
Load<T, VecSize>(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec);
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
}
}
}
if (hi < num_heads) {
// write q
Store<T, VecSize>(out_vec, &qkv_out[ori_idx]);
} else {
// quant + write k/v
const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads;
const uint32_t tgt_idx =
block_idx * kv_num_heads * block_size * head_size +
kv_head_idx * block_size * head_size + block_offset * head_size +
h_bias;
if (hi < num_heads + kv_num_heads) {
Store<T, VecSize>(out_vec, &key_cache[tgt_idx]);
} else {
Store<T, VecSize>(out_vec, &value_cache[tgt_idx]);
}
}
}
}
template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_rope_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,

View File

@@ -15,6 +15,70 @@
#include "decoder_write_cache_with_rope_kernel.h"
#include "utils.cuh"
template <typename T, typename QKV_TYPE>
void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
T* key_cache,
T* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
const float* sin_emb,
const float* qkv_out_scales,
const T* qkv_biases,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int kv_num_heads,
const int dim_head,
const int block_size,
const int bsz,
const cudaStream_t& stream,
const bool use_neox_style,
const bool rope_3d,
const T* q_norm_weight,
const T* k_norm_weight,
const float rms_norm_eps) {
const uint32_t elem_nums =
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
: bsz * (num_heads + 2 * kv_num_heads) * dim_head;
assert(dim_head == 128 && "dim_head must be 128");
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
dim3 block_dim(blocksize / kWarpSize, kWarpSize, 1);
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize>
<<<grid_size, block_dim, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d,
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
template <typename T, typename QKV_TYPE>
void append_decode_cache_rope(const QKV_TYPE* qkv,
T* key_cache,
@@ -441,7 +505,10 @@ void DecoderWriteCacheWithRoPEKernel(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out) {
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps) {
typedef cascade_attn_type_traits<T> traits_;
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
typedef typename traits_::type DataType_;
@@ -464,6 +531,43 @@ void DecoderWriteCacheWithRoPEKernel(
? rotary_embs.get().data<float>() + max_seq_len * dim_head
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
}
if (q_norm_weight && k_norm_weight) {
if (cache_quant_type_str == "none") {
append_decode_cache_rope_qk_norm(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d,
reinterpret_cast<const DataType_*>(q_norm_weight.get().data<T>()),
reinterpret_cast<const DataType_*>(k_norm_weight.get().data<T>()),
rms_norm_eps);
} else {
PD_THROW(
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
}
} else {
if (cache_quant_type_str == "none") {
append_decode_cache_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
@@ -641,6 +745,7 @@ void DecoderWriteCacheWithRoPEKernel(
"cache_int4_zp]");
}
}
}
template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
@@ -667,7 +772,10 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out);
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
template void
DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
@@ -694,7 +802,10 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out);
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
const AppendAttnMetaData& meta_data,
@@ -720,7 +831,10 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out);
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const AppendAttnMetaData& meta_data,
@@ -746,4 +860,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out);
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);

View File

@@ -40,4 +40,6 @@ void DecoderWriteCacheWithRoPEKernel(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out);
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight, const float rms_norm_eps);

View File

@@ -358,7 +358,7 @@ __global__ void GQAVariableLengthRotaryKernel(
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_bi = batch_id_per_token[token_idx];;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / last_dim;
@@ -405,6 +405,94 @@ __global__ void GQAVariableLengthRotaryKernel(
}
}
template <typename T, int VecSize = 1>
__global__ void GQAVariableLengthRotaryQKNormKernel(
const T *qkv,
const float *cos_emb,
const float *sin_emb,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
T *qkv_out,
const int64_t elem_cnt,
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int last_dim,
const bool rope_3d,
const T* q_norm_weight,
const T* k_norm_weight,
const float rms_norm_eps
) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
LoadT src_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
int64_t global_warp_idx = blockDim.x * blockIdx.x + threadIdx.x;
int64_t all_warp_num = gridDim.x * blockDim.x;
const int half_lastdim = last_dim / 2;
const int offset = (q_num_head + kv_num_head) * last_dim;
const int all_head_num = elem_cnt / last_dim;
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) {
int64_t linear_index = gloabl_hi * last_dim + threadIdx.y * VecSize;
const int token_idx = linear_index / offset;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / last_dim;
const int h_bias = bias % last_dim;
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
const int64_t base_idx =
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
h_bias;
Load<T, VecSize>(&qkv[base_idx], &src_vec);
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
const float input_left = static_cast<float>(src_vec[2 * i]);
const float input_right = static_cast<float>(src_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
src_vec[2 * i] = static_cast<T>(tmp1);
src_vec[2 * i + 1] = static_cast<T>(tmp2);
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
}
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / last_dim, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
LoadT q_norm_vec, k_norm_vec;
if (hi < q_num_head) {
Load<T, VecSize>(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
}
} else {
Load<T, VecSize>(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec);
for (int i = 0; i < VecSize; i++) {
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
}
}
Store<T, VecSize>(src_vec, &qkv_out[base_idx]);
}
}
template <typename T, int VecSize = 1>
__global__ void GQAVariableLengthRotaryKernel(
const T *qkv,
@@ -1568,6 +1656,66 @@ void rotary_qk_variable(
}
}
template <typename T, typename QKV_TYPE>
void gqa_rotary_qk_norm_variable(
T *qkv_out, // [token_num, 3, num_head, dim_head]
const QKV_TYPE *qkv_input, // qkv
const float *qkv_out_scales, // [3, num_head, dim_head]
const T *qkv_bias,
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const int token_num,
const int num_heads,
const int kv_num_heads,
const int seq_len,
const int input_output_len,
const int dim_head,
const cudaStream_t &stream,
bool use_neox_style = false,
bool rope_3d = false,
const T *q_norm_weight = nullptr,
const T *k_norm_weight = nullptr,
const float rms_norm_eps = 1e-6) {
int64_t elem_nums =
qkv_out_scales
? token_num * (num_heads + 2 * kv_num_heads) * dim_head
: token_num * (num_heads + kv_num_heads) * dim_head; // for all q k v
assert(dim_head == 128 && "dim_head must be 128");
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
dim3 Blocks(grid_size/kWarpSize, kWarpSize, 1);
const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
GQAVariableLengthRotaryQKNormKernel<T, PackSize>
<<<grid_size, Blocks, 0, stream>>>(
reinterpret_cast<const T *>(qkv_input),
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out,
elem_nums,
num_heads,
kv_num_heads,
seq_len,
dim_head,
rope_3d,
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
template <typename T, typename QKV_TYPE>
void gqa_rotary_qk_variable(
T *qkv_out, // [token_num, 3, num_head, dim_head]

View File

@@ -46,7 +46,10 @@ void EncoderWriteCacheWithRopeKernel(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out) {
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps) {
auto token_num = meta_data.token_nums;
auto num_heads = meta_data.q_num_heads;
auto kv_num_heads = meta_data.kv_num_heads;
@@ -56,6 +59,35 @@ void EncoderWriteCacheWithRopeKernel(
is_scale_channel_wise = true;
}
if (q_norm_weight && k_norm_weight) {
if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) {
gqa_rotary_qk_norm_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d,
q_norm_weight ? q_norm_weight.get().data<T>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<T>() : nullptr,
rms_norm_eps);
} else {
PD_THROW(
"gqa_rotary_qk_norm_variable only support gqa mode. channel wise scale and neox style are not supported");
}
} else {
if (num_heads == kv_num_heads) {
rotary_qk_variable(
qkv_out->data<T>(),
@@ -121,6 +153,7 @@ void EncoderWriteCacheWithRopeKernel(
}
}
}
const uint32_t block_size = meta_data.block_size;
if (cache_quant_type_str == "none") {
CascadeAppendWriteCacheKVQKV<T>(meta_data,

View File

@@ -43,4 +43,7 @@ EncoderWriteCacheWithRopeKernel<paddle::bfloat16, paddle::bfloat16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out);
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);

View File

@@ -42,4 +42,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::bfloat16, int>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out);
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);

View File

@@ -42,4 +42,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, paddle::float16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out);
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);

View File

@@ -42,4 +42,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, int>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out);
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);

View File

@@ -559,3 +559,37 @@ template <typename T, bool IsFP8>inline __device__ static void convert_c8(T * re
convert_int8(result, source);
}
}
constexpr int kWarpSize = 32;
template<typename T>
inline __device__ void WelfordCombine1(T b_m2, T* m2) {
*m2 += b_m2;
}
template<typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) {
*m2 = thread_m2;
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
T b_m2 = __shfl_xor_sync(0xffffffff, *m2, mask);
WelfordCombine1(b_m2, m2);
}
}
template<typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) {
WelfordWarpReduce<T, thread_group_width>(thread_m2, m2);
}
template <typename T>
__inline__ __device__ T Rsqrt(T x);
template <>
__inline__ __device__ float Rsqrt<float>(float x) {
return rsqrt(x);
}
template <>
__inline__ __device__ double Rsqrt<double>(double x) {
return rsqrt(x);
}

View File

@@ -78,6 +78,9 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::optional<paddle::Tensor> &out_linear_shifts,
const paddle::optional<paddle::Tensor> &out_linear_smooths,
const paddle::optional<paddle::Tensor> &kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string &compute_dtype, const std::string &cache_quant_type_str,
const bool use_neox_rotary_style, const bool rope_3d,
const int max_input_length, const float quant_max_bound,

View File

@@ -43,6 +43,11 @@
__VA_ARGS__ \
break; \
} \
case 48: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 48; \
__VA_ARGS__ \
break; \
} \
case 64: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 64; \
__VA_ARGS__ \

View File

@@ -262,6 +262,9 @@ class AppendAttentionBackend(AttentionBackend):
layer.linear_shift,
layer.linear_smooth,
metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "rms_norm_eps", 1e-6),
metadata._fuse_kernel_compute_dtype,
getattr(layer, "cache_quant_type_str", "none"),
layer.use_neox_rotary_style,

View File

@@ -28,6 +28,7 @@ from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethod
if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.utils import get_tensor
class Attention(nn.Layer):
@@ -49,6 +50,7 @@ class Attention(nn.Layer):
linear_smooth: paddle.Tensor = None,
use_neox_rotary_style: bool = False,
use_qk_norm: bool = False,
rms_norm_eps: float = 1e-6,
) -> None:
"""
Initializes `LMLayer` with the given parameters.
@@ -63,6 +65,8 @@ class Attention(nn.Layer):
prefix (str, optional): The name of current layer. Defaults to "".
linear_shift (Optional[paddle.Tensor], optional): The shift of linear. Defaults to None.
linear_smooth (Optional[paddle.Tensor], optional): The smooth of linear. Defaults to None.
use_qk_norm (bool, optional): Whether to apply rmsnorm on QA after rope. Defaults to False.
rms_norm_eps (float, optional): The epsilon of RMSNorm. Defaults to 1e-6.
Raises:
ValueError: If the `v_head_dim` is less than 0.
@@ -102,6 +106,27 @@ class Attention(nn.Layer):
logger.info(
f"Attention is running in cache kv {self.kvcache_quant_method.cache_quant_config.quant_type} mode"
)
self.use_qk_norm = use_qk_norm
self.rms_norm_eps = rms_norm_eps
if self.use_qk_norm:
self.q_norm_key = f"{self.prefix}.q_norm"
self.k_norm_key = f"{self.prefix}.k_norm"
self.init_weight()
def init_weight(self):
self.q_norm_weight = self.create_parameter(
shape=[self.qk_head_dim],
dtype=self._dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
self.k_norm_weight = self.create_parameter(
shape=[self.qk_head_dim],
dtype=self._dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""
@@ -109,6 +134,11 @@ class Attention(nn.Layer):
"""
if self.kvcache_quant_method is not None:
self.kvcache_quant_method.create_weights(self, state_dict)
if self.use_qk_norm:
q_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.q_norm_key + ".weight")))
k_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.k_norm_key + ".weight")))
self.q_norm_weight.set_value(q_norm_weight_tensor)
self.k_norm_weight.set_value(k_norm_weight_tensor)
def forward(
self,

View File

@@ -60,6 +60,9 @@ def append_attention(
linear_shift: Optional[paddle.Tensor] = None,
linear_smooth: Optional[paddle.Tensor] = None,
kv_signal_data: Optional[paddle.Tensor] = None,
q_norm_weight: Optional[paddle.Tensor] = None,
k_norm_weight: Optional[paddle.Tensor] = None,
rms_norm_eps: float = 1e-6,
compute_type: str = "bf16",
cache_quant_type: str = "none",
use_neox_rotary_style: bool = False,
@@ -114,6 +117,9 @@ def append_attention(
linear_shift,
linear_smooth,
kv_signal_data,
q_norm_weight,
k_norm_weight,
rms_norm_eps,
compute_type,
cache_quant_type,
use_neox_rotary_style,

View File

@@ -17,6 +17,7 @@ import unittest
import numpy as np
import paddle
from paddle.incubate.nn.functional import fused_rms_norm
paddle.seed(10)
@@ -157,6 +158,8 @@ def naive_attention_impl(
cache_k_dequant_scales=None,
cache_v_dequant_scales=None,
use_cachekv_int8="None",
q_norm_weight=None,
k_norm_weight=None,
):
batch = query.shape[0]
heads = query.shape[1]
@@ -244,6 +247,27 @@ def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head
return q, k, v, qkv
def apply_qk_norm(head_dim, dtype, q, k):
q_norm_weight = np.random.random([head_dim]) / 10
k_norm_weight = np.random.random([head_dim]) / 10
q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype=dtype)
k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype=dtype)
print("q:", q.shape)
print("k:", k.shape)
bs, q_num_head, seq_len, dim_head = q.shape
_, kv_num_head, _, _ = k.shape
q = q.reshape([-1, head_dim])
k = k.reshape([-1, head_dim])
print("q:", q)
q = fused_rms_norm(q, q_norm_weight_tensor, None, 1e-5)[0]
print("q after norm:", q)
k = fused_rms_norm(k, k_norm_weight_tensor, None, 1e-5)[0]
q = q.reshape([-1, q_num_head, seq_len, dim_head])
k = k.reshape([-1, kv_num_head, seq_len, dim_head])
return q, k, q_norm_weight_tensor, k_norm_weight_tensor
def split_query_by_phase(
query,
seq_lens_encoder,
@@ -324,6 +348,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.softmax_scale = self.dim_head**-0.5
self.rope_theta = 10000
self.dtype = "float16"
self.use_qk_norm = True
self.init_tensor()
def init_tensor(self):
@@ -394,6 +419,11 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
)
q, k = self.rope._apply_rope(self.rope_emb, q, k, causal=True)
if self.use_qk_norm:
q, k, q_norm_weight, k_norm_weight = apply_qk_norm(self.dim_head, self.dtype, q, k)
else:
q_norm_weight = None
k_norm_weight = None
out_ = naive_attention_impl(
q,
k,
@@ -476,6 +506,9 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
None, # linear_shift
None, # linear_smooth
None, # kv_signal_data
q_norm_weight, # q_norm_weight
k_norm_weight, # k_norm_weight
1e-6,
"fp16",
"none", # cache_quant_type
self.use_neox_rotary_style,
@@ -580,6 +613,7 @@ class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
self.softmax_scale = self.dim_head**-0.5
self.rope_theta = 10000
self.dtype = "float16"
self.use_qk_norm = False
self.init_tensor()