mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
support qk norm (#3145)
This commit is contained in:
@@ -73,6 +73,9 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
@@ -223,7 +226,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
main_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
};
|
||||
|
||||
if (qkv_out_scales) {
|
||||
@@ -339,7 +345,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
DecoderWriteCacheWithRoPEKernel<data_t, data_t>(
|
||||
meta_data,
|
||||
@@ -363,7 +372,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -430,6 +442,9 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
@@ -500,6 +515,9 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
kv_signal_data,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
@@ -577,6 +595,9 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
||||
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
@@ -637,6 +658,9 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
|
||||
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
|
||||
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
|
||||
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
@@ -714,7 +738,9 @@ PD_BUILD_STATIC_OP(append_attention)
|
||||
paddle::Optional("cache_v_zp"),
|
||||
paddle::Optional("out_linear_shifts"),
|
||||
paddle::Optional("out_linear_smooths"),
|
||||
paddle::Optional("kv_signal_data")})
|
||||
paddle::Optional("kv_signal_data"),
|
||||
paddle::Optional("q_norm_weight"),
|
||||
paddle::Optional("k_norm_weight")})
|
||||
.Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})
|
||||
.SetInplaceMap({{"key_cache", "key_cache_out"},
|
||||
{"value_cache", "value_cache_out"}})
|
||||
@@ -732,7 +758,8 @@ PD_BUILD_STATIC_OP(append_attention)
|
||||
"encoder_max_partition_size: int",
|
||||
"speculate_max_draft_token_num: int",
|
||||
"causal: bool",
|
||||
"speculate_decoder: bool"})
|
||||
"speculate_decoder: bool",
|
||||
"rms_norm_eps: float"})
|
||||
.SetKernelFn(PD_KERNEL(AppendAttention))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));
|
||||
|
||||
Reference in New Issue
Block a user