support qk norm (#3145)

This commit is contained in:
Yuan Xiaolan
2025-08-05 16:46:14 +08:00
committed by GitHub
parent 4a10e29804
commit 7ce00e597c
17 changed files with 791 additions and 201 deletions

View File

@@ -73,6 +73,9 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
const paddle::optional<paddle::Tensor>& out_linear_shifts, const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths, const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& kv_signal_data, const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string& cache_quant_type_str, const std::string& cache_quant_type_str,
const bool use_neox_rotary_style, const bool use_neox_rotary_style,
const bool rope_3d, const bool rope_3d,
@@ -223,7 +226,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
main_stream, main_stream,
&qkv_out, &qkv_out,
const_cast<paddle::Tensor*>(&key_cache), const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache)); const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}; };
if (qkv_out_scales) { if (qkv_out_scales) {
@@ -339,7 +345,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
exec_stream, exec_stream,
&qkv_out, &qkv_out,
const_cast<paddle::Tensor*>(&key_cache), const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache)); const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
} else { } else {
DecoderWriteCacheWithRoPEKernel<data_t, data_t>( DecoderWriteCacheWithRoPEKernel<data_t, data_t>(
meta_data, meta_data,
@@ -363,7 +372,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
exec_stream, exec_stream,
&qkv_out, &qkv_out,
const_cast<paddle::Tensor*>(&key_cache), const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache)); const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
} }
} }
@@ -430,6 +442,9 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::optional<paddle::Tensor>& out_linear_shifts, const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths, const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& kv_signal_data, const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string& compute_dtype, const std::string& compute_dtype,
const std::string& cache_quant_type_str, const std::string& cache_quant_type_str,
const bool use_neox_rotary_style, const bool use_neox_rotary_style,
@@ -500,6 +515,9 @@ std::vector<paddle::Tensor> AppendAttention(
out_linear_shifts, out_linear_shifts,
out_linear_smooths, out_linear_smooths,
kv_signal_data, kv_signal_data,
q_norm_weight,
k_norm_weight,
rms_norm_eps,
cache_quant_type_str, cache_quant_type_str,
use_neox_rotary_style, use_neox_rotary_style,
rope_3d, rope_3d,
@@ -577,6 +595,9 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape, const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape, const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape, const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
const float rms_norm_eps,
const std::string& compute_dtype, const std::string& compute_dtype,
const std::string& cache_quant_type_str, const std::string& cache_quant_type_str,
const bool use_neox_rotary_style, const bool use_neox_rotary_style,
@@ -637,6 +658,9 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype, const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype, const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
const paddle::optional<paddle::DataType>& kv_signal_data_dtype, const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
const float rms_norm_eps,
const std::string& compute_dtype, const std::string& compute_dtype,
const std::string& cache_quant_type_str, const std::string& cache_quant_type_str,
const bool use_neox_rotary_style, const bool use_neox_rotary_style,
@@ -714,7 +738,9 @@ PD_BUILD_STATIC_OP(append_attention)
paddle::Optional("cache_v_zp"), paddle::Optional("cache_v_zp"),
paddle::Optional("out_linear_shifts"), paddle::Optional("out_linear_shifts"),
paddle::Optional("out_linear_smooths"), paddle::Optional("out_linear_smooths"),
paddle::Optional("kv_signal_data")}) paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
.Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"}) .Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})
.SetInplaceMap({{"key_cache", "key_cache_out"}, .SetInplaceMap({{"key_cache", "key_cache_out"},
{"value_cache", "value_cache_out"}}) {"value_cache", "value_cache_out"}})
@@ -732,7 +758,8 @@ PD_BUILD_STATIC_OP(append_attention)
"encoder_max_partition_size: int", "encoder_max_partition_size: int",
"speculate_max_draft_token_num: int", "speculate_max_draft_token_num: int",
"causal: bool", "causal: bool",
"speculate_decoder: bool"}) "speculate_decoder: bool",
"rms_norm_eps: float"})
.SetKernelFn(PD_KERNEL(AppendAttention)) .SetKernelFn(PD_KERNEL(AppendAttention))
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape)) .SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype)); .SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));

View File

