[XPU] refactor of block_attn param 'pos_emb_type' (#5511)

This commit is contained in:
Lucas
2025-12-12 14:30:09 +08:00
committed by GitHub
parent 4eb55332f6
commit 888c4b992d
6 changed files with 25 additions and 19 deletions

View File

@@ -89,8 +89,8 @@ std::vector<paddle::Tensor> BlockAttnKernel(
const paddle::optional<paddle::Tensor>& smooth, const paddle::optional<paddle::Tensor>& smooth,
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu, const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu, const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
const std::string& pos_emb_type, const bool use_neox_rotary_style,
bool rope_3d) { const bool rope_3d) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx); auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
@@ -134,12 +134,25 @@ std::vector<paddle::Tensor> BlockAttnKernel(
int prefix_block_num_per_seq = len_info_cpu.data<int32_t>()[5]; int prefix_block_num_per_seq = len_info_cpu.data<int32_t>()[5];
int rope_max_seqlen = 0; int rope_max_seqlen = 0;
int rope_3d_num_seqs = 1; int rope_head_dim = 0;
if (rope_3d) { if (rope_3d) {
PD_CHECK(rotary_embs.dims().size() == 6,
"rotary_embs dim size should be 6 in multi-modal model");
rope_max_seqlen = rotary_embs.dims()[3]; rope_max_seqlen = rotary_embs.dims()[3];
rope_3d_num_seqs = rotary_embs.dims()[0]; rope_head_dim = rotary_embs.dims()[5];
} else { } else {
PD_CHECK(rotary_embs.dims().size() == 5,
"rotary_embs dim size should be 5 in language model");
rope_max_seqlen = rotary_embs.dims()[2]; rope_max_seqlen = rotary_embs.dims()[2];
rope_head_dim = rotary_embs.dims()[4];
}
std::string pos_emb_type;
if (use_neox_rotary_style == true) {
pos_emb_type = "NEOX";
} else if (rope_head_dim == head_dim / 2) {
pos_emb_type = "HALF_HEAD_DIM";
} else {
pos_emb_type = "NORMAL";
} }
auto block_attn_out = auto block_attn_out =
@@ -992,8 +1005,8 @@ std::vector<paddle::Tensor> BlockAttn(
const paddle::optional<paddle::Tensor>& smooth, const paddle::optional<paddle::Tensor>& smooth,
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu, const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu, const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
const std::string& pos_emb_type = "NORMAL", const bool use_neox_rotary_style,
bool rope_3d = false) { const bool rope_3d = false) {
#define APPLY_KERNEL(TX, TC, TS) \ #define APPLY_KERNEL(TX, TC, TS) \
return BlockAttnKernel<TX, TC, TS>(qkv, \ return BlockAttnKernel<TX, TC, TS>(qkv, \
key_cache, \ key_cache, \
@@ -1021,7 +1034,7 @@ std::vector<paddle::Tensor> BlockAttn(
smooth, \ smooth, \
kv_signal_data_cpu, \ kv_signal_data_cpu, \
cachekv_signal_thread_cpu, \ cachekv_signal_thread_cpu, \
pos_emb_type, \ use_neox_rotary_style, \
rope_3d); rope_3d);
const auto cache_dtype = key_cache.dtype(); const auto cache_dtype = key_cache.dtype();
@@ -1087,7 +1100,7 @@ PD_BUILD_STATIC_OP(block_attn)
paddle::Optional("smooth"), paddle::Optional("smooth"),
paddle::Optional("kv_signal_data_cpu"), paddle::Optional("kv_signal_data_cpu"),
paddle::Optional("cachekv_signal_thread_cpu")}) paddle::Optional("cachekv_signal_thread_cpu")})
.Attrs({"pos_emb_type:std::string", "rope_3d:bool"}) .Attrs({"use_neox_rotary_style:bool", "rope_3d:bool"})
.Outputs({"block_attn_out"}) .Outputs({"block_attn_out"})
.SetKernelFn(PD_KERNEL(BlockAttn)) .SetKernelFn(PD_KERNEL(BlockAttn))
.SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape)) .SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))

View File

