diff --git a/custom_ops/xpu_ops/src/ops/block_attn.cc b/custom_ops/xpu_ops/src/ops/block_attn.cc index 449c44334..72ae24749 100644 --- a/custom_ops/xpu_ops/src/ops/block_attn.cc +++ b/custom_ops/xpu_ops/src/ops/block_attn.cc @@ -88,7 +88,9 @@ std::vector BlockAttnKernel( const paddle::optional& shift, const paddle::optional& smooth, const paddle::optional& kv_signal_data_cpu, - const paddle::optional& cachekv_signal_thread_cpu) { + const paddle::optional& cachekv_signal_thread_cpu, + const std::string &pos_emb_type, + bool rope_3d) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); auto xpu_ctx = static_cast(dev_ctx); @@ -130,6 +132,16 @@ std::vector BlockAttnKernel( int max_enc_len = len_info_cpu.data()[3]; int max_kv_len = len_info_cpu.data()[4]; int prefix_block_num_per_seq = len_info_cpu.data()[5]; + + int rope_max_seqlen = 0; + int rope_3d_num_seqs = 1; + if (rope_3d) { + rope_max_seqlen = rotary_embs.dims()[3]; + rope_3d_num_seqs = rotary_embs.dims()[0]; + } else { + rope_max_seqlen = rotary_embs.dims()[2]; + } + auto block_attn_out = paddle::empty({token_num, hidden_dim}, qkv.type(), qkv.place()); @@ -299,7 +311,7 @@ std::vector BlockAttnKernel( prefix_lens_vp, // start_tokens param.batch_size, // batch_size 1, // emb_batch_size - rotary_embs.dims()[2], // max_seqlen + rope_max_seqlen, // max_seqlen param.head_num, param.kv_head_num, param.head_dim, @@ -308,13 +320,15 @@ std::vector BlockAttnKernel( max_block_per_seq, "BLHD", "HLD", - "NORMAL", + pos_emb_type, nullptr, // k_cache_scale_inv - use for per head nullptr, // v_cache_scale_inv - use for per head quant_k_scale, // intx_k_pc_scale quant_v_scale, // intx_v_pc_scale quant_k_zp, // intx_k_pc_zero - quant_v_zp); // intx_v_pc_zero + quant_v_zp, // intx_v_pc_zero + rope_3d, // rope_3d + rope_3d_num_seqs); PD_CHECK(ret == api::SUCCESS, "infer_ops::split_rope_cache_kv_encoder failed."); // pd split @@ -504,7 +518,7 @@ std::vector BlockAttnKernel( decoder_context_len_cache_vp, // start_tokens (prefix len) param.batch_size, // batch_size 1, // emb_batch_size - rotary_embs.dims()[2], // max_seqlen + rope_max_seqlen, // max_seqlen param.head_num, param.kv_head_num, param.head_dim, @@ -513,13 +527,15 @@ std::vector BlockAttnKernel( max_block_per_seq, "BLHD", "HLD", - "NORMAL", + pos_emb_type, nullptr, // k_cache_scale_inv - use for per head nullptr, // v_cache_scale_inv - use for per head quant_k_scale, // intx_k_pc_scale quant_v_scale, // intx_v_pc_scale quant_k_zp, // intx_k_pc_zero - quant_v_zp); // intx_v_pc_zero + quant_v_zp, // intx_v_pc_zero + rope_3d, // rope_3d + rope_3d_num_seqs); PD_CHECK(ret == api::SUCCESS, "infer_ops::split_rope_cache_kv_encoder failed."); @@ -690,7 +706,7 @@ std::vector BlockAttnKernel( vsl.slot_mapping_vp, // real_batch param.batch_size, // batch_size 1, // emb_batch_size = rotary_embs.dims()[1] = 1 - rotary_embs.dims()[2], // TODO(lizan03) + rope_max_seqlen, // max_seqlen param.head_num, param.kv_head_num, param.head_dim, @@ -699,12 +715,14 @@ std::vector BlockAttnKernel( max_block_per_seq, "BLHD", "HLD", - "NORMAL", + pos_emb_type, reinterpret_cast(quant_k_scale), // k_cache_scale_inv reinterpret_cast(quant_v_scale), // v_cache_scale_inv reinterpret_cast(quant_k_zp), // k_cache_zp reinterpret_cast(quant_v_zp), // v_cache_zp - is_cache_int8); // bool b_c8_pc + is_cache_int8, // bool b_c8_pc + rope_3d, // rope_3d + rope_3d_num_seqs); PD_CHECK(ret == api::SUCCESS, "infer_ops::split_rope_cache_kv_decoder failed."); @@ -848,7 +866,9 @@ std::vector BlockAttn( const paddle::optional& shift, const paddle::optional& smooth, const paddle::optional& kv_signal_data_cpu, - const paddle::optional& cachekv_signal_thread_cpu) { + const paddle::optional& cachekv_signal_thread_cpu, + const std::string &pos_emb_type="NORMAL", + bool rope_3d=false) { #define APPLY_KERNEL(TX, TC, TS) \ return BlockAttnKernel(qkv, \ key_cache, \ @@ -875,7 +895,9 @@ std::vector BlockAttn( shift, \ smooth, \ kv_signal_data_cpu, \ - cachekv_signal_thread_cpu); + cachekv_signal_thread_cpu, \ + pos_emb_type, \ + rope_3d); const auto cache_dtype = key_cache.dtype(); if (cache_dtype == paddle::DataType::BFLOAT16) { @@ -940,6 +962,7 @@ PD_BUILD_STATIC_OP(block_attn) paddle::Optional("smooth"), paddle::Optional("kv_signal_data_cpu"), paddle::Optional("cachekv_signal_thread_cpu")}) + .Attrs({"pos_emb_type:std::string", "rope_3d:bool"}) .Outputs({"block_attn_out"}) .SetKernelFn(PD_KERNEL(BlockAttn)) .SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape)) diff --git a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py index 5735abf6f..3e48ab81f 100644 --- a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py @@ -198,5 +198,7 @@ class XPUAttentionBackend(AttentionBackend): None, # smooth None, # kv_signal_data None, # kv_signal_sender + forward_meta.pos_emb_type, + self.rope_3d, ) return res