@@ -18,6 +18,142 @@
#include "mma_tensor_op.cuh" #include "mma_tensor_op.cuh"
#include "utils.cuh" #include "utils.cuh"
template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_rope_qk_norm_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
// head_size // 2]
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int head_size,
const int block_size,
const uint32_t elem_cnt,
const int kv_num_heads,
const bool rope_3d,
const T* q_norm_weight,
const T* k_norm_weight,
const float rms_norm_eps) {
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadKVT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
LoadT src_vec;
LoadBiasT out_vec;
LoadKVT cache_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
int64_t global_warp_idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t all_warp_num = gridDim.x * blockDim.x;
int64_t all_head_dim = elem_cnt / head_size;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
// const int64_t offset = 2 * hidden_size;
const int half_head_size = head_size / 2;
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) {
int64_t linear_index = gloabl_hi * head_size + threadIdx.y * VecSize;
const int ori_bi = linear_index / hidden_size;
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
const int* block_table_now = nullptr;
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
const int block_offset = write_seq_id % block_size;
const uint32_t ori_idx =
start_token_idx * hidden_size + hi * head_size + h_bias;
const int bias_idx = hi * head_size + h_bias;
Load<T, VecSize>(&quant_qkv[ori_idx], &src_vec);
if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
}
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]);
if (hi < num_heads + kv_num_heads) {
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
out_vec[2 * i] =
static_cast<T>(tmp1);
out_vec[2 * i + 1] =
static_cast<T>(tmp2);
} else {
out_vec[2 * i] = src_vec[2 * i];
out_vec[2 * i + 1] = src_vec[2 * i + 1];
}
}
if (hi < (num_heads + kv_num_heads)) { // q k
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / head_size, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
LoadT q_norm_vec, k_norm_vec;
if (hi < num_heads) { // q
Load<T, VecSize>(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
}
} else { // k
Load<T, VecSize>(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec);
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
}
}
}
if (hi < num_heads) {
// write q
Store<T, VecSize>(out_vec, &qkv_out[ori_idx]);
} else {
// quant + write k/v
const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads;
const uint32_t tgt_idx =
block_idx * kv_num_heads * block_size * head_size +
kv_head_idx * block_size * head_size + block_offset * head_size +
h_bias;
if (hi < num_heads + kv_num_heads) {
Store<T, VecSize>(out_vec, &key_cache[tgt_idx]);
} else {
Store<T, VecSize>(out_vec, &value_cache[tgt_idx]);
}
}
}
}
template <typename T, int VecSize = 1> template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_rope_kernel( __global__ void append_decode_cache_T_rope_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,

View File

@@ -15,6 +15,70 @@
#include "decoder_write_cache_with_rope_kernel.h" #include "decoder_write_cache_with_rope_kernel.h"
#include "utils.cuh" #include "utils.cuh"
template <typename T, typename QKV_TYPE>
void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
T* key_cache,
T* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
const float* sin_emb,
const float* qkv_out_scales,
const T* qkv_biases,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int kv_num_heads,
const int dim_head,
const int block_size,
const int bsz,
const cudaStream_t& stream,
const bool use_neox_style,
const bool rope_3d,
const T* q_norm_weight,
const T* k_norm_weight,
const float rms_norm_eps) {
const uint32_t elem_nums =
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
: bsz * (num_heads + 2 * kv_num_heads) * dim_head;
assert(dim_head == 128 && "dim_head must be 128");
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
dim3 block_dim(blocksize / kWarpSize, kWarpSize, 1);
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize>
<<<grid_size, block_dim, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d,
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
template <typename T, typename QKV_TYPE> template <typename T, typename QKV_TYPE>
void append_decode_cache_rope(const QKV_TYPE* qkv, void append_decode_cache_rope(const QKV_TYPE* qkv,
T* key_cache, T* key_cache,
@@ -441,7 +505,10 @@ void DecoderWriteCacheWithRoPEKernel(
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out) { paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps) {
typedef cascade_attn_type_traits<T> traits_; typedef cascade_attn_type_traits<T> traits_;
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_; typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
typedef typename traits_::type DataType_; typedef typename traits_::type DataType_;
@@ -464,107 +531,77 @@ void DecoderWriteCacheWithRoPEKernel(
? rotary_embs.get().data<float>() + max_seq_len * dim_head ? rotary_embs.get().data<float>() + max_seq_len * dim_head
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2; : rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
} }
if (cache_quant_type_str == "none") {
append_decode_cache_rope( if (q_norm_weight && k_norm_weight) {
reinterpret_cast<const QKV_TYPE*>(qkv_ptr), if (cache_quant_type_str == "none") {
reinterpret_cast<DataType_*>(key_cache_out->data<T>()), append_decode_cache_rope_qk_norm(
reinterpret_cast<DataType_*>(value_cache_out->data<T>()), reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(qkv_out->data<T>()), reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
block_tables.data<int>(), reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
batch_id_per_token.data<int>(), reinterpret_cast<DataType_*>(qkv_out->data<T>()),
cu_seqlens_q.data<int>(), block_tables.data<int>(),
seq_lens.data<int>(), batch_id_per_token.data<int>(),
seq_lens_encoder.data<int>(), cu_seqlens_q.data<int>(),
cos_emb, seq_lens.data<int>(),
sin_emb, seq_lens_encoder.data<int>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr, cos_emb,
qkv_biases ? reinterpret_cast<DataType_*>( sin_emb,
const_cast<T*>(qkv_biases.get().data<T>())) qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
: nullptr, qkv_biases ? reinterpret_cast<DataType_*>(
max_seq_len, const_cast<T*>(qkv_biases.get().data<T>()))
max_blocks_per_seq, : nullptr,
num_heads, max_seq_len,
kv_num_heads, max_blocks_per_seq,
dim_head, num_heads,
block_size, kv_num_heads,
bsz, dim_head,
stream, block_size,
use_neox_rotary_style, bsz,
rope_3d); stream,
} else if (cache_quant_type_str == "cache_int8") { use_neox_rotary_style,
bool is_scale_channel_wise = false; rope_3d,
if (cache_k_scale && cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) { reinterpret_cast<const DataType_*>(q_norm_weight.get().data<T>()),
is_scale_channel_wise = true; reinterpret_cast<const DataType_*>(k_norm_weight.get().data<T>()),
} rms_norm_eps);
if (is_scale_channel_wise) {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else { } else {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false>( PD_THROW(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr), "append_decode_cache_rope_qk_norm not support cachekv quant yet");
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} }
} else if (cache_quant_type_str == "cache_fp8") { } else {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false, true>( if (cache_quant_type_str == "none") {
append_decode_cache_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int8") {
bool is_scale_channel_wise = false;
if (cache_k_scale && cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) {
is_scale_channel_wise = true;
}
if (is_scale_channel_wise) {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr), reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(), key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(), value_cache_out->data<uint8_t>(),
@@ -596,49 +633,117 @@ void DecoderWriteCacheWithRoPEKernel(
stream, stream,
use_neox_rotary_style, use_neox_rotary_style,
rope_3d); rope_3d);
} else if (cache_quant_type_str == "cache_int4_zp") { } else {
append_decode_cache_int4_rope( append_decode_cache_int8_rope<DataType_, QKV_TYPE, false>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr), reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(), key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(), value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())), reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(), block_tables.data<int>(),
batch_id_per_token.data<int>(), batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
seq_lens.data<int>(), seq_lens.data<int>(),
seq_lens_encoder.data<int>(), seq_lens_encoder.data<int>(),
cos_emb, cos_emb,
sin_emb, sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr, qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>( qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>())) const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr, : nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>( cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>())) const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
}
} else if (cache_quant_type_str == "cache_fp8") {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr, : nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>( cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>())) const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr, : nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>( cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>())) const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr, : nullptr,
max_seq_len, max_seq_len,
max_blocks_per_seq, max_blocks_per_seq,
num_heads, num_heads,
kv_num_heads, kv_num_heads,
dim_head, dim_head,
block_size, block_size,
bsz, bsz,
stream, stream,
use_neox_rotary_style, use_neox_rotary_style,
rope_3d); rope_3d);
} else { } else if (cache_quant_type_str == "cache_int4_zp") {
PD_THROW( append_decode_cache_int4_rope(
"cache_quant_type_str should be one of [none, cache_int8, cache_fp8 " reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
"cache_int4_zp]"); key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else {
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, cache_fp8 "
"cache_int4_zp]");
}
} }
} }
@@ -667,7 +772,10 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out); paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
template void template void
DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>( DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
@@ -694,7 +802,10 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out); paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>( template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
const AppendAttnMetaData& meta_data, const AppendAttnMetaData& meta_data,
@@ -720,7 +831,10 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out); paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>( template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const AppendAttnMetaData& meta_data, const AppendAttnMetaData& meta_data,
@@ -746,4 +860,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out); paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);

