fix custom op order rms_norm_eps (#3348)

This commit is contained in:
Ryan
2025-08-13 10:12:49 +08:00
committed by GitHub
parent 8224b21525
commit ed6bff215a

View File

@@ -739,12 +739,13 @@ PD_BUILD_STATIC_OP(append_attention)
paddle::Optional("out_linear_shifts"),
paddle::Optional("out_linear_smooths"),
paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
.Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})
.SetInplaceMap({{"key_cache", "key_cache_out"},
{"value_cache", "value_cache_out"}})
.Attrs({"compute_type: std::string",
.Attrs({"rms_norm_eps: float",
"compute_type: std::string",
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"rope_3d: bool",
@@ -759,7 +760,7 @@ PD_BUILD_STATIC_OP(append_attention)
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool",
"rms_norm_eps: float"})
})
.SetKernelFn(PD_KERNEL(AppendAttention))
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));