qk norm for speculate decode C16 (#3637)

This commit is contained in:
Yuan Xiaolan
2025-09-03 14:53:56 +08:00
committed by GitHub
parent d22d3de256
commit fa58a9fa8f
6 changed files with 470 additions and 160 deletions

View File

@@ -277,7 +277,10 @@ void AppendAttentionKernel(
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
} else {
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
meta_data,
@@ -300,7 +303,10 @@ void AppendAttentionKernel(
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
} else {
if (qkv_out_scales) {

View File

@@ -120,7 +120,6 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
float row_variance =
max(warp_m2 / head_size, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
if (hi < num_heads) { // q
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll
@@ -129,6 +128,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
}
} else { // k
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
}

View File

@@ -18,6 +18,166 @@
#include "mma_tensor_op.cuh"
#include "utils.cuh"
template <typename T, int VecSize = 1, typename InT = T>
__global__ void append_speculate_cache_T_rope_qk_norm_kernel(
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
// head_size]
T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size,
// head_size // 2]
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
// head_size // 2]
T* __restrict__ q_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_decoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
const float*
qkv_out_scales, // [(num_heads + 2 * gqa_group_size) * head_size]
const T* qkv_biases, // [num_head + 2 * gqa_group_size, dim_head]
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int output_inner_dim,
const int head_size,
const int block_size,
const int elem_cnt,
const int gqa_group_size,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps) {
using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
LoadInT src_vec;
LoadFloat scale_vec;
LoadT bias_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
LoadFloat tmp_vec;
LoadFloat q_norm_vec;
LoadFloat k_norm_vec;
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.y;
int64_t all_head_dim = elem_cnt / head_size;
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size;
const int half_head_size = head_size / 2;
for (int global_hi = global_warp_idx; global_hi < all_head_dim; global_hi += all_warp_num) {
int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize;
const int token_id = linear_index / hidden_size;
const int ori_bi = batch_id_per_token[token_id];
if (seq_lens_decoder[ori_bi] == 0) continue;
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
const int write_seq_id =
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
if (write_seq_id == 0) continue;
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) {
printf(
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
}
const int block_offset = write_seq_id % block_size;
const int write_q_idx =
token_id * output_inner_dim * head_size + hi * head_size + h_bias;
const int bias_idx = hi * head_size + h_bias;
Load<InT, VecSize>(&qkv[linear_index], &src_vec);
if (qkv_biases) {
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
}
if (qkv_out_scales) {
Load<float, VecSize>(&qkv_out_scales[bias_idx], &scale_vec);
}
if (hi < num_heads + gqa_group_size) {
// q k rope
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// add_bias + rope
float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]);
if (qkv_out_scales) {
input_left *= scale_vec[2 * i];
input_right *= scale_vec[2 * i + 1];
}
if (qkv_biases) {
input_left = input_left + static_cast<float>(bias_vec[2 * i]);
input_right = input_right + static_cast<float>(bias_vec[2 * i + 1]);
}
if (hi < num_heads + gqa_group_size) {
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
tmp_vec[2 * i] = tmp1;
tmp_vec[2 * i + 1] = tmp2;
} else {
bias_vec[2 * i] = static_cast<T>(input_left);
bias_vec[2 * i + 1] = static_cast<T>(input_right);
}
}
if (hi < (num_heads + gqa_group_size)) {
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / head_size, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
if (hi < num_heads) {
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
}
} else {
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
}
}
}
if (hi < num_heads) {
// write q
Store<T, VecSize>(bias_vec, &q_out[write_q_idx]);
} else {
// write k/v
const int kv_head_idx = (hi - num_heads) % gqa_group_size;
const int tgt_idx = (block_idx * gqa_group_size * block_size * head_size +
kv_head_idx * block_size * head_size +
block_offset * head_size + h_bias);
// write
if (hi < num_heads + gqa_group_size) {
Store<T, VecSize>(bias_vec, &key_cache[tgt_idx]);
} else {
Store<T, VecSize>(bias_vec, &value_cache[tgt_idx]);
}
}
}
}
template <int VecSize = 4, int HeadDim = 128>
__global__ void append_clear_cache_int8_block(
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,

View File

@@ -15,6 +15,77 @@
#include "speculate_write_cache_with_rope_kernel.h"
#include "utils.cuh"
template <typename T, typename QKV_TYPE>
void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
T* key_cache,
T* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
const float* sin_emb,
const float* qkv_out_scales,
const T* qkv_biases,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int kv_num_heads,
const int dim_head,
const int block_size,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps) {
int output_inner_dim = num_heads + 2 * kv_num_heads;
const uint32_t elem_nums =
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
: token_num * (num_heads + 2 * kv_num_heads) * dim_head;
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
if (use_neox_style) {
PD_THROW(
"append_speculate_cache_rope_qk_norm not support neox rope yet");
} else {
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
append_speculate_cache_T_rope_qk_norm_kernel<T, PackSize>
<<<grid_size, block_dim, 0, stream>>>(qkv,
key_cache,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
max_seq_len,
max_blocks_per_seq,
num_heads,
output_inner_dim,
dim_head,
block_size,
elem_nums,
kv_num_heads,
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
}
// rope + write
template <typename T, typename QKV_TYPE>
void append_speculate_cache_rope(const QKV_TYPE* qkv,
@@ -317,7 +388,10 @@ void SpeculateWriteCacheWithRoPEKernel(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out) {
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps) {
typedef cascade_attn_type_traits<T> traits_;
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
typedef typename traits_::type DataType_;
@@ -342,142 +416,180 @@ void SpeculateWriteCacheWithRoPEKernel(
? rotary_embs.get().data<float>() + max_seq_len * dim_head
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
}
if (cache_quant_type_str == "none") {
append_speculate_cache_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else if (cache_quant_type_str == "cache_int8") {
append_speculate_cache_int8_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else if (cache_quant_type_str == "cache_fp8") {
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_speculate_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
if (q_norm_weight && k_norm_weight) {
if (cache_quant_type_str == "none") {
append_speculate_cache_rope_qk_norm(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
reinterpret_cast<const float*>(q_norm_weight.get().data<float>()),
reinterpret_cast<const float*>(k_norm_weight.get().data<float>()),
rms_norm_eps);
} else {
PD_THROW(
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
}
} else {
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, "
"cache_int4_zp]");
if (cache_quant_type_str == "none") {
append_speculate_cache_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else if (cache_quant_type_str == "cache_int8") {
append_speculate_cache_int8_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else if (cache_quant_type_str == "cache_fp8") {
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_speculate_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else {
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, "
"cache_int4_zp]");
}
}
}
@@ -504,7 +616,10 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out);
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
template void
SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
@@ -530,7 +645,10 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out);
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
const AppendAttnMetaData& meta_data,
@@ -555,7 +673,10 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out);
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
template void
@@ -582,4 +703,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out);
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);

View File

@@ -39,4 +39,7 @@ void SpeculateWriteCacheWithRoPEKernel(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out);
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);

View File

@@ -5,12 +5,16 @@ import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.incubate.nn.functional import fused_rms_norm
from fastdeploy.model_executor.layers.attention.ops import (
append_attention,
get_block_shape_and_split_kv_block,
)
np.random.seed(0)
paddle.seed(0)
class TestTreeMask(unittest.TestCase):
def setUp(self):
@@ -27,6 +31,7 @@ class TestTreeMask(unittest.TestCase):
self.head_dim = 128
self.num_q_head = 20
self.num_kv_head = 4
self.use_qknorm = True
self.dtype = "bfloat16"
self.rope_3d = False
@@ -91,12 +96,20 @@ class TestTreeMask(unittest.TestCase):
cu_seqlens_k[i + 1] = cum_seq_len_k
return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k
def ref_attention(self, q, k, v, mask):
def ref_attention(self, q, k, v, mask, use_qknorm=False):
if use_qknorm:
q = q.reshape([-1, self.head_dim])
q = fused_rms_norm(q.astype("float32"), self.q_norm_weight_tensor, None, 1e-6)[0].astype(self.dtype)
q = q.reshape([self.bsz, -1, self.num_q_head, self.head_dim])
q = q.transpose([0, 2, 1, 3])
if len(k) > 1:
k = paddle.concat(k, axis=1)
else:
k = k[0]
if use_qknorm:
k = k.reshape([-1, self.head_dim])
k = fused_rms_norm(k.astype("float32"), self.k_norm_weight_tensor, None, 1e-6)[0].astype(self.dtype)
k = k.reshape([self.bsz, -1, self.num_kv_head, self.head_dim])
k = k.transpose([0, 2, 1, 3])
if len(v) > 1:
v = paddle.concat(v, axis=1)
@@ -127,7 +140,7 @@ class TestTreeMask(unittest.TestCase):
.reshape([-1, self.num_q_head, self.head_dim])
)
def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None):
def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None, use_qknorm=False):
if prefill:
seq_lens_enc = [
q_len,
@@ -187,6 +200,10 @@ class TestTreeMask(unittest.TestCase):
decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory()
max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
q_norm_weight = np.ones([self.head_dim])
k_norm_weight = np.ones([self.head_dim])
self.q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32")
self.k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32")
paddle.device.synchronize()
(
encoder_batch_ids,
@@ -237,20 +254,20 @@ class TestTreeMask(unittest.TestCase):
max_len_kv,
rotary_embs,
attn_mask,
None,
None,
None, # qkv_bias
None, # qkv_out_scales
cache_k_scale,
cache_v_scale,
cache_k_out_scale,
cache_v_out_scale,
None,
None,
None,
None,
None,
None,
None,
None,
None, # cache_k_zp
None, # cache_v_zp
None, # linear_shift
None, # linear_smooth
None, # mask_offset
None, # kv_signal_data
self.q_norm_weight_tensor if use_qknorm else None, # q_norm_weight
self.k_norm_weight_tensor if use_qknorm else None, # k_norm_weight
1e-6,
"bf16",
"none",
@@ -271,7 +288,7 @@ class TestTreeMask(unittest.TestCase):
paddle.device.synchronize()
e_time = time.time()
print(f"mean infer time: {np.mean((e_time - s_time) * 1000 / self.run_time):.2f}")
return out[0].reshape([token_num, self.num_q_head, self.head_dim])
return out.reshape([token_num, self.num_q_head, self.head_dim])
def test_naive_speculative_decoding(self):
prefill_len = 8192
@@ -279,10 +296,10 @@ class TestTreeMask(unittest.TestCase):
total_len = prefill_len + dec_len_q
mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
self.run_append_c16_attention(prefill_len, 0, True)
dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False)
self.run_append_c16_attention(prefill_len, 0, True, use_qknorm=self.use_qknorm)
dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False, use_qknorm=self.use_qknorm)
ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask)
ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask, use_qknorm=self.use_qknorm)
np.testing.assert_allclose(
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
)