View File

@@ -40,4 +40,6 @@ void DecoderWriteCacheWithRoPEKernel(
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out); paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight, const float rms_norm_eps);

View File

@@ -358,7 +358,7 @@ __global__ void GQAVariableLengthRotaryKernel(
linear_index < elem_cnt; linear_index < elem_cnt;
linear_index += step) { linear_index += step) {
const int token_idx = linear_index / offset; const int token_idx = linear_index / offset;
const int ori_bi = batch_id_per_token[token_idx];; const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens[ori_bi] == 0) continue; if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset; const int bias = linear_index % offset;
const int hi = bias / last_dim; const int hi = bias / last_dim;
@@ -405,6 +405,94 @@ __global__ void GQAVariableLengthRotaryKernel(
} }
} }
template <typename T, int VecSize = 1>
__global__ void GQAVariableLengthRotaryQKNormKernel(
const T *qkv,
const float *cos_emb,
const float *sin_emb,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
T *qkv_out,
const int64_t elem_cnt,
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int last_dim,
const bool rope_3d,
const T* q_norm_weight,
const T* k_norm_weight,
const float rms_norm_eps
) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
LoadT src_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
int64_t global_warp_idx = blockDim.x * blockIdx.x + threadIdx.x;
int64_t all_warp_num = gridDim.x * blockDim.x;
const int half_lastdim = last_dim / 2;
const int offset = (q_num_head + kv_num_head) * last_dim;
const int all_head_num = elem_cnt / last_dim;
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) {
int64_t linear_index = gloabl_hi * last_dim + threadIdx.y * VecSize;
const int token_idx = linear_index / offset;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / last_dim;
const int h_bias = bias % last_dim;
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
const int64_t base_idx =
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
h_bias;
Load<T, VecSize>(&qkv[base_idx], &src_vec);
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
const float input_left = static_cast<float>(src_vec[2 * i]);
const float input_right = static_cast<float>(src_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
src_vec[2 * i] = static_cast<T>(tmp1);
src_vec[2 * i + 1] = static_cast<T>(tmp2);
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
}
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / last_dim, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
LoadT q_norm_vec, k_norm_vec;
if (hi < q_num_head) {
Load<T, VecSize>(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
}
} else {
Load<T, VecSize>(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec);
for (int i = 0; i < VecSize; i++) {
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
}
}
Store<T, VecSize>(src_vec, &qkv_out[base_idx]);
}
}
template <typename T, int VecSize = 1> template <typename T, int VecSize = 1>
__global__ void GQAVariableLengthRotaryKernel( __global__ void GQAVariableLengthRotaryKernel(
const T *qkv, const T *qkv,
@@ -1568,6 +1656,66 @@ void rotary_qk_variable(
} }
} }
template <typename T, typename QKV_TYPE>
void gqa_rotary_qk_norm_variable(
T *qkv_out, // [token_num, 3, num_head, dim_head]
const QKV_TYPE *qkv_input, // qkv
const float *qkv_out_scales, // [3, num_head, dim_head]
const T *qkv_bias,
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const int token_num,
const int num_heads,
const int kv_num_heads,
const int seq_len,
const int input_output_len,
const int dim_head,
const cudaStream_t &stream,
bool use_neox_style = false,
bool rope_3d = false,
const T *q_norm_weight = nullptr,
const T *k_norm_weight = nullptr,
const float rms_norm_eps = 1e-6) {
int64_t elem_nums =
qkv_out_scales
? token_num * (num_heads + 2 * kv_num_heads) * dim_head
: token_num * (num_heads + kv_num_heads) * dim_head; // for all q k v
assert(dim_head == 128 && "dim_head must be 128");
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
dim3 Blocks(grid_size/kWarpSize, kWarpSize, 1);
const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
GQAVariableLengthRotaryQKNormKernel<T, PackSize>
<<<grid_size, Blocks, 0, stream>>>(
reinterpret_cast<const T *>(qkv_input),
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out,
elem_nums,
num_heads,
kv_num_heads,
seq_len,
dim_head,
rope_3d,
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
template <typename T, typename QKV_TYPE> template <typename T, typename QKV_TYPE>
void gqa_rotary_qk_variable( void gqa_rotary_qk_variable(
T *qkv_out, // [token_num, 3, num_head, dim_head] T *qkv_out, // [token_num, 3, num_head, dim_head]

View File

@@ -46,7 +46,10 @@ void EncoderWriteCacheWithRopeKernel(
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out) { paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps) {
auto token_num = meta_data.token_nums; auto token_num = meta_data.token_nums;
auto num_heads = meta_data.q_num_heads; auto num_heads = meta_data.q_num_heads;
auto kv_num_heads = meta_data.kv_num_heads; auto kv_num_heads = meta_data.kv_num_heads;
@@ -56,28 +59,9 @@ void EncoderWriteCacheWithRopeKernel(
is_scale_channel_wise = true; is_scale_channel_wise = true;
} }
if (num_heads == kv_num_heads) { if (q_norm_weight && k_norm_weight) {
rotary_qk_variable( if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) {
qkv_out->data<T>(), gqa_rotary_qk_norm_variable(
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
max_seq_len,
rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d);
} else {
if (!is_scale_channel_wise) {
gqa_rotary_qk_variable(
qkv_out->data<T>(), qkv_out->data<T>(),
qkv.data<QKV_TYPE>(), qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr, qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
@@ -95,31 +79,80 @@ void EncoderWriteCacheWithRopeKernel(
head_dim, head_dim,
stream, stream,
use_neox_style, use_neox_style,
rope_3d); rope_3d,
q_norm_weight ? q_norm_weight.get().data<T>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<T>() : nullptr,
rms_norm_eps);
} else { } else {
gqa_rotary_qk_quant_variable( PD_THROW(
qkv_out->data<T>(), "gqa_rotary_qk_norm_variable only support gqa mode. channel wise scale and neox style are not supported");
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
cache_k_scale ? cache_k_scale.get().data<T>() : nullptr,
cache_v_scale ? cache_v_scale.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d);
} }
} else {
if (num_heads == kv_num_heads) {
rotary_qk_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
max_seq_len,
rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d);
} else {
if (!is_scale_channel_wise) {
gqa_rotary_qk_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d);
} else {
gqa_rotary_qk_quant_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
cache_k_scale ? cache_k_scale.get().data<T>() : nullptr,
cache_v_scale ? cache_v_scale.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d);
}
}
} }
const uint32_t block_size = meta_data.block_size; const uint32_t block_size = meta_data.block_size;
if (cache_quant_type_str == "none") { if (cache_quant_type_str == "none") {

View File

@@ -43,4 +43,7 @@ EncoderWriteCacheWithRopeKernel<paddle::bfloat16, paddle::bfloat16>(
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out); paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);

View File

@@ -42,4 +42,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::bfloat16, int>(
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out); paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);

View File

@@ -42,4 +42,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, paddle::float16>(
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out); paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);

