[SOT][Cudagraph] Remove BreakGraph of #3302 && update CustomOp (#3694)

* rm inplace info && to(gpu)

* update append_attention

* unpin paddle version

* add full_cuda_graph=False

* add blank line

---------

Co-authored-by: SigureMo <sigure.qaq@gmail.com>
This commit is contained in:
Ryan
2025-10-17 10:57:55 +08:00
committed by GitHub
parent a37c9416ac
commit 49cea8fb1c
5 changed files with 12 additions and 11 deletions

View File

@@ -593,7 +593,7 @@ std::vector<paddle::Tensor> AppendAttention(
return {paddle::Tensor{}};
}
void AppendAttentionWithOutput(
std::vector<paddle::Tensor> AppendAttentionWithOutput(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
@@ -756,6 +756,8 @@ void AppendAttentionWithOutput(
break;
}
}
return {fmha_out};
}
@@ -1112,10 +1114,8 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
.Outputs({"fmha_out_out", "qkv_out", "key_cache_out", "value_cache_out"})
.SetInplaceMap({{"fmha_out", "fmha_out_out"},
{"key_cache", "key_cache_out"},
{"value_cache", "value_cache_out"}})
.Outputs({"fmha_out_out"})
.SetInplaceMap({{"fmha_out", "fmha_out_out"}})
.Attrs({"rms_norm_eps: float",
"compute_type: std::string",
"cache_quant_type: std::string",

View File

@@ -91,7 +91,7 @@ std::vector<paddle::Tensor> AppendAttention(
const int speculate_max_draft_token_num, const bool causal,
const bool speculate_decoder);
void AppendAttentionWithOutput(
std::vector<paddle::Tensor> AppendAttentionWithOutput(
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,

View File

@@ -262,15 +262,15 @@ class AppendAttentionBackend(AttentionBackend):
# 3. generate output tensor of different dtypes
if out_scale > 0.0:
if abs(quant_max_bound - 127) < 0.000001:
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="int8").to(qkv.place)
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="int8")
elif abs(quant_max_bound - 448) < 0.000001:
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="float8_e4m3fn").to(qkv.place)
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="float8_e4m3fn")
else:
raise NotImplementedError("Only supported attr of quant_max_bound in ['127', '448'].")
else:
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype=D_type).to(qkv.place)
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype=D_type)
append_attention_with_output(
res = append_attention_with_output(
qkv,
cache_k,
cache_v,

View File

@@ -205,7 +205,7 @@ def append_attention_with_output(
append_attention
"""
if current_platform.is_cuda():
append_attention_with_output_gpu(
return append_attention_with_output_gpu(
qkv,
key_cache,
value_cache,

View File

@@ -6,3 +6,4 @@ use_cudagraph: True
graph_optimization_config:
graph_opt_level: 1
sot_warmup_sizes: [2,16,32,64]
full_cuda_graph: False