[XPU] xpu support neox style ROPE (#4723)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

Co-authored-by: ddchenhao66 <dhaochen163.com>
This commit is contained in:
ddchenhao66
2025-10-31 18:17:21 +08:00
committed by GitHub
parent 00d0da0c18
commit ce53cdccd2
2 changed files with 249 additions and 129 deletions

View File

@@ -291,46 +291,89 @@ std::vector<paddle::Tensor> BlockAttnKernel(
KV_BUF_TYPE,
{total_enc_len, hidden_dim});
// rope + cache
int ret = infer_ops::
split_rope_cache_kv_encoder<XPU_XType, float, XPU_CType, int, E_Scale>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
reinterpret_cast<const float*>(
rotary_embs.data<float>()), // rotary_pos_emb
reinterpret_cast<const int*>(
block_tables.data<int>()), // block_table
q_buf.data<XPU_XType>(),
k_buf.data<XPU_XType>(),
v_buf.data<XPU_XType>(),
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>())),
vsl.usual_lod_vp, // seq_lod
vsl.slot_mapping_vp, // real_batch
prefix_lens_vp, // start_tokens
param.batch_size, // batch_size
1, // emb_batch_size
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
"BLHD",
"HLD",
pos_emb_type,
nullptr, // k_cache_scale_inv - use for per head
nullptr, // v_cache_scale_inv - use for per head
quant_k_scale, // intx_k_pc_scale
quant_v_scale, // intx_v_pc_scale
quant_k_zp, // intx_k_pc_zero
quant_v_zp, // intx_v_pc_zero
rope_3d, // rope_3d
rope_3d_num_seqs);
PD_CHECK(ret == api::SUCCESS,
"infer_ops::split_rope_cache_kv_encoder failed.");
int ret = 0;
if (pos_emb_type == "NEOX") {
ret = infer_ops::
split_neox_cache_kv_encoder<XPU_XType, float, XPU_CType, int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
reinterpret_cast<const float*>(
rotary_embs.data<float>()), // rotary_pos_emb
reinterpret_cast<const int*>(
block_tables.data<int>()), // block_table
q_buf.data<XPU_XType>(),
k_buf.data<XPU_XType>(),
v_buf.data<XPU_XType>(),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>())),
vsl.usual_lod_vp, // seq_lod
vsl.slot_mapping_vp, // real_batch
param.batch_size, // batch_size
1, // emb_batch_size
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
"BLHD",
"HLD",
pos_emb_type,
nullptr, // k_cache_scale_inv - use for per head
nullptr, // v_cache_scale_inv - use for per head
nullptr, // intx_k_pc_scale
nullptr, // intx_v_pc_scale
nullptr, // intx_k_pc_zero
nullptr, // intx_v_pc_zero
rope_3d);
PD_CHECK(ret == api::SUCCESS, "split_neox_cache_kv_encoder failed.");
} else {
ret = infer_ops::split_rope_cache_kv_encoder<XPU_XType,
float,
XPU_CType,
int,
E_Scale>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
reinterpret_cast<const float*>(
rotary_embs.data<float>()), // rotary_pos_emb
reinterpret_cast<const int*>(
block_tables.data<int>()), // block_table
q_buf.data<XPU_XType>(),
k_buf.data<XPU_XType>(),
v_buf.data<XPU_XType>(),
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(value_cache.data<cdata_t>())),
vsl.usual_lod_vp, // seq_lod
vsl.slot_mapping_vp, // real_batch
prefix_lens_vp, // start_tokens
param.batch_size, // batch_size
1, // emb_batch_size
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
"BLHD",
"HLD",
pos_emb_type,
nullptr, // k_cache_scale_inv - use for per head
nullptr, // v_cache_scale_inv - use for per head
quant_k_scale, // intx_k_pc_scale
quant_v_scale, // intx_v_pc_scale
quant_k_zp, // intx_k_pc_zero
quant_v_zp, // intx_v_pc_zero
rope_3d);
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_encoder failed.");
}
// pd split
if (FLAGS_fmt_write_cache_completed_signal) {
XPUEvent write_event = nullptr;
@@ -496,49 +539,88 @@ std::vector<paddle::Tensor> BlockAttnKernel(
nullptr}; // use for split rope enc as lod in MTP
// rope + cache
int ret = infer_ops::split_rope_cache_kv_encoder<XPU_XType,
float,
XPU_CType,
int,
E_Scale>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
reinterpret_cast<const float*>(
rotary_embs.data<float>()), // rotary_pos_emb
reinterpret_cast<const int*>(
block_tables.data<int>()), // block_table
q_buf.data<XPU_XType>(),
k_buf.data<XPU_XType>(),
v_buf.data<XPU_XType>(),
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(value_cache.data<cdata_t>())),
decoder_seq_lod_vp, // seq_lod
decoder_batch_map_vp, // real_batch
decoder_context_len_cache_vp, // start_tokens (prefix len)
param.batch_size, // batch_size
1, // emb_batch_size
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
"BLHD",
"HLD",
pos_emb_type,
nullptr, // k_cache_scale_inv - use for per head
nullptr, // v_cache_scale_inv - use for per head
quant_k_scale, // intx_k_pc_scale
quant_v_scale, // intx_v_pc_scale
quant_k_zp, // intx_k_pc_zero
quant_v_zp, // intx_v_pc_zero
rope_3d, // rope_3d
rope_3d_num_seqs);
PD_CHECK(ret == api::SUCCESS,
"infer_ops::split_rope_cache_kv_encoder failed.");
int ret = 0;
if (pos_emb_type == "NEOX") {
ret = infer_ops::
split_neox_cache_kv_encoder<XPU_XType, float, XPU_CType, int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
reinterpret_cast<const float*>(
rotary_embs.data<float>()), // rotary_pos_emb
reinterpret_cast<const int*>(
block_tables.data<int>()), // block_table
q_buf.data<XPU_XType>(),
k_buf.data<XPU_XType>(),
v_buf.data<XPU_XType>(),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>())),
decoder_seq_lod_vp, // seq_lod
decoder_batch_map_vp, // real_batch
param.batch_size, // batch_size
1, // emb_batch_size
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
"BLHD",
"HLD",
pos_emb_type,
nullptr, // k_cache_scale_inv - use for per head
nullptr, // v_cache_scale_inv - use for per head
nullptr, // intx_k_pc_scale
nullptr, // intx_v_pc_scale
nullptr, // intx_k_pc_zero
nullptr, // intx_v_pc_zero
rope_3d);
PD_CHECK(ret == api::SUCCESS, "split_neox_cache_kv_encoder failed.");
} else {
ret = infer_ops::split_rope_cache_kv_encoder<XPU_XType,
float,
XPU_CType,
int,
E_Scale>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
reinterpret_cast<const float*>(
rotary_embs.data<float>()), // rotary_pos_emb
reinterpret_cast<const int*>(
block_tables.data<int>()), // block_table
q_buf.data<XPU_XType>(),
k_buf.data<XPU_XType>(),
v_buf.data<XPU_XType>(),
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>())),
decoder_seq_lod_vp, // seq_lod
decoder_batch_map_vp, // real_batch
decoder_context_len_cache_vp, // start_tokens (prefix len)
param.batch_size, // batch_size
1, // emb_batch_size
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
"BLHD",
"HLD",
pos_emb_type,
nullptr, // k_cache_scale_inv - use for per head
nullptr, // v_cache_scale_inv - use for per head
quant_k_scale, // intx_k_pc_scale
quant_v_scale, // intx_v_pc_scale
quant_k_zp, // intx_k_pc_zero
quant_v_zp, // intx_v_pc_zero
rope_3d);
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_encoder failed.");
}
float* fake_perhead_scale = nullptr;
if (is_cache_int8 && has_zp) {
@@ -684,48 +766,89 @@ std::vector<paddle::Tensor> BlockAttnKernel(
param.page_attn.block_table = &block_table_tensor;
// rope + cache
int ret = infer_ops::split_rope_cache_kv_decoder<XPU_XType,
float,
XPU_CType,
D_Scale,
int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()) +
total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv
reinterpret_cast<const float*>(
rotary_embs.data<float>()), // rotary_pos_emb
reinterpret_cast<const int*>(
block_tables.data<int>()), // block_table
q_buf.data<XPU_XType>(),
nullptr,
nullptr,
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(value_cache.data<cdata_t>())),
vsl.usual_lod_vp, // seq_lod
vsl.slot_mapping_vp, // real_batch
param.batch_size, // batch_size
1, // emb_batch_size = rotary_embs.dims()[1] = 1
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
"BLHD",
"HLD",
pos_emb_type,
reinterpret_cast<D_Scale*>(quant_k_scale), // k_cache_scale_inv
reinterpret_cast<D_Scale*>(quant_v_scale), // v_cache_scale_inv
reinterpret_cast<D_Scale*>(quant_k_zp), // k_cache_zp
reinterpret_cast<D_Scale*>(quant_v_zp), // v_cache_zp
is_cache_int8, // bool b_c8_pc
rope_3d, // rope_3d
rope_3d_num_seqs);
PD_CHECK(ret == api::SUCCESS,
"infer_ops::split_rope_cache_kv_decoder failed.");
int ret = 0;
if (pos_emb_type == "NEOX") {
ret = infer_ops::split_neox_cache_kv_decoder<XPU_XType,
float,
XPU_CType,
D_Scale,
int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()) +
total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv
reinterpret_cast<const float*>(
rotary_embs.data<float>()), // rotary_pos_emb
reinterpret_cast<const int*>(
block_tables.data<int>()), // block_table
q_buf.data<XPU_XType>(),
nullptr,
nullptr,
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>())),
vsl.usual_lod_vp, // seq_lod
vsl.slot_mapping_vp, // real_batch
param.batch_size, // batch_size
1, // emb_batch_size = rotary_embs.dims()[1] = 1
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
"BLHD",
"HLD",
pos_emb_type,
reinterpret_cast<D_Scale*>(quant_k_scale), // k_cache_scale_inv
reinterpret_cast<D_Scale*>(quant_v_scale), // v_cache_scale_inv
reinterpret_cast<D_Scale*>(quant_k_zp), // k_cache_zp
reinterpret_cast<D_Scale*>(quant_v_zp), // v_cache_zp
rope_3d);
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_decoder failed.");
} else {
ret = infer_ops::split_rope_cache_kv_decoder<XPU_XType,
float,
XPU_CType,
D_Scale,
int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()) +
total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv
reinterpret_cast<const float*>(
rotary_embs.data<float>()), // rotary_pos_emb
reinterpret_cast<const int*>(
block_tables.data<int>()), // block_table
q_buf.data<XPU_XType>(),
nullptr,
nullptr,
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>())),
vsl.usual_lod_vp, // seq_lod
vsl.slot_mapping_vp, // real_batch
param.batch_size, // batch_size
1, // emb_batch_size = rotary_embs.dims()[1] = 1
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
"BLHD",
"HLD",
pos_emb_type,
reinterpret_cast<D_Scale*>(quant_k_scale), // k_cache_scale_inv
reinterpret_cast<D_Scale*>(quant_v_scale), // v_cache_scale_inv
reinterpret_cast<D_Scale*>(quant_k_zp), // k_cache_zp
reinterpret_cast<D_Scale*>(quant_v_zp), // v_cache_zp
is_cache_int8, // bool b_c8_pc
rope_3d);
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_decoder failed.");
}
float* fake_perhead_scale = nullptr;
if (is_cache_int8 && has_zp) {

View File

@@ -847,12 +847,9 @@ class XPUModelRunner(ModelRunnerBase):
head_dim = self.model_config.head_dim
if "paddleocr" in self.model_config.model_type: # neox style = True
rope_head_dim = head_dim
self.share_inputs["pos_emb_type"] = "NEOX"
else: # neox style = False
rope_head_dim = head_dim // 2
if rope_head_dim == self.model_config.head_dim:
self.share_inputs["pos_emb_type"] = "NORMAL"
else:
self.share_inputs["pos_emb_type"] = "HALF_HEAD_DIM"
self.share_inputs["rope_emb"] = paddle.full(