View File

@@ -42,4 +42,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, int>(
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out); paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);

View File

@@ -559,3 +559,37 @@ template <typename T, bool IsFP8>inline __device__ static void convert_c8(T * re
convert_int8(result, source); convert_int8(result, source);
} }
} }
constexpr int kWarpSize = 32;
template<typename T>
inline __device__ void WelfordCombine1(T b_m2, T* m2) {
*m2 += b_m2;
}
template<typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) {
*m2 = thread_m2;
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
T b_m2 = __shfl_xor_sync(0xffffffff, *m2, mask);
WelfordCombine1(b_m2, m2);
}
}
template<typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) {
WelfordWarpReduce<T, thread_group_width>(thread_m2, m2);
}
template <typename T>
__inline__ __device__ T Rsqrt(T x);
template <>
__inline__ __device__ float Rsqrt<float>(float x) {
return rsqrt(x);
}
template <>
__inline__ __device__ double Rsqrt<double>(double x) {
return rsqrt(x);
}

View File

@@ -78,6 +78,9 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::optional<paddle::Tensor> &out_linear_shifts, const paddle::optional<paddle::Tensor> &out_linear_shifts,
const paddle::optional<paddle::Tensor> &out_linear_smooths, const paddle::optional<paddle::Tensor> &out_linear_smooths,
const paddle::optional<paddle::Tensor> &kv_signal_data, const paddle::optional<paddle::Tensor> &kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string &compute_dtype, const std::string &cache_quant_type_str, const std::string &compute_dtype, const std::string &cache_quant_type_str,
const bool use_neox_rotary_style, const bool rope_3d, const bool use_neox_rotary_style, const bool rope_3d,
const int max_input_length, const float quant_max_bound, const int max_input_length, const float quant_max_bound,