@@ -85,8 +85,8 @@ std::vector<paddle::Tensor> BlockAttn(
const paddle::optional<paddle::Tensor>& smooth, const paddle::optional<paddle::Tensor>& smooth,
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu, const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu, const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
const std::string& pos_emb_type = "NORMAL", const bool use_neox_rotary_style,
bool rope_3d = false); const bool rope_3d = false);
std::vector<paddle::Tensor> MoeLayer( std::vector<paddle::Tensor> MoeLayer(
const paddle::Tensor& x, const paddle::Tensor& x,
@@ -616,7 +616,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("smooth"), py::arg("smooth"),
py::arg("kv_signal_data_cpu"), py::arg("kv_signal_data_cpu"),
py::arg("cachekv_signal_thread_cpu"), py::arg("cachekv_signal_thread_cpu"),
py::arg("pos_emb_type") = "NORMAL", py::arg("use_neox_rotary_style"),
py::arg("rope_3d") = false, py::arg("rope_3d") = false,
"block attention in XPU"); "block attention in XPU");

View File

@@ -253,8 +253,6 @@ class XPUForwardMeta(ForwardMeta):
dec_batch: Optional[paddle.Tensor] = None dec_batch: Optional[paddle.Tensor] = None
# #
total_enc_len: Optional[paddle.Tensor] = None total_enc_len: Optional[paddle.Tensor] = None
# position embedding type in rope, supports 'NORMAL' or 'HALF_HEAD_DIM'
pos_emb_type: Optional[str] = "NORMAL"
# for pd_disaggregation # for pd_disaggregation
kv_signal_sender: Optional[paddle.Tensor] = None kv_signal_sender: Optional[paddle.Tensor] = None

View File

@@ -213,7 +213,7 @@ class XPUAttentionBackend(AttentionBackend):
None, # smooth None, # smooth
metadata.kv_signal_data_list[layer.layer_id], # kv_signal_data metadata.kv_signal_data_list[layer.layer_id], # kv_signal_data
forward_meta.kv_signal_sender, # kv_signal_sender forward_meta.kv_signal_sender, # kv_signal_sender
forward_meta.pos_emb_type, layer.use_neox_rotary_style,
self.rope_3d, self.rope_3d,
) )

View File

@@ -726,7 +726,6 @@ class MTPProposer(Proposer):
self.forward_meta.kv_batch_ids = (self.model_inputs["kv_batch_ids"],) self.forward_meta.kv_batch_ids = (self.model_inputs["kv_batch_ids"],)
self.forward_meta.kv_tile_ids_per_batch = (self.model_inputs["kv_tile_ids_per_batch"],) self.forward_meta.kv_tile_ids_per_batch = (self.model_inputs["kv_tile_ids_per_batch"],)
self.forward_meta.kv_num_blocks_x_cpu = (self.model_inputs["kv_num_blocks_x_cpu"],) self.forward_meta.kv_num_blocks_x_cpu = (self.model_inputs["kv_num_blocks_x_cpu"],)
self.forward_meta.pos_emb_type = "NORMAL"
self.forward_meta.attn_backend = self.attn_backends[0] self.forward_meta.attn_backend = self.attn_backends[0]
# Initialzie attention meta data # Initialzie attention meta data

View File

@@ -822,10 +822,8 @@ class XPUModelRunner(ModelRunnerBase):
head_dim = self.model_config.head_dim head_dim = self.model_config.head_dim
if "paddleocr" in self.model_config.model_type: # neox style = True if "paddleocr" in self.model_config.model_type: # neox style = True
rope_head_dim = head_dim rope_head_dim = head_dim
self.share_inputs["pos_emb_type"] = "NEOX"
else: # neox style = False else: # neox style = False
rope_head_dim = head_dim // 2 rope_head_dim = head_dim // 2
self.share_inputs["pos_emb_type"] = "HALF_HEAD_DIM"
self.share_inputs["rope_emb"] = paddle.full( self.share_inputs["rope_emb"] = paddle.full(
shape=[ shape=[
@@ -918,8 +916,6 @@ class XPUModelRunner(ModelRunnerBase):
# Update bad tokens len # Update bad tokens len
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"]) max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
if self.enable_mm:
self.forward_meta.pos_emb_type = self.share_inputs["pos_emb_type"]
self.forward_meta.attn_backend = self.attn_backends[0] self.forward_meta.attn_backend = self.attn_backends[0]
self.initialize_attention_backend() self.initialize_attention_backend()