support qk norm (#3145)

This commit is contained in:
Yuan Xiaolan
2025-08-05 16:46:14 +08:00
committed by GitHub
parent 4a10e29804
commit 7ce00e597c
17 changed files with 791 additions and 201 deletions

View File

@@ -73,6 +73,9 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
@@ -223,7 +226,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
main_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
};
if (qkv_out_scales) {
@@ -339,7 +345,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
} else {
DecoderWriteCacheWithRoPEKernel<data_t, data_t>(
meta_data,
@@ -363,7 +372,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
}
@@ -430,6 +442,9 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
@@ -500,6 +515,9 @@ std::vector<paddle::Tensor> AppendAttention(
out_linear_shifts,
out_linear_smooths,
kv_signal_data,
q_norm_weight,
k_norm_weight,
rms_norm_eps,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
@@ -577,6 +595,9 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
@@ -637,6 +658,9 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
@@ -714,7 +738,9 @@ PD_BUILD_STATIC_OP(append_attention)
paddle::Optional("cache_v_zp"),
paddle::Optional("out_linear_shifts"),
paddle::Optional("out_linear_smooths"),
paddle::Optional("kv_signal_data")})
paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
.Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})
.SetInplaceMap({{"key_cache", "key_cache_out"},
{"value_cache", "value_cache_out"}})
@@ -732,7 +758,8 @@ PD_BUILD_STATIC_OP(append_attention)
"encoder_max_partition_size: int",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool"})
"speculate_decoder: bool",
"rms_norm_eps: float"})
.SetKernelFn(PD_KERNEL(AppendAttention))
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));