[XPU] fix VL multi-batch accuracy issue (#4394)

This commit is contained in:
Lucas
2025-10-15 17:27:43 +08:00
committed by GitHub
parent d8841b7b40
commit bdc0207277
2 changed files with 37 additions and 12 deletions

View File

@@ -88,7 +88,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
const paddle::optional<paddle::Tensor>& shift,
const paddle::optional<paddle::Tensor>& smooth,
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu) {
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
const std::string &pos_emb_type,
bool rope_3d) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
@@ -130,6 +132,16 @@ std::vector<paddle::Tensor> BlockAttnKernel(
int max_enc_len = len_info_cpu.data<int32_t>()[3];
int max_kv_len = len_info_cpu.data<int32_t>()[4];
int prefix_block_num_per_seq = len_info_cpu.data<int32_t>()[5];
int rope_max_seqlen = 0;
int rope_3d_num_seqs = 1;
if (rope_3d) {
rope_max_seqlen = rotary_embs.dims()[3];
rope_3d_num_seqs = rotary_embs.dims()[0];
} else {
rope_max_seqlen = rotary_embs.dims()[2];
}
auto block_attn_out =
paddle::empty({token_num, hidden_dim}, qkv.type(), qkv.place());
@@ -299,7 +311,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
prefix_lens_vp, // start_tokens
param.batch_size, // batch_size
1, // emb_batch_size
rotary_embs.dims()[2], // max_seqlen
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
@@ -308,13 +320,15 @@ std::vector<paddle::Tensor> BlockAttnKernel(
max_block_per_seq,
"BLHD",
"HLD",
"NORMAL",
pos_emb_type,
nullptr, // k_cache_scale_inv - use for per head
nullptr, // v_cache_scale_inv - use for per head
quant_k_scale, // intx_k_pc_scale
quant_v_scale, // intx_v_pc_scale
quant_k_zp, // intx_k_pc_zero
quant_v_zp); // intx_v_pc_zero
quant_v_zp, // intx_v_pc_zero
rope_3d, // rope_3d
rope_3d_num_seqs);
PD_CHECK(ret == api::SUCCESS,
"infer_ops::split_rope_cache_kv_encoder failed.");
// pd split
@@ -504,7 +518,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
decoder_context_len_cache_vp, // start_tokens (prefix len)
param.batch_size, // batch_size
1, // emb_batch_size
rotary_embs.dims()[2], // max_seqlen
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
@@ -513,13 +527,15 @@ std::vector<paddle::Tensor> BlockAttnKernel(
max_block_per_seq,
"BLHD",
"HLD",
"NORMAL",
pos_emb_type,
nullptr, // k_cache_scale_inv - use for per head
nullptr, // v_cache_scale_inv - use for per head
quant_k_scale, // intx_k_pc_scale
quant_v_scale, // intx_v_pc_scale
quant_k_zp, // intx_k_pc_zero
quant_v_zp); // intx_v_pc_zero
quant_v_zp, // intx_v_pc_zero
rope_3d, // rope_3d
rope_3d_num_seqs);
PD_CHECK(ret == api::SUCCESS,
"infer_ops::split_rope_cache_kv_encoder failed.");
@@ -690,7 +706,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
vsl.slot_mapping_vp, // real_batch
param.batch_size, // batch_size
1, // emb_batch_size = rotary_embs.dims()[1] = 1
rotary_embs.dims()[2], // TODO(lizan03)
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
@@ -699,12 +715,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
max_block_per_seq,
"BLHD",
"HLD",
"NORMAL",
pos_emb_type,
reinterpret_cast<D_Scale*>(quant_k_scale), // k_cache_scale_inv
reinterpret_cast<D_Scale*>(quant_v_scale), // v_cache_scale_inv
reinterpret_cast<D_Scale*>(quant_k_zp), // k_cache_zp
reinterpret_cast<D_Scale*>(quant_v_zp), // v_cache_zp
is_cache_int8); // bool b_c8_pc
is_cache_int8, // bool b_c8_pc
rope_3d, // rope_3d
rope_3d_num_seqs);
PD_CHECK(ret == api::SUCCESS,
"infer_ops::split_rope_cache_kv_decoder failed.");
@@ -848,7 +866,9 @@ std::vector<paddle::Tensor> BlockAttn(
const paddle::optional<paddle::Tensor>& shift,
const paddle::optional<paddle::Tensor>& smooth,
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu) {
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
const std::string &pos_emb_type="NORMAL",
bool rope_3d=false) {
#define APPLY_KERNEL(TX, TC, TS) \
return BlockAttnKernel<TX, TC, TS>(qkv, \
key_cache, \
@@ -875,7 +895,9 @@ std::vector<paddle::Tensor> BlockAttn(
shift, \
smooth, \
kv_signal_data_cpu, \
cachekv_signal_thread_cpu);
cachekv_signal_thread_cpu, \
pos_emb_type, \
rope_3d);
const auto cache_dtype = key_cache.dtype();
if (cache_dtype == paddle::DataType::BFLOAT16) {
@@ -940,6 +962,7 @@ PD_BUILD_STATIC_OP(block_attn)
paddle::Optional("smooth"),
paddle::Optional("kv_signal_data_cpu"),
paddle::Optional("cachekv_signal_thread_cpu")})
.Attrs({"pos_emb_type:std::string", "rope_3d:bool"})
.Outputs({"block_attn_out"})
.SetKernelFn(PD_KERNEL(BlockAttn))
.SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))

View File

@@ -198,5 +198,7 @@ class XPUAttentionBackend(AttentionBackend):
None, # smooth
None, # kv_signal_data
None, # kv_signal_sender
forward_meta.pos_emb_type,
self.rope_3d,
)
return res