diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 75e336b76..d5ece4f53 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -277,7 +277,10 @@ void AppendAttentionKernel( exec_stream, &qkv_out, const_cast(&key_cache), - const_cast(&value_cache)); + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); } else { SpeculateWriteCacheWithRoPEKernel( meta_data, @@ -300,7 +303,10 @@ void AppendAttentionKernel( exec_stream, &qkv_out, const_cast(&key_cache), - const_cast(&value_cache)); + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); } } else { if (qkv_out_scales) { diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh index c8273cd3c..75f9ebd8d 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh @@ -120,7 +120,6 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel( float row_variance = max(warp_m2 / head_size, 0.0f); float row_inv_var = Rsqrt(row_variance + rms_norm_eps); - if (hi < num_heads) { // q Load(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); #pragma unroll @@ -129,6 +128,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel( } } else { // k Load(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec); + #pragma unroll for (int i = 0; i < VecSize; i++) { out_vec[i] = static_cast(tmp_vec[i] * row_inv_var * k_norm_vec[i]); } diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh index 57612c458..9c9816d3b 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh @@ -18,6 +18,166 @@ #include "mma_tensor_op.cuh" #include "utils.cuh" +template +__global__ void append_speculate_cache_T_rope_qk_norm_kernel( + const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size, + // head_size] + T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size, + // head_size // 2] + T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, + // head_size // 2] + T* __restrict__ q_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ cu_seqlens_q, + const int* __restrict__ seq_lens_decoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + const float* + qkv_out_scales, // [(num_heads + 2 * gqa_group_size) * head_size] + const T* qkv_biases, // [num_head + 2 * gqa_group_size, dim_head] + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int output_inner_dim, + const int head_size, + const int block_size, + const int elem_cnt, + const int gqa_group_size, + const float* q_norm_weight, + const float* k_norm_weight, + const float rms_norm_eps) { + using LoadT = AlignedVector; + using LoadFloat = AlignedVector; + using LoadInT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + LoadInT src_vec; + LoadFloat scale_vec; + LoadT bias_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + LoadFloat tmp_vec; + LoadFloat q_norm_vec; + LoadFloat k_norm_vec; + + int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y; + int64_t all_warp_num = gridDim.x * blockDim.y; + int64_t all_head_dim = elem_cnt / head_size; + + const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size; + const int half_head_size = head_size / 2; + for (int global_hi = global_warp_idx; global_hi < all_head_dim; global_hi += all_warp_num) { + int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize; + const int token_id = linear_index / hidden_size; + const int ori_bi = batch_id_per_token[token_id]; + if (seq_lens_decoder[ori_bi] == 0) continue; + const int bias = linear_index % hidden_size; + const int hi = bias / head_size; // q + k + v + const int h_bias = bias % head_size; + const int start_token_idx = cu_seqlens_q[ori_bi]; + const int write_seq_id = + seq_lens_decoder[ori_bi] + token_id - start_token_idx; + if (write_seq_id == 0) continue; + + const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const int block_idx = block_table_now[write_seq_id / block_size]; + if (block_idx < 0) { + printf( + "Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var " + "%d %d %d %d\n", + block_idx, + write_seq_id, + ori_bi, + seq_lens_decoder[ori_bi], + token_id, + cu_seqlens_q[ori_bi]); + } + const int block_offset = write_seq_id % block_size; + + const int write_q_idx = + token_id * output_inner_dim * head_size + hi * head_size + h_bias; + + const int bias_idx = hi * head_size + h_bias; + Load(&qkv[linear_index], &src_vec); + if (qkv_biases) { + Load(&qkv_biases[bias_idx], &bias_vec); + } + if (qkv_out_scales) { + Load(&qkv_out_scales[bias_idx], &scale_vec); + } + if (hi < num_heads + gqa_group_size) { + // q k rope + const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2; + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); + } + float thread_m2 = 0.0f; + float warp_m2 = 0.0f; +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + // add_bias + rope + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + if (qkv_out_scales) { + input_left *= scale_vec[2 * i]; + input_right *= scale_vec[2 * i + 1]; + } + if (qkv_biases) { + input_left = input_left + static_cast(bias_vec[2 * i]); + input_right = input_right + static_cast(bias_vec[2 * i + 1]); + } + if (hi < num_heads + gqa_group_size) { + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + float tmp1 = input_left * cos_tmp - input_right * sin_tmp; + float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + tmp_vec[2 * i] = tmp1; + tmp_vec[2 * i + 1] = tmp2; + } else { + bias_vec[2 * i] = static_cast(input_left); + bias_vec[2 * i + 1] = static_cast(input_right); + } + } + if (hi < (num_heads + gqa_group_size)) { + WelfordWarpAllReduce(thread_m2, &warp_m2); + float row_variance = + max(warp_m2 / head_size, 0.0f); + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); + if (hi < num_heads) { + Load(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); + #pragma unroll + for (int i = 0; i < VecSize; i++) { + bias_vec[i] = static_cast(tmp_vec[i] * row_inv_var * q_norm_vec[i]); + } + } else { + Load(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec); + #pragma unroll + for (int i = 0; i < VecSize; i++) { + bias_vec[i] = static_cast(tmp_vec[i] * row_inv_var * k_norm_vec[i]); + } + } + } + if (hi < num_heads) { + // write q + Store(bias_vec, &q_out[write_q_idx]); + } else { + // write k/v + const int kv_head_idx = (hi - num_heads) % gqa_group_size; + const int tgt_idx = (block_idx * gqa_group_size * block_size * head_size + + kv_head_idx * block_size * head_size + + block_offset * head_size + h_bias); + // write + if (hi < num_heads + gqa_group_size) { + Store(bias_vec, &key_cache[tgt_idx]); + } else { + Store(bias_vec, &value_cache[tgt_idx]); + } + } + } +} + template __global__ void append_clear_cache_int8_block( uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size, diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu index fb6a24fef..8e8195c30 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu @@ -15,6 +15,77 @@ #include "speculate_write_cache_with_rope_kernel.h" #include "utils.cuh" +template +void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv, + T* key_cache, + T* value_cache, + T* qkv_out, + const int* block_tables, + const int* batch_id_per_token, + const int* cu_seqlens_q, + const int* seq_lens, + const int* seq_lens_encoder, + const float* cos_emb, + const float* sin_emb, + const float* qkv_out_scales, + const T* qkv_biases, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int kv_num_heads, + const int dim_head, + const int block_size, + const int bsz, + const int token_num, + const cudaStream_t& stream, + const bool use_neox_style, + const float* q_norm_weight, + const float* k_norm_weight, + const float rms_norm_eps) { + int output_inner_dim = num_heads + 2 * kv_num_heads; + const uint32_t elem_nums = + use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2 + : token_num * (num_heads + 2 * kv_num_heads) * dim_head; + constexpr int HEAD_DIM = 128; + + constexpr int PackSize = HEAD_DIM / kWarpSize; + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks<128>(pack_num, &grid_size); + + if (use_neox_style) { + PD_THROW( + "append_speculate_cache_rope_qk_norm not support neox rope yet"); + } else { + dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1); + append_speculate_cache_T_rope_qk_norm_kernel + <<>>(qkv, + key_cache, + value_cache, + qkv_out, + block_tables, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + max_seq_len, + max_blocks_per_seq, + num_heads, + output_inner_dim, + dim_head, + block_size, + elem_nums, + kv_num_heads, + q_norm_weight, + k_norm_weight, + rms_norm_eps); + } +} + // rope + write template void append_speculate_cache_rope(const QKV_TYPE* qkv, @@ -317,7 +388,10 @@ void SpeculateWriteCacheWithRoPEKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out) { + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps) { typedef cascade_attn_type_traits traits_; typedef cascade_attn_type_traits qkt_nv_type_; typedef typename traits_::type DataType_; @@ -342,142 +416,180 @@ void SpeculateWriteCacheWithRoPEKernel( ? rotary_embs.get().data() + max_seq_len * dim_head : rotary_embs.get().data() + max_seq_len * dim_head / 2; } - if (cache_quant_type_str == "none") { - append_speculate_cache_rope( - reinterpret_cast(qkv_ptr), - reinterpret_cast(key_cache_out->data()), - reinterpret_cast(value_cache_out->data()), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - token_nums, - stream, - use_neox_rotary_style); - } else if (cache_quant_type_str == "cache_int8") { - append_speculate_cache_int8_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - cache_k_scale ? reinterpret_cast( - const_cast(cache_k_scale.get().data())) - : nullptr, - cache_v_scale ? reinterpret_cast( - const_cast(cache_v_scale.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - token_nums, - stream, - use_neox_rotary_style); - } else if (cache_quant_type_str == "cache_fp8") { - append_speculate_cache_int8_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - cache_k_scale ? reinterpret_cast( - const_cast(cache_k_scale.get().data())) - : nullptr, - cache_v_scale ? reinterpret_cast( - const_cast(cache_v_scale.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - token_nums, - stream, - use_neox_rotary_style); - } else if (cache_quant_type_str == "cache_int4_zp") { - append_speculate_cache_int4_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(const_cast(qkv_out->data())), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - cache_k_scale ? reinterpret_cast( - const_cast(cache_k_scale.get().data())) - : nullptr, - cache_v_scale ? reinterpret_cast( - const_cast(cache_v_scale.get().data())) - : nullptr, - cache_k_zp ? reinterpret_cast( - const_cast(cache_k_zp.get().data())) - : nullptr, - cache_v_zp ? reinterpret_cast( - const_cast(cache_v_zp.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - token_nums, - stream, - use_neox_rotary_style); + + if (q_norm_weight && k_norm_weight) { + if (cache_quant_type_str == "none") { + append_speculate_cache_rope_qk_norm( + reinterpret_cast(qkv_ptr), + reinterpret_cast(key_cache_out->data()), + reinterpret_cast(value_cache_out->data()), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style, + reinterpret_cast(q_norm_weight.get().data()), + reinterpret_cast(k_norm_weight.get().data()), + rms_norm_eps); + } else { + PD_THROW( + "append_decode_cache_rope_qk_norm not support cachekv quant yet"); + } } else { - PD_THROW( - "cache_quant_type_str should be one of [none, cache_int8, " - "cache_int4_zp]"); + if (cache_quant_type_str == "none") { + append_speculate_cache_rope( + reinterpret_cast(qkv_ptr), + reinterpret_cast(key_cache_out->data()), + reinterpret_cast(value_cache_out->data()), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style); + } else if (cache_quant_type_str == "cache_int8") { + append_speculate_cache_int8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) + : nullptr, + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style); + } else if (cache_quant_type_str == "cache_fp8") { + append_speculate_cache_int8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) + : nullptr, + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style); + } else if (cache_quant_type_str == "cache_int4_zp") { + append_speculate_cache_int4_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(const_cast(qkv_out->data())), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) + : nullptr, + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) + : nullptr, + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) + : nullptr, + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style); + } else { + PD_THROW( + "cache_quant_type_str should be one of [none, cache_int8, " + "cache_int4_zp]"); + } } } @@ -504,7 +616,10 @@ template void SpeculateWriteCacheWithRoPEKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); template void SpeculateWriteCacheWithRoPEKernel( @@ -530,7 +645,10 @@ SpeculateWriteCacheWithRoPEKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); template void SpeculateWriteCacheWithRoPEKernel( const AppendAttnMetaData& meta_data, @@ -555,7 +673,10 @@ template void SpeculateWriteCacheWithRoPEKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); template void @@ -582,4 +703,7 @@ SpeculateWriteCacheWithRoPEKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h index 40ab34e05..a44a9db15 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h @@ -39,4 +39,7 @@ void SpeculateWriteCacheWithRoPEKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); diff --git a/tests/operators/test_tree_mask.py b/tests/operators/test_tree_mask.py index 10e55a4b1..59c4b1d98 100644 --- a/tests/operators/test_tree_mask.py +++ b/tests/operators/test_tree_mask.py @@ -5,12 +5,16 @@ import unittest import numpy as np import paddle import paddle.nn.functional as F +from paddle.incubate.nn.functional import fused_rms_norm from fastdeploy.model_executor.layers.attention.ops import ( append_attention, get_block_shape_and_split_kv_block, ) +np.random.seed(0) +paddle.seed(0) + class TestTreeMask(unittest.TestCase): def setUp(self): @@ -27,6 +31,7 @@ class TestTreeMask(unittest.TestCase): self.head_dim = 128 self.num_q_head = 20 self.num_kv_head = 4 + self.use_qknorm = True self.dtype = "bfloat16" self.rope_3d = False @@ -91,12 +96,20 @@ class TestTreeMask(unittest.TestCase): cu_seqlens_k[i + 1] = cum_seq_len_k return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k - def ref_attention(self, q, k, v, mask): + def ref_attention(self, q, k, v, mask, use_qknorm=False): + if use_qknorm: + q = q.reshape([-1, self.head_dim]) + q = fused_rms_norm(q.astype("float32"), self.q_norm_weight_tensor, None, 1e-6)[0].astype(self.dtype) + q = q.reshape([self.bsz, -1, self.num_q_head, self.head_dim]) q = q.transpose([0, 2, 1, 3]) if len(k) > 1: k = paddle.concat(k, axis=1) else: k = k[0] + if use_qknorm: + k = k.reshape([-1, self.head_dim]) + k = fused_rms_norm(k.astype("float32"), self.k_norm_weight_tensor, None, 1e-6)[0].astype(self.dtype) + k = k.reshape([self.bsz, -1, self.num_kv_head, self.head_dim]) k = k.transpose([0, 2, 1, 3]) if len(v) > 1: v = paddle.concat(v, axis=1) @@ -127,7 +140,7 @@ class TestTreeMask(unittest.TestCase): .reshape([-1, self.num_q_head, self.head_dim]) ) - def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None): + def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None, use_qknorm=False): if prefill: seq_lens_enc = [ q_len, @@ -187,6 +200,10 @@ class TestTreeMask(unittest.TestCase): decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory() max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() + q_norm_weight = np.ones([self.head_dim]) + k_norm_weight = np.ones([self.head_dim]) + self.q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32") + self.k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32") paddle.device.synchronize() ( encoder_batch_ids, @@ -237,20 +254,20 @@ class TestTreeMask(unittest.TestCase): max_len_kv, rotary_embs, attn_mask, - None, - None, + None, # qkv_bias + None, # qkv_out_scales cache_k_scale, cache_v_scale, cache_k_out_scale, cache_v_out_scale, - None, - None, - None, - None, - None, - None, - None, - None, + None, # cache_k_zp + None, # cache_v_zp + None, # linear_shift + None, # linear_smooth + None, # mask_offset + None, # kv_signal_data + self.q_norm_weight_tensor if use_qknorm else None, # q_norm_weight + self.k_norm_weight_tensor if use_qknorm else None, # k_norm_weight 1e-6, "bf16", "none", @@ -271,7 +288,7 @@ class TestTreeMask(unittest.TestCase): paddle.device.synchronize() e_time = time.time() print(f"mean infer time: {np.mean((e_time - s_time) * 1000 / self.run_time):.2f}") - return out[0].reshape([token_num, self.num_q_head, self.head_dim]) + return out.reshape([token_num, self.num_q_head, self.head_dim]) def test_naive_speculative_decoding(self): prefill_len = 8192 @@ -279,10 +296,10 @@ class TestTreeMask(unittest.TestCase): total_len = prefill_len + dec_len_q mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) - self.run_append_c16_attention(prefill_len, 0, True) - dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False) + self.run_append_c16_attention(prefill_len, 0, True, use_qknorm=self.use_qknorm) + dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False, use_qknorm=self.use_qknorm) - ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask) + ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask, use_qknorm=self.use_qknorm) np.testing.assert_allclose( ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 )