mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] refactor of block_attn param 'pos_emb_type' (#5511)
This commit is contained in:
@@ -89,8 +89,8 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
const paddle::optional<paddle::Tensor>& smooth,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
|
||||
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
|
||||
const std::string& pos_emb_type,
|
||||
bool rope_3d) {
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
@@ -134,12 +134,25 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
int prefix_block_num_per_seq = len_info_cpu.data<int32_t>()[5];
|
||||
|
||||
int rope_max_seqlen = 0;
|
||||
int rope_3d_num_seqs = 1;
|
||||
int rope_head_dim = 0;
|
||||
if (rope_3d) {
|
||||
PD_CHECK(rotary_embs.dims().size() == 6,
|
||||
"rotary_embs dim size should be 6 in multi-modal model");
|
||||
rope_max_seqlen = rotary_embs.dims()[3];
|
||||
rope_3d_num_seqs = rotary_embs.dims()[0];
|
||||
rope_head_dim = rotary_embs.dims()[5];
|
||||
} else {
|
||||
PD_CHECK(rotary_embs.dims().size() == 5,
|
||||
"rotary_embs dim size should be 5 in language model");
|
||||
rope_max_seqlen = rotary_embs.dims()[2];
|
||||
rope_head_dim = rotary_embs.dims()[4];
|
||||
}
|
||||
std::string pos_emb_type;
|
||||
if (use_neox_rotary_style == true) {
|
||||
pos_emb_type = "NEOX";
|
||||
} else if (rope_head_dim == head_dim / 2) {
|
||||
pos_emb_type = "HALF_HEAD_DIM";
|
||||
} else {
|
||||
pos_emb_type = "NORMAL";
|
||||
}
|
||||
|
||||
auto block_attn_out =
|
||||
@@ -992,8 +1005,8 @@ std::vector<paddle::Tensor> BlockAttn(
|
||||
const paddle::optional<paddle::Tensor>& smooth,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
|
||||
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
|
||||
const std::string& pos_emb_type = "NORMAL",
|
||||
bool rope_3d = false) {
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d = false) {
|
||||
#define APPLY_KERNEL(TX, TC, TS) \
|
||||
return BlockAttnKernel<TX, TC, TS>(qkv, \
|
||||
key_cache, \
|
||||
@@ -1021,7 +1034,7 @@ std::vector<paddle::Tensor> BlockAttn(
|
||||
smooth, \
|
||||
kv_signal_data_cpu, \
|
||||
cachekv_signal_thread_cpu, \
|
||||
pos_emb_type, \
|
||||
use_neox_rotary_style, \
|
||||
rope_3d);
|
||||
|
||||
const auto cache_dtype = key_cache.dtype();
|
||||
@@ -1087,7 +1100,7 @@ PD_BUILD_STATIC_OP(block_attn)
|
||||
paddle::Optional("smooth"),
|
||||
paddle::Optional("kv_signal_data_cpu"),
|
||||
paddle::Optional("cachekv_signal_thread_cpu")})
|
||||
.Attrs({"pos_emb_type:std::string", "rope_3d:bool"})
|
||||
.Attrs({"use_neox_rotary_style:bool", "rope_3d:bool"})
|
||||
.Outputs({"block_attn_out"})
|
||||
.SetKernelFn(PD_KERNEL(BlockAttn))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))
|
||||
|
||||
@@ -85,8 +85,8 @@ std::vector<paddle::Tensor> BlockAttn(
|
||||
const paddle::optional<paddle::Tensor>& smooth,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
|
||||
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
|
||||
const std::string& pos_emb_type = "NORMAL",
|
||||
bool rope_3d = false);
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d = false);
|
||||
|
||||
std::vector<paddle::Tensor> MoeLayer(
|
||||
const paddle::Tensor& x,
|
||||
@@ -616,7 +616,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("smooth"),
|
||||
py::arg("kv_signal_data_cpu"),
|
||||
py::arg("cachekv_signal_thread_cpu"),
|
||||
py::arg("pos_emb_type") = "NORMAL",
|
||||
py::arg("use_neox_rotary_style"),
|
||||
py::arg("rope_3d") = false,
|
||||
"block attention in XPU");
|
||||
|
||||
|
||||
Reference in New Issue
Block a user