diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 2ba7555e7..1b7f3d531 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -73,6 +73,9 @@ std::vector AppendAttentionKernel( const paddle::optional& out_linear_shifts, const paddle::optional& out_linear_smooths, const paddle::optional& kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, const bool rope_3d, @@ -223,7 +226,10 @@ std::vector AppendAttentionKernel( main_stream, &qkv_out, const_cast(&key_cache), - const_cast(&value_cache)); + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); }; if (qkv_out_scales) { @@ -339,7 +345,10 @@ std::vector AppendAttentionKernel( exec_stream, &qkv_out, const_cast(&key_cache), - const_cast(&value_cache)); + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); } else { DecoderWriteCacheWithRoPEKernel( meta_data, @@ -363,7 +372,10 @@ std::vector AppendAttentionKernel( exec_stream, &qkv_out, const_cast(&key_cache), - const_cast(&value_cache)); + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); } } @@ -430,6 +442,9 @@ std::vector AppendAttention( const paddle::optional& out_linear_shifts, const paddle::optional& out_linear_smooths, const paddle::optional& kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps, const std::string& compute_dtype, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, @@ -500,6 +515,9 @@ std::vector AppendAttention( out_linear_shifts, out_linear_smooths, kv_signal_data, + q_norm_weight, + k_norm_weight, + rms_norm_eps, cache_quant_type_str, use_neox_rotary_style, rope_3d, @@ -577,6 +595,9 @@ std::vector> AppendAttentionInferShape( const paddle::optional>& out_linear_shifts_shape, const paddle::optional>& out_linear_smooths_shape, const paddle::optional>& kv_signal_data_shape, + const paddle::optional>& q_norm_weight_shape, + const paddle::optional>& k_norm_weight_shape, + const float rms_norm_eps, const std::string& compute_dtype, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, @@ -637,6 +658,9 @@ std::vector AppendAttentionInferDtype( const paddle::optional& out_linear_shifts_dtype, const paddle::optional& out_linear_smooths_dtype, const paddle::optional& kv_signal_data_dtype, + const paddle::optional& q_norm_weight_dtype, + const paddle::optional& k_norm_weight_dtype, + const float rms_norm_eps, const std::string& compute_dtype, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, @@ -714,7 +738,9 @@ PD_BUILD_STATIC_OP(append_attention) paddle::Optional("cache_v_zp"), paddle::Optional("out_linear_shifts"), paddle::Optional("out_linear_smooths"), - paddle::Optional("kv_signal_data")}) + paddle::Optional("kv_signal_data"), + paddle::Optional("q_norm_weight"), + paddle::Optional("k_norm_weight")}) .Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"}) .SetInplaceMap({{"key_cache", "key_cache_out"}, {"value_cache", "value_cache_out"}}) @@ -732,7 +758,8 @@ PD_BUILD_STATIC_OP(append_attention) "encoder_max_partition_size: int", "speculate_max_draft_token_num: int", "causal: bool", - "speculate_decoder: bool"}) + "speculate_decoder: bool", + "rms_norm_eps: float"}) .SetKernelFn(PD_KERNEL(AppendAttention)) .SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype)); diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh index 67066efc2..31c7bc061 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh @@ -18,6 +18,142 @@ #include "mma_tensor_op.cuh" #include "utils.cuh" +template +__global__ void append_decode_cache_T_rope_qk_norm_kernel( + const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, + // head_size] + T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size, + // head_size // 2] + T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size, + // head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ cu_seqlens_q, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int head_size, + const int block_size, + const uint32_t elem_cnt, + const int kv_num_heads, + const bool rope_3d, + const T* q_norm_weight, + const T* k_norm_weight, + const float rms_norm_eps) { + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadKVT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + LoadT src_vec; + LoadBiasT out_vec; + LoadKVT cache_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + + int64_t global_warp_idx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t all_warp_num = gridDim.x * blockDim.x; + int64_t all_head_dim = elem_cnt / head_size; + + const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size; + // const int64_t offset = 2 * hidden_size; + const int half_head_size = head_size / 2; + for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) { + int64_t linear_index = gloabl_hi * head_size + threadIdx.y * VecSize; + const int ori_bi = linear_index / hidden_size; + const int bias = linear_index % hidden_size; + const int hi = bias / head_size; // q + k + v + const int h_bias = bias % head_size; + const int start_token_idx = cu_seqlens_q[ori_bi]; + if (seq_lens_encoder[ori_bi] > 0) return; + const int write_seq_id = seq_lens[ori_bi]; + if (write_seq_id == 0) continue; + + const int* block_table_now = nullptr; + + block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const int block_idx = block_table_now[write_seq_id / block_size]; + const int block_offset = write_seq_id % block_size; + const uint32_t ori_idx = + start_token_idx * hidden_size + hi * head_size + h_bias; + + const int bias_idx = hi * head_size + h_bias; + Load(&quant_qkv[ori_idx], &src_vec); + if (hi < num_heads + kv_num_heads) { + // q k rope + const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2; + uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); + } + float thread_m2 = 0.0f; + float warp_m2 = 0.0f; + +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + // dequant + add_bias + rope + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + + if (hi < num_heads + kv_num_heads) { + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + float tmp1 = input_left * cos_tmp - input_right * sin_tmp; + float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + out_vec[2 * i] = + static_cast(tmp1); + out_vec[2 * i + 1] = + static_cast(tmp2); + } else { + out_vec[2 * i] = src_vec[2 * i]; + out_vec[2 * i + 1] = src_vec[2 * i + 1]; + } + } + if (hi < (num_heads + kv_num_heads)) { // q k + WelfordWarpAllReduce(thread_m2, &warp_m2); + float row_variance = + max(warp_m2 / head_size, 0.0f); + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); + LoadT q_norm_vec, k_norm_vec; + if (hi < num_heads) { // q + Load(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec); + #pragma unroll + for (int i = 0; i < VecSize; i++) { + out_vec[i] = static_cast(static_cast(out_vec[i]) * row_inv_var * static_cast(q_norm_vec[i])); + } + } else { // k + Load(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec); + for (int i = 0; i < VecSize; i++) { + out_vec[i] = static_cast(static_cast(out_vec[i]) * row_inv_var * static_cast(k_norm_vec[i])); + } + } + } + if (hi < num_heads) { + // write q + Store(out_vec, &qkv_out[ori_idx]); + } else { + // quant + write k/v + const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads; + const uint32_t tgt_idx = + block_idx * kv_num_heads * block_size * head_size + + kv_head_idx * block_size * head_size + block_offset * head_size + + h_bias; + if (hi < num_heads + kv_num_heads) { + Store(out_vec, &key_cache[tgt_idx]); + } else { + Store(out_vec, &value_cache[tgt_idx]); + } + } + + } +} + template __global__ void append_decode_cache_T_rope_kernel( const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu index fe72d120a..77cdfa300 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu @@ -15,6 +15,70 @@ #include "decoder_write_cache_with_rope_kernel.h" #include "utils.cuh" +template +void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv, + T* key_cache, + T* value_cache, + T* qkv_out, + const int* block_tables, + const int* batch_id_per_token, + const int* cu_seqlens_q, + const int* seq_lens, + const int* seq_lens_encoder, + const float* cos_emb, + const float* sin_emb, + const float* qkv_out_scales, + const T* qkv_biases, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int kv_num_heads, + const int dim_head, + const int block_size, + const int bsz, + const cudaStream_t& stream, + const bool use_neox_style, + const bool rope_3d, + const T* q_norm_weight, + const T* k_norm_weight, + const float rms_norm_eps) { + const uint32_t elem_nums = + use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2 + : bsz * (num_heads + 2 * kv_num_heads) * dim_head; + assert(dim_head == 128 && "dim_head must be 128"); + constexpr int HEAD_DIM = 128; + + constexpr int PackSize = HEAD_DIM / kWarpSize; + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks<128>(pack_num, &grid_size); + dim3 block_dim(blocksize / kWarpSize, kWarpSize, 1); + append_decode_cache_T_rope_qk_norm_kernel + <<>>(reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + max_seq_len, + max_blocks_per_seq, + num_heads, + dim_head, + block_size, + elem_nums, + kv_num_heads, + rope_3d, + q_norm_weight, + k_norm_weight, + rms_norm_eps); +} + template void append_decode_cache_rope(const QKV_TYPE* qkv, T* key_cache, @@ -441,7 +505,10 @@ void DecoderWriteCacheWithRoPEKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out) { + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps) { typedef cascade_attn_type_traits traits_; typedef cascade_attn_type_traits qkt_nv_type_; typedef typename traits_::type DataType_; @@ -464,107 +531,77 @@ void DecoderWriteCacheWithRoPEKernel( ? rotary_embs.get().data() + max_seq_len * dim_head : rotary_embs.get().data() + max_seq_len * dim_head / 2; } - if (cache_quant_type_str == "none") { - append_decode_cache_rope( - reinterpret_cast(qkv_ptr), - reinterpret_cast(key_cache_out->data()), - reinterpret_cast(value_cache_out->data()), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - stream, - use_neox_rotary_style, - rope_3d); - } else if (cache_quant_type_str == "cache_int8") { - bool is_scale_channel_wise = false; - if (cache_k_scale && cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) { - is_scale_channel_wise = true; - } - if (is_scale_channel_wise) { - append_decode_cache_int8_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - cache_k_scale ? reinterpret_cast( - const_cast(cache_k_scale.get().data())) - : nullptr, - cache_v_scale ? reinterpret_cast( - const_cast(cache_v_scale.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - stream, - use_neox_rotary_style, - rope_3d); + + if (q_norm_weight && k_norm_weight) { + if (cache_quant_type_str == "none") { + append_decode_cache_rope_qk_norm( + reinterpret_cast(qkv_ptr), + reinterpret_cast(key_cache_out->data()), + reinterpret_cast(value_cache_out->data()), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + stream, + use_neox_rotary_style, + rope_3d, + reinterpret_cast(q_norm_weight.get().data()), + reinterpret_cast(k_norm_weight.get().data()), + rms_norm_eps); } else { - append_decode_cache_int8_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - cache_k_scale ? reinterpret_cast( - const_cast(cache_k_scale.get().data())) - : nullptr, - cache_v_scale ? reinterpret_cast( - const_cast(cache_v_scale.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - stream, - use_neox_rotary_style, - rope_3d); + PD_THROW( + "append_decode_cache_rope_qk_norm not support cachekv quant yet"); } - } else if (cache_quant_type_str == "cache_fp8") { - append_decode_cache_int8_rope( + } else { + if (cache_quant_type_str == "none") { + append_decode_cache_rope( + reinterpret_cast(qkv_ptr), + reinterpret_cast(key_cache_out->data()), + reinterpret_cast(value_cache_out->data()), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + stream, + use_neox_rotary_style, + rope_3d); + } else if (cache_quant_type_str == "cache_int8") { + bool is_scale_channel_wise = false; + if (cache_k_scale && cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) { + is_scale_channel_wise = true; + } + if (is_scale_channel_wise) { + append_decode_cache_int8_rope( reinterpret_cast(qkv_ptr), key_cache_out->data(), value_cache_out->data(), @@ -596,49 +633,117 @@ void DecoderWriteCacheWithRoPEKernel( stream, use_neox_rotary_style, rope_3d); - } else if (cache_quant_type_str == "cache_int4_zp") { - append_decode_cache_int4_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(const_cast(qkv_out->data())), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - cache_k_scale ? reinterpret_cast( - const_cast(cache_k_scale.get().data())) + } else { + append_decode_cache_int8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) : nullptr, - cache_v_scale ? reinterpret_cast( - const_cast(cache_v_scale.get().data())) + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) + : nullptr, + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + stream, + use_neox_rotary_style, + rope_3d); + } + } else if (cache_quant_type_str == "cache_fp8") { + append_decode_cache_int8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) : nullptr, - cache_k_zp ? reinterpret_cast( - const_cast(cache_k_zp.get().data())) - : nullptr, - cache_v_zp ? reinterpret_cast( - const_cast(cache_v_zp.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - stream, - use_neox_rotary_style, - rope_3d); - } else { - PD_THROW( - "cache_quant_type_str should be one of [none, cache_int8, cache_fp8 " - "cache_int4_zp]"); + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) + : nullptr, + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + stream, + use_neox_rotary_style, + rope_3d); + } else if (cache_quant_type_str == "cache_int4_zp") { + append_decode_cache_int4_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(const_cast(qkv_out->data())), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) + : nullptr, + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) + : nullptr, + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) + : nullptr, + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + stream, + use_neox_rotary_style, + rope_3d); + } else { + PD_THROW( + "cache_quant_type_str should be one of [none, cache_int8, cache_fp8 " + "cache_int4_zp]"); + } } } @@ -667,7 +772,10 @@ template void DecoderWriteCacheWithRoPEKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); template void DecoderWriteCacheWithRoPEKernel( @@ -694,7 +802,10 @@ DecoderWriteCacheWithRoPEKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); template void DecoderWriteCacheWithRoPEKernel( const AppendAttnMetaData& meta_data, @@ -720,7 +831,10 @@ template void DecoderWriteCacheWithRoPEKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); template void DecoderWriteCacheWithRoPEKernel( const AppendAttnMetaData& meta_data, @@ -746,4 +860,7 @@ template void DecoderWriteCacheWithRoPEKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h index b3fe75b2c..459f29448 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h @@ -40,4 +40,6 @@ void DecoderWriteCacheWithRoPEKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, const float rms_norm_eps); diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh index 09f0f50a0..5215b933a 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh @@ -358,7 +358,7 @@ __global__ void GQAVariableLengthRotaryKernel( linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; - const int ori_bi = batch_id_per_token[token_idx];; + const int ori_bi = batch_id_per_token[token_idx]; if (seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int hi = bias / last_dim; @@ -405,6 +405,94 @@ __global__ void GQAVariableLengthRotaryKernel( } } + +template +__global__ void GQAVariableLengthRotaryQKNormKernel( + const T *qkv, + const float *cos_emb, + const float *sin_emb, + const int *batch_id_per_token, + const int *cu_seqlens_q, + const int *seq_lens, + const int *seq_lens_decoder, + T *qkv_out, + const int64_t elem_cnt, + const int q_num_head, + const int kv_num_head, + const int seq_len, + const int last_dim, + const bool rope_3d, + const T* q_norm_weight, + const T* k_norm_weight, + const float rms_norm_eps +) { + using LoadT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + LoadT src_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + int64_t global_warp_idx = blockDim.x * blockIdx.x + threadIdx.x; + int64_t all_warp_num = gridDim.x * blockDim.x; + const int half_lastdim = last_dim / 2; + const int offset = (q_num_head + kv_num_head) * last_dim; + const int all_head_num = elem_cnt / last_dim; + for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) { + int64_t linear_index = gloabl_hi * last_dim + threadIdx.y * VecSize; + const int token_idx = linear_index / offset; + const int ori_bi = batch_id_per_token[token_idx]; + if (seq_lens[ori_bi] == 0) continue; + const int bias = linear_index % offset; + const int hi = bias / last_dim; + const int h_bias = bias % last_dim; + + const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; + const int64_t base_idx = + token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim + + h_bias; + Load(&qkv[base_idx], &src_vec); + + int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); + + float thread_m2 = 0.0f; + float warp_m2 = 0.0f; + +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + const float input_left = static_cast(src_vec[2 * i]); + const float input_right = static_cast(src_vec[2 * i + 1]); + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + float tmp1 = input_left * cos_tmp - input_right * sin_tmp; + float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + src_vec[2 * i] = static_cast(tmp1); + src_vec[2 * i + 1] = static_cast(tmp2); + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + } + WelfordWarpAllReduce(thread_m2, &warp_m2); + float row_variance = + max(warp_m2 / last_dim, 0.0f); + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); + LoadT q_norm_vec, k_norm_vec; + if (hi < q_num_head) { + Load(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec); + #pragma unroll + for (int i = 0; i < VecSize; i++) { + src_vec[i] = static_cast(static_cast(src_vec[i]) * row_inv_var * static_cast(q_norm_vec[i])); + } + } else { + Load(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec); + for (int i = 0; i < VecSize; i++) { + src_vec[i] = static_cast(static_cast(src_vec[i]) * row_inv_var * static_cast(k_norm_vec[i])); + } + } + Store(src_vec, &qkv_out[base_idx]); + } +} + template __global__ void GQAVariableLengthRotaryKernel( const T *qkv, @@ -1568,6 +1656,66 @@ void rotary_qk_variable( } } +template +void gqa_rotary_qk_norm_variable( + T *qkv_out, // [token_num, 3, num_head, dim_head] + const QKV_TYPE *qkv_input, // qkv + const float *qkv_out_scales, // [3, num_head, dim_head] + const T *qkv_bias, + const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2] + const int *batch_id_per_token, + const int *cu_seqlens_q, + const int *seq_lens, + const int *seq_lens_decoder, + const int token_num, + const int num_heads, + const int kv_num_heads, + const int seq_len, + const int input_output_len, + const int dim_head, + const cudaStream_t &stream, + bool use_neox_style = false, + bool rope_3d = false, + const T *q_norm_weight = nullptr, + const T *k_norm_weight = nullptr, + const float rms_norm_eps = 1e-6) { + int64_t elem_nums = + qkv_out_scales + ? token_num * (num_heads + 2 * kv_num_heads) * dim_head + : token_num * (num_heads + kv_num_heads) * dim_head; // for all q k v + assert(dim_head == 128 && "dim_head must be 128"); + constexpr int HEAD_DIM = 128; + constexpr int PackSize = HEAD_DIM / kWarpSize; + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks<128>(pack_num, &grid_size); + dim3 Blocks(grid_size/kWarpSize, kWarpSize, 1); + + const float *cos_emb = rotary_emb; + const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; + + GQAVariableLengthRotaryQKNormKernel + <<>>( + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rope_3d, + q_norm_weight, + k_norm_weight, + rms_norm_eps); +} + template void gqa_rotary_qk_variable( T *qkv_out, // [token_num, 3, num_head, dim_head] diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h index 5eb238216..1e5d79878 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h @@ -46,7 +46,10 @@ void EncoderWriteCacheWithRopeKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out) { + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps) { auto token_num = meta_data.token_nums; auto num_heads = meta_data.q_num_heads; auto kv_num_heads = meta_data.kv_num_heads; @@ -56,28 +59,9 @@ void EncoderWriteCacheWithRopeKernel( is_scale_channel_wise = true; } - if (num_heads == kv_num_heads) { - rotary_qk_variable( - qkv_out->data(), - qkv.data(), - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? qkv_biases.get().data() : nullptr, - rotary_embs.get().data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens_encoder.data(), - seq_lens_decoder.data(), - token_num, - num_heads, - max_seq_len, - rotary_embs.get().dims()[2], - head_dim, - stream, - use_neox_style, - rope_3d); - } else { - if (!is_scale_channel_wise) { - gqa_rotary_qk_variable( + if (q_norm_weight && k_norm_weight) { + if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) { + gqa_rotary_qk_norm_variable( qkv_out->data(), qkv.data(), qkv_out_scales ? qkv_out_scales.get().data() : nullptr, @@ -95,31 +79,80 @@ void EncoderWriteCacheWithRopeKernel( head_dim, stream, use_neox_style, - rope_3d); + rope_3d, + q_norm_weight ? q_norm_weight.get().data() : nullptr, + k_norm_weight ? k_norm_weight.get().data() : nullptr, + rms_norm_eps); } else { - gqa_rotary_qk_quant_variable( - qkv_out->data(), - qkv.data(), - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? qkv_biases.get().data() : nullptr, - cache_k_scale ? cache_k_scale.get().data() : nullptr, - cache_v_scale ? cache_v_scale.get().data() : nullptr, - rotary_embs.get().data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens_encoder.data(), - seq_lens_decoder.data(), - token_num, - num_heads, - kv_num_heads, - max_seq_len, - rotary_embs.get().dims()[2], - head_dim, - stream, - use_neox_style, - rope_3d); + PD_THROW( + "gqa_rotary_qk_norm_variable only support gqa mode. channel wise scale and neox style are not supported"); } + } else { + if (num_heads == kv_num_heads) { + rotary_qk_variable( + qkv_out->data(), + qkv.data(), + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? qkv_biases.get().data() : nullptr, + rotary_embs.get().data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + token_num, + num_heads, + max_seq_len, + rotary_embs.get().dims()[2], + head_dim, + stream, + use_neox_style, + rope_3d); + } else { + if (!is_scale_channel_wise) { + gqa_rotary_qk_variable( + qkv_out->data(), + qkv.data(), + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? qkv_biases.get().data() : nullptr, + rotary_embs.get().data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + token_num, + num_heads, + kv_num_heads, + max_seq_len, + rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2], + head_dim, + stream, + use_neox_style, + rope_3d); + } else { + gqa_rotary_qk_quant_variable( + qkv_out->data(), + qkv.data(), + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? qkv_biases.get().data() : nullptr, + cache_k_scale ? cache_k_scale.get().data() : nullptr, + cache_v_scale ? cache_v_scale.get().data() : nullptr, + rotary_embs.get().data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + token_num, + num_heads, + kv_num_heads, + max_seq_len, + rotary_embs.get().dims()[2], + head_dim, + stream, + use_neox_style, + rope_3d); + } + } } const uint32_t block_size = meta_data.block_size; if (cache_quant_type_str == "none") { diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_bfloat16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_bfloat16_kernel.cu index 8d786ce58..915039908 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_bfloat16_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_bfloat16_kernel.cu @@ -43,4 +43,7 @@ EncoderWriteCacheWithRopeKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_int_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_int_kernel.cu index a34da8258..3f3539b8a 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_int_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_int_kernel.cu @@ -42,4 +42,7 @@ template void EncoderWriteCacheWithRopeKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_float16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_float16_kernel.cu index 42f07ee8b..a559ec77f 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_float16_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_float16_kernel.cu @@ -42,4 +42,7 @@ template void EncoderWriteCacheWithRopeKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_int_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_int_kernel.cu index ef3d3832e..3318a3647 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_int_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_int_kernel.cu @@ -42,4 +42,7 @@ template void EncoderWriteCacheWithRopeKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); diff --git a/custom_ops/gpu_ops/append_attn/utils.cuh b/custom_ops/gpu_ops/append_attn/utils.cuh index 05f500126..9efbab433 100644 --- a/custom_ops/gpu_ops/append_attn/utils.cuh +++ b/custom_ops/gpu_ops/append_attn/utils.cuh @@ -559,3 +559,37 @@ template inline __device__ static void convert_c8(T * re convert_int8(result, source); } } + +constexpr int kWarpSize = 32; + +template +inline __device__ void WelfordCombine1(T b_m2, T* m2) { + *m2 += b_m2; +} + +template +__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) { + *m2 = thread_m2; + for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) { + T b_m2 = __shfl_xor_sync(0xffffffff, *m2, mask); + WelfordCombine1(b_m2, m2); + } +} + +template +__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) { + WelfordWarpReduce(thread_m2, m2); +} + +template +__inline__ __device__ T Rsqrt(T x); + +template <> +__inline__ __device__ float Rsqrt(float x) { + return rsqrt(x); +} + +template <> +__inline__ __device__ double Rsqrt(double x) { + return rsqrt(x); +} diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 000820688..b0e6e332b 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -78,6 +78,9 @@ std::vector AppendAttention( const paddle::optional &out_linear_shifts, const paddle::optional &out_linear_smooths, const paddle::optional &kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps, const std::string &compute_dtype, const std::string &cache_quant_type_str, const bool use_neox_rotary_style, const bool rope_3d, const int max_input_length, const float quant_max_bound, diff --git a/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu b/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu index 46cc60bef..d677b360c 100644 --- a/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu +++ b/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu @@ -43,6 +43,11 @@ __VA_ARGS__ \ break; \ } \ + case 48: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 48; \ + __VA_ARGS__ \ + break; \ + } \ case 64: { \ constexpr size_t NUM_EXPERTS_PER_RANK = 64; \ __VA_ARGS__ \ diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index cffc4adf7..405c6c118 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -262,6 +262,9 @@ class AppendAttentionBackend(AttentionBackend): layer.linear_shift, layer.linear_smooth, metadata.kv_signal_data_list[layer.layer_id], + getattr(layer, "q_norm_weight", None), + getattr(layer, "k_norm_weight", None), + getattr(layer, "rms_norm_eps", 1e-6), metadata._fuse_kernel_compute_dtype, getattr(layer, "cache_quant_type_str", "none"), layer.use_neox_rotary_style, diff --git a/fastdeploy/model_executor/layers/attention/attention.py b/fastdeploy/model_executor/layers/attention/attention.py index f1bc434bc..98527571a 100644 --- a/fastdeploy/model_executor/layers/attention/attention.py +++ b/fastdeploy/model_executor/layers/attention/attention.py @@ -28,6 +28,7 @@ from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethod if TYPE_CHECKING: from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.layers.utils import get_tensor class Attention(nn.Layer): @@ -49,6 +50,7 @@ class Attention(nn.Layer): linear_smooth: paddle.Tensor = None, use_neox_rotary_style: bool = False, use_qk_norm: bool = False, + rms_norm_eps: float = 1e-6, ) -> None: """ Initializes `LMLayer` with the given parameters. @@ -63,6 +65,8 @@ class Attention(nn.Layer): prefix (str, optional): The name of current layer. Defaults to "". linear_shift (Optional[paddle.Tensor], optional): The shift of linear. Defaults to None. linear_smooth (Optional[paddle.Tensor], optional): The smooth of linear. Defaults to None. + use_qk_norm (bool, optional): Whether to apply rmsnorm on QA after rope. Defaults to False. + rms_norm_eps (float, optional): The epsilon of RMSNorm. Defaults to 1e-6. Raises: ValueError: If the `v_head_dim` is less than 0. @@ -102,6 +106,27 @@ class Attention(nn.Layer): logger.info( f"Attention is running in cache kv {self.kvcache_quant_method.cache_quant_config.quant_type} mode" ) + self.use_qk_norm = use_qk_norm + self.rms_norm_eps = rms_norm_eps + if self.use_qk_norm: + self.q_norm_key = f"{self.prefix}.q_norm" + self.k_norm_key = f"{self.prefix}.k_norm" + self.init_weight() + + def init_weight(self): + self.q_norm_weight = self.create_parameter( + shape=[self.qk_head_dim], + dtype=self._dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + self.k_norm_weight = self.create_parameter( + shape=[self.qk_head_dim], + dtype=self._dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]): """ @@ -109,6 +134,11 @@ class Attention(nn.Layer): """ if self.kvcache_quant_method is not None: self.kvcache_quant_method.create_weights(self, state_dict) + if self.use_qk_norm: + q_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.q_norm_key + ".weight"))) + k_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.k_norm_key + ".weight"))) + self.q_norm_weight.set_value(q_norm_weight_tensor) + self.k_norm_weight.set_value(k_norm_weight_tensor) def forward( self, diff --git a/fastdeploy/model_executor/layers/attention/ops/append_attention.py b/fastdeploy/model_executor/layers/attention/ops/append_attention.py index de538ad69..38322ab9e 100644 --- a/fastdeploy/model_executor/layers/attention/ops/append_attention.py +++ b/fastdeploy/model_executor/layers/attention/ops/append_attention.py @@ -60,6 +60,9 @@ def append_attention( linear_shift: Optional[paddle.Tensor] = None, linear_smooth: Optional[paddle.Tensor] = None, kv_signal_data: Optional[paddle.Tensor] = None, + q_norm_weight: Optional[paddle.Tensor] = None, + k_norm_weight: Optional[paddle.Tensor] = None, + rms_norm_eps: float = 1e-6, compute_type: str = "bf16", cache_quant_type: str = "none", use_neox_rotary_style: bool = False, @@ -114,6 +117,9 @@ def append_attention( linear_shift, linear_smooth, kv_signal_data, + q_norm_weight, + k_norm_weight, + rms_norm_eps, compute_type, cache_quant_type, use_neox_rotary_style, diff --git a/test/layers/test_append_attention.py b/test/layers/test_append_attention.py index 764191a7b..1c2ac0bbf 100644 --- a/test/layers/test_append_attention.py +++ b/test/layers/test_append_attention.py @@ -17,6 +17,7 @@ import unittest import numpy as np import paddle +from paddle.incubate.nn.functional import fused_rms_norm paddle.seed(10) @@ -157,6 +158,8 @@ def naive_attention_impl( cache_k_dequant_scales=None, cache_v_dequant_scales=None, use_cachekv_int8="None", + q_norm_weight=None, + k_norm_weight=None, ): batch = query.shape[0] heads = query.shape[1] @@ -244,6 +247,27 @@ def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head return q, k, v, qkv +def apply_qk_norm(head_dim, dtype, q, k): + q_norm_weight = np.random.random([head_dim]) / 10 + k_norm_weight = np.random.random([head_dim]) / 10 + q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype=dtype) + k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype=dtype) + print("q:", q.shape) + print("k:", k.shape) + bs, q_num_head, seq_len, dim_head = q.shape + _, kv_num_head, _, _ = k.shape + + q = q.reshape([-1, head_dim]) + k = k.reshape([-1, head_dim]) + print("q:", q) + q = fused_rms_norm(q, q_norm_weight_tensor, None, 1e-5)[0] + print("q after norm:", q) + k = fused_rms_norm(k, k_norm_weight_tensor, None, 1e-5)[0] + q = q.reshape([-1, q_num_head, seq_len, dim_head]) + k = k.reshape([-1, kv_num_head, seq_len, dim_head]) + return q, k, q_norm_weight_tensor, k_norm_weight_tensor + + def split_query_by_phase( query, seq_lens_encoder, @@ -324,6 +348,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.softmax_scale = self.dim_head**-0.5 self.rope_theta = 10000 self.dtype = "float16" + self.use_qk_norm = True self.init_tensor() def init_tensor(self): @@ -394,6 +419,11 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): ) q, k = self.rope._apply_rope(self.rope_emb, q, k, causal=True) + if self.use_qk_norm: + q, k, q_norm_weight, k_norm_weight = apply_qk_norm(self.dim_head, self.dtype, q, k) + else: + q_norm_weight = None + k_norm_weight = None out_ = naive_attention_impl( q, k, @@ -476,6 +506,9 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): None, # linear_shift None, # linear_smooth None, # kv_signal_data + q_norm_weight, # q_norm_weight + k_norm_weight, # k_norm_weight + 1e-6, "fp16", "none", # cache_quant_type self.use_neox_rotary_style, @@ -580,6 +613,7 @@ class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope): self.softmax_scale = self.dim_head**-0.5 self.rope_theta = 10000 self.dtype = "float16" + self.use_qk_norm = False self.init_tensor()