mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] fix VL multi-batch accuracy issue (#4394)
This commit is contained in:
@@ -88,7 +88,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
const paddle::optional<paddle::Tensor>& shift,
|
||||
const paddle::optional<paddle::Tensor>& smooth,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
|
||||
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu) {
|
||||
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
|
||||
const std::string &pos_emb_type,
|
||||
bool rope_3d) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
@@ -130,6 +132,16 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
int max_enc_len = len_info_cpu.data<int32_t>()[3];
|
||||
int max_kv_len = len_info_cpu.data<int32_t>()[4];
|
||||
int prefix_block_num_per_seq = len_info_cpu.data<int32_t>()[5];
|
||||
|
||||
int rope_max_seqlen = 0;
|
||||
int rope_3d_num_seqs = 1;
|
||||
if (rope_3d) {
|
||||
rope_max_seqlen = rotary_embs.dims()[3];
|
||||
rope_3d_num_seqs = rotary_embs.dims()[0];
|
||||
} else {
|
||||
rope_max_seqlen = rotary_embs.dims()[2];
|
||||
}
|
||||
|
||||
auto block_attn_out =
|
||||
paddle::empty({token_num, hidden_dim}, qkv.type(), qkv.place());
|
||||
|
||||
@@ -299,7 +311,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
prefix_lens_vp, // start_tokens
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size
|
||||
rotary_embs.dims()[2], // max_seqlen
|
||||
rope_max_seqlen, // max_seqlen
|
||||
param.head_num,
|
||||
param.kv_head_num,
|
||||
param.head_dim,
|
||||
@@ -308,13 +320,15 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
max_block_per_seq,
|
||||
"BLHD",
|
||||
"HLD",
|
||||
"NORMAL",
|
||||
pos_emb_type,
|
||||
nullptr, // k_cache_scale_inv - use for per head
|
||||
nullptr, // v_cache_scale_inv - use for per head
|
||||
quant_k_scale, // intx_k_pc_scale
|
||||
quant_v_scale, // intx_v_pc_scale
|
||||
quant_k_zp, // intx_k_pc_zero
|
||||
quant_v_zp); // intx_v_pc_zero
|
||||
quant_v_zp, // intx_v_pc_zero
|
||||
rope_3d, // rope_3d
|
||||
rope_3d_num_seqs);
|
||||
PD_CHECK(ret == api::SUCCESS,
|
||||
"infer_ops::split_rope_cache_kv_encoder failed.");
|
||||
// pd split
|
||||
@@ -504,7 +518,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
decoder_context_len_cache_vp, // start_tokens (prefix len)
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size
|
||||
rotary_embs.dims()[2], // max_seqlen
|
||||
rope_max_seqlen, // max_seqlen
|
||||
param.head_num,
|
||||
param.kv_head_num,
|
||||
param.head_dim,
|
||||
@@ -513,13 +527,15 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
max_block_per_seq,
|
||||
"BLHD",
|
||||
"HLD",
|
||||
"NORMAL",
|
||||
pos_emb_type,
|
||||
nullptr, // k_cache_scale_inv - use for per head
|
||||
nullptr, // v_cache_scale_inv - use for per head
|
||||
quant_k_scale, // intx_k_pc_scale
|
||||
quant_v_scale, // intx_v_pc_scale
|
||||
quant_k_zp, // intx_k_pc_zero
|
||||
quant_v_zp); // intx_v_pc_zero
|
||||
quant_v_zp, // intx_v_pc_zero
|
||||
rope_3d, // rope_3d
|
||||
rope_3d_num_seqs);
|
||||
PD_CHECK(ret == api::SUCCESS,
|
||||
"infer_ops::split_rope_cache_kv_encoder failed.");
|
||||
|
||||
@@ -690,7 +706,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
vsl.slot_mapping_vp, // real_batch
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size = rotary_embs.dims()[1] = 1
|
||||
rotary_embs.dims()[2], // TODO(lizan03)
|
||||
rope_max_seqlen, // max_seqlen
|
||||
param.head_num,
|
||||
param.kv_head_num,
|
||||
param.head_dim,
|
||||
@@ -699,12 +715,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
max_block_per_seq,
|
||||
"BLHD",
|
||||
"HLD",
|
||||
"NORMAL",
|
||||
pos_emb_type,
|
||||
reinterpret_cast<D_Scale*>(quant_k_scale), // k_cache_scale_inv
|
||||
reinterpret_cast<D_Scale*>(quant_v_scale), // v_cache_scale_inv
|
||||
reinterpret_cast<D_Scale*>(quant_k_zp), // k_cache_zp
|
||||
reinterpret_cast<D_Scale*>(quant_v_zp), // v_cache_zp
|
||||
is_cache_int8); // bool b_c8_pc
|
||||
is_cache_int8, // bool b_c8_pc
|
||||
rope_3d, // rope_3d
|
||||
rope_3d_num_seqs);
|
||||
PD_CHECK(ret == api::SUCCESS,
|
||||
"infer_ops::split_rope_cache_kv_decoder failed.");
|
||||
|
||||
@@ -848,7 +866,9 @@ std::vector<paddle::Tensor> BlockAttn(
|
||||
const paddle::optional<paddle::Tensor>& shift,
|
||||
const paddle::optional<paddle::Tensor>& smooth,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
|
||||
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu) {
|
||||
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
|
||||
const std::string &pos_emb_type="NORMAL",
|
||||
bool rope_3d=false) {
|
||||
#define APPLY_KERNEL(TX, TC, TS) \
|
||||
return BlockAttnKernel<TX, TC, TS>(qkv, \
|
||||
key_cache, \
|
||||
@@ -875,7 +895,9 @@ std::vector<paddle::Tensor> BlockAttn(
|
||||
shift, \
|
||||
smooth, \
|
||||
kv_signal_data_cpu, \
|
||||
cachekv_signal_thread_cpu);
|
||||
cachekv_signal_thread_cpu, \
|
||||
pos_emb_type, \
|
||||
rope_3d);
|
||||
|
||||
const auto cache_dtype = key_cache.dtype();
|
||||
if (cache_dtype == paddle::DataType::BFLOAT16) {
|
||||
@@ -940,6 +962,7 @@ PD_BUILD_STATIC_OP(block_attn)
|
||||
paddle::Optional("smooth"),
|
||||
paddle::Optional("kv_signal_data_cpu"),
|
||||
paddle::Optional("cachekv_signal_thread_cpu")})
|
||||
.Attrs({"pos_emb_type:std::string", "rope_3d:bool"})
|
||||
.Outputs({"block_attn_out"})
|
||||
.SetKernelFn(PD_KERNEL(BlockAttn))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))
|
||||
|
||||
Reference in New Issue
Block a user