View File

@@ -43,6 +43,11 @@
__VA_ARGS__ \ __VA_ARGS__ \
break; \ break; \
} \ } \
case 48: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 48; \
__VA_ARGS__ \
break; \
} \
case 64: { \ case 64: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 64; \ constexpr size_t NUM_EXPERTS_PER_RANK = 64; \
__VA_ARGS__ \ __VA_ARGS__ \

View File

@@ -262,6 +262,9 @@ class AppendAttentionBackend(AttentionBackend):
layer.linear_shift, layer.linear_shift,
layer.linear_smooth, layer.linear_smooth,
metadata.kv_signal_data_list[layer.layer_id], metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "rms_norm_eps", 1e-6),
metadata._fuse_kernel_compute_dtype, metadata._fuse_kernel_compute_dtype,
getattr(layer, "cache_quant_type_str", "none"), getattr(layer, "cache_quant_type_str", "none"),
layer.use_neox_rotary_style, layer.use_neox_rotary_style,

View File

@@ -28,6 +28,7 @@ from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethod
if TYPE_CHECKING: if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.utils import get_tensor
class Attention(nn.Layer): class Attention(nn.Layer):
@@ -49,6 +50,7 @@ class Attention(nn.Layer):
linear_smooth: paddle.Tensor = None, linear_smooth: paddle.Tensor = None,
use_neox_rotary_style: bool = False, use_neox_rotary_style: bool = False,
use_qk_norm: bool = False, use_qk_norm: bool = False,
rms_norm_eps: float = 1e-6,
) -> None: ) -> None:
""" """
Initializes `LMLayer` with the given parameters. Initializes `LMLayer` with the given parameters.
@@ -63,6 +65,8 @@ class Attention(nn.Layer):
prefix (str, optional): The name of current layer. Defaults to "". prefix (str, optional): The name of current layer. Defaults to "".
linear_shift (Optional[paddle.Tensor], optional): The shift of linear. Defaults to None. linear_shift (Optional[paddle.Tensor], optional): The shift of linear. Defaults to None.
linear_smooth (Optional[paddle.Tensor], optional): The smooth of linear. Defaults to None. linear_smooth (Optional[paddle.Tensor], optional): The smooth of linear. Defaults to None.
use_qk_norm (bool, optional): Whether to apply rmsnorm on QA after rope. Defaults to False.
rms_norm_eps (float, optional): The epsilon of RMSNorm. Defaults to 1e-6.
Raises: Raises:
ValueError: If the `v_head_dim` is less than 0. ValueError: If the `v_head_dim` is less than 0.
@@ -102,6 +106,27 @@ class Attention(nn.Layer):
logger.info( logger.info(
f"Attention is running in cache kv {self.kvcache_quant_method.cache_quant_config.quant_type} mode" f"Attention is running in cache kv {self.kvcache_quant_method.cache_quant_config.quant_type} mode"
) )
self.use_qk_norm = use_qk_norm
self.rms_norm_eps = rms_norm_eps
if self.use_qk_norm:
self.q_norm_key = f"{self.prefix}.q_norm"
self.k_norm_key = f"{self.prefix}.k_norm"
self.init_weight()
def init_weight(self):
self.q_norm_weight = self.create_parameter(
shape=[self.qk_head_dim],
dtype=self._dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
self.k_norm_weight = self.create_parameter(
shape=[self.qk_head_dim],
dtype=self._dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]): def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
""" """
@@ -109,6 +134,11 @@ class Attention(nn.Layer):
""" """
if self.kvcache_quant_method is not None: if self.kvcache_quant_method is not None:
self.kvcache_quant_method.create_weights(self, state_dict) self.kvcache_quant_method.create_weights(self, state_dict)
if self.use_qk_norm:
q_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.q_norm_key + ".weight")))
k_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.k_norm_key + ".weight")))
self.q_norm_weight.set_value(q_norm_weight_tensor)
self.k_norm_weight.set_value(k_norm_weight_tensor)
def forward( def forward(
self, self,

View File

@@ -60,6 +60,9 @@ def append_attention(
linear_shift: Optional[paddle.Tensor] = None, linear_shift: Optional[paddle.Tensor] = None,
linear_smooth: Optional[paddle.Tensor] = None, linear_smooth: Optional[paddle.Tensor] = None,
kv_signal_data: Optional[paddle.Tensor] = None, kv_signal_data: Optional[paddle.Tensor] = None,
q_norm_weight: Optional[paddle.Tensor] = None,
k_norm_weight: Optional[paddle.Tensor] = None,
rms_norm_eps: float = 1e-6,
compute_type: str = "bf16", compute_type: str = "bf16",
cache_quant_type: str = "none", cache_quant_type: str = "none",
use_neox_rotary_style: bool = False, use_neox_rotary_style: bool = False,
@@ -114,6 +117,9 @@ def append_attention(
linear_shift, linear_shift,
linear_smooth, linear_smooth,
kv_signal_data, kv_signal_data,
q_norm_weight,
k_norm_weight,
rms_norm_eps,
compute_type, compute_type,
cache_quant_type, cache_quant_type,
use_neox_rotary_style, use_neox_rotary_style,

View File

@@ -17,6 +17,7 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
from paddle.incubate.nn.functional import fused_rms_norm
paddle.seed(10) paddle.seed(10)
@@ -157,6 +158,8 @@ def naive_attention_impl(
cache_k_dequant_scales=None, cache_k_dequant_scales=None,
cache_v_dequant_scales=None, cache_v_dequant_scales=None,
use_cachekv_int8="None", use_cachekv_int8="None",
q_norm_weight=None,
k_norm_weight=None,
): ):
batch = query.shape[0] batch = query.shape[0]
heads = query.shape[1] heads = query.shape[1]
@@ -244,6 +247,27 @@ def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head
return q, k, v, qkv return q, k, v, qkv
def apply_qk_norm(head_dim, dtype, q, k):
q_norm_weight = np.random.random([head_dim]) / 10
k_norm_weight = np.random.random([head_dim]) / 10
q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype=dtype)
k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype=dtype)
print("q:", q.shape)
print("k:", k.shape)
bs, q_num_head, seq_len, dim_head = q.shape
_, kv_num_head, _, _ = k.shape
q = q.reshape([-1, head_dim])
k = k.reshape([-1, head_dim])
print("q:", q)
q = fused_rms_norm(q, q_norm_weight_tensor, None, 1e-5)[0]
print("q after norm:", q)
k = fused_rms_norm(k, k_norm_weight_tensor, None, 1e-5)[0]
q = q.reshape([-1, q_num_head, seq_len, dim_head])
k = k.reshape([-1, kv_num_head, seq_len, dim_head])
return q, k, q_norm_weight_tensor, k_norm_weight_tensor
def split_query_by_phase( def split_query_by_phase(
query, query,
seq_lens_encoder, seq_lens_encoder,
@@ -324,6 +348,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.softmax_scale = self.dim_head**-0.5 self.softmax_scale = self.dim_head**-0.5
self.rope_theta = 10000 self.rope_theta = 10000
self.dtype = "float16" self.dtype = "float16"
self.use_qk_norm = True
self.init_tensor() self.init_tensor()
def init_tensor(self): def init_tensor(self):
@@ -394,6 +419,11 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
) )
q, k = self.rope._apply_rope(self.rope_emb, q, k, causal=True) q, k = self.rope._apply_rope(self.rope_emb, q, k, causal=True)
if self.use_qk_norm:
q, k, q_norm_weight, k_norm_weight = apply_qk_norm(self.dim_head, self.dtype, q, k)
else:
q_norm_weight = None
k_norm_weight = None
out_ = naive_attention_impl( out_ = naive_attention_impl(
q, q,
k, k,
@@ -476,6 +506,9 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
None, # linear_shift None, # linear_shift
None, # linear_smooth None, # linear_smooth
None, # kv_signal_data None, # kv_signal_data
q_norm_weight, # q_norm_weight
k_norm_weight, # k_norm_weight
1e-6,
"fp16", "fp16",
"none", # cache_quant_type "none", # cache_quant_type
self.use_neox_rotary_style, self.use_neox_rotary_style,
@@ -580,6 +613,7 @@ class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
self.softmax_scale = self.dim_head**-0.5 self.softmax_scale = self.dim_head**-0.5
self.rope_theta = 10000 self.rope_theta = 10000
self.dtype = "float16" self.dtype = "float16"
self.use_qk_norm = False
self.init_tensor() self.init_tensor()