[XPU] xpu support neox style ROPE (#4723)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

Co-authored-by: ddchenhao66 <dhaochen163.com>
This commit is contained in:
ddchenhao66
2025-10-31 18:17:21 +08:00
committed by GitHub
parent 00d0da0c18
commit ce53cdccd2
2 changed files with 249 additions and 129 deletions

View File

@@ -291,46 +291,89 @@ std::vector<paddle::Tensor> BlockAttnKernel(
KV_BUF_TYPE, KV_BUF_TYPE,
{total_enc_len, hidden_dim}); {total_enc_len, hidden_dim});
// rope + cache // rope + cache
int ret = infer_ops:: int ret = 0;
split_rope_cache_kv_encoder<XPU_XType, float, XPU_CType, int, E_Scale>( if (pos_emb_type == "NEOX") {
xpu_ctx->x_context(), ret = infer_ops::
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv split_neox_cache_kv_encoder<XPU_XType, float, XPU_CType, int>(
reinterpret_cast<const float*>( xpu_ctx->x_context(),
rotary_embs.data<float>()), // rotary_pos_emb reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
reinterpret_cast<const int*>( reinterpret_cast<const float*>(
block_tables.data<int>()), // block_table rotary_embs.data<float>()), // rotary_pos_emb
q_buf.data<XPU_XType>(), reinterpret_cast<const int*>(
k_buf.data<XPU_XType>(), block_tables.data<int>()), // block_table
v_buf.data<XPU_XType>(), q_buf.data<XPU_XType>(),
const_cast<XPU_CType*>( k_buf.data<XPU_XType>(),
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())), v_buf.data<XPU_XType>(),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>( const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>())), key_cache.data<cdata_t>())),
vsl.usual_lod_vp, // seq_lod const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
vsl.slot_mapping_vp, // real_batch value_cache.data<cdata_t>())),
prefix_lens_vp, // start_tokens vsl.usual_lod_vp, // seq_lod
param.batch_size, // batch_size vsl.slot_mapping_vp, // real_batch
1, // emb_batch_size param.batch_size, // batch_size
rope_max_seqlen, // max_seqlen 1, // emb_batch_size
param.head_num, rope_max_seqlen, // max_seqlen
param.kv_head_num, param.head_num,
param.head_dim, param.kv_head_num,
param.max_batch_size, param.head_dim,
block_size, param.max_batch_size,
max_block_per_seq, block_size,
"BLHD", max_block_per_seq,
"HLD", "BLHD",
pos_emb_type, "HLD",
nullptr, // k_cache_scale_inv - use for per head pos_emb_type,
nullptr, // v_cache_scale_inv - use for per head nullptr, // k_cache_scale_inv - use for per head
quant_k_scale, // intx_k_pc_scale nullptr, // v_cache_scale_inv - use for per head
quant_v_scale, // intx_v_pc_scale nullptr, // intx_k_pc_scale
quant_k_zp, // intx_k_pc_zero nullptr, // intx_v_pc_scale
quant_v_zp, // intx_v_pc_zero nullptr, // intx_k_pc_zero
rope_3d, // rope_3d nullptr, // intx_v_pc_zero
rope_3d_num_seqs); rope_3d);
PD_CHECK(ret == api::SUCCESS, PD_CHECK(ret == api::SUCCESS, "split_neox_cache_kv_encoder failed.");
"infer_ops::split_rope_cache_kv_encoder failed."); } else {
ret = infer_ops::split_rope_cache_kv_encoder<XPU_XType,
float,
XPU_CType,
int,
E_Scale>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
reinterpret_cast<const float*>(
rotary_embs.data<float>()), // rotary_pos_emb
reinterpret_cast<const int*>(
block_tables.data<int>()), // block_table
q_buf.data<XPU_XType>(),
k_buf.data<XPU_XType>(),
v_buf.data<XPU_XType>(),
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(value_cache.data<cdata_t>())),
vsl.usual_lod_vp, // seq_lod
vsl.slot_mapping_vp, // real_batch
prefix_lens_vp, // start_tokens
param.batch_size, // batch_size
1, // emb_batch_size
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
"BLHD",
"HLD",
pos_emb_type,
nullptr, // k_cache_scale_inv - use for per head
nullptr, // v_cache_scale_inv - use for per head
quant_k_scale, // intx_k_pc_scale
quant_v_scale, // intx_v_pc_scale
quant_k_zp, // intx_k_pc_zero
quant_v_zp, // intx_v_pc_zero
rope_3d);
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_encoder failed.");
}
// pd split // pd split
if (FLAGS_fmt_write_cache_completed_signal) { if (FLAGS_fmt_write_cache_completed_signal) {
XPUEvent write_event = nullptr; XPUEvent write_event = nullptr;
@@ -496,49 +539,88 @@ std::vector<paddle::Tensor> BlockAttnKernel(
nullptr}; // use for split rope enc as lod in MTP nullptr}; // use for split rope enc as lod in MTP
// rope + cache // rope + cache
int ret = infer_ops::split_rope_cache_kv_encoder<XPU_XType, int ret = 0;
float, if (pos_emb_type == "NEOX") {
XPU_CType, ret = infer_ops::
int, split_neox_cache_kv_encoder<XPU_XType, float, XPU_CType, int>(
E_Scale>( xpu_ctx->x_context(),
xpu_ctx->x_context(), reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv reinterpret_cast<const float*>(
reinterpret_cast<const float*>( rotary_embs.data<float>()), // rotary_pos_emb
rotary_embs.data<float>()), // rotary_pos_emb reinterpret_cast<const int*>(
reinterpret_cast<const int*>( block_tables.data<int>()), // block_table
block_tables.data<int>()), // block_table q_buf.data<XPU_XType>(),
q_buf.data<XPU_XType>(), k_buf.data<XPU_XType>(),
k_buf.data<XPU_XType>(), v_buf.data<XPU_XType>(),
v_buf.data<XPU_XType>(), const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
const_cast<XPU_CType*>( key_cache.data<cdata_t>())),
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())), const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
const_cast<XPU_CType*>( value_cache.data<cdata_t>())),
reinterpret_cast<const XPU_CType*>(value_cache.data<cdata_t>())), decoder_seq_lod_vp, // seq_lod
decoder_seq_lod_vp, // seq_lod decoder_batch_map_vp, // real_batch
decoder_batch_map_vp, // real_batch param.batch_size, // batch_size
decoder_context_len_cache_vp, // start_tokens (prefix len) 1, // emb_batch_size
param.batch_size, // batch_size rope_max_seqlen, // max_seqlen
1, // emb_batch_size param.head_num,
rope_max_seqlen, // max_seqlen param.kv_head_num,
param.head_num, param.head_dim,
param.kv_head_num, param.max_batch_size,
param.head_dim, block_size,
param.max_batch_size, max_block_per_seq,
block_size, "BLHD",
max_block_per_seq, "HLD",
"BLHD", pos_emb_type,
"HLD", nullptr, // k_cache_scale_inv - use for per head
pos_emb_type, nullptr, // v_cache_scale_inv - use for per head
nullptr, // k_cache_scale_inv - use for per head nullptr, // intx_k_pc_scale
nullptr, // v_cache_scale_inv - use for per head nullptr, // intx_v_pc_scale
quant_k_scale, // intx_k_pc_scale nullptr, // intx_k_pc_zero
quant_v_scale, // intx_v_pc_scale nullptr, // intx_v_pc_zero
quant_k_zp, // intx_k_pc_zero rope_3d);
quant_v_zp, // intx_v_pc_zero PD_CHECK(ret == api::SUCCESS, "split_neox_cache_kv_encoder failed.");
rope_3d, // rope_3d } else {
rope_3d_num_seqs); ret = infer_ops::split_rope_cache_kv_encoder<XPU_XType,
PD_CHECK(ret == api::SUCCESS, float,
"infer_ops::split_rope_cache_kv_encoder failed."); XPU_CType,
int,
E_Scale>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
reinterpret_cast<const float*>(
rotary_embs.data<float>()), // rotary_pos_emb
reinterpret_cast<const int*>(
block_tables.data<int>()), // block_table
q_buf.data<XPU_XType>(),
k_buf.data<XPU_XType>(),
v_buf.data<XPU_XType>(),
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>())),
decoder_seq_lod_vp, // seq_lod
decoder_batch_map_vp, // real_batch
decoder_context_len_cache_vp, // start_tokens (prefix len)
param.batch_size, // batch_size
1, // emb_batch_size
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
"BLHD",
"HLD",
pos_emb_type,
nullptr, // k_cache_scale_inv - use for per head
nullptr, // v_cache_scale_inv - use for per head
quant_k_scale, // intx_k_pc_scale
quant_v_scale, // intx_v_pc_scale
quant_k_zp, // intx_k_pc_zero
quant_v_zp, // intx_v_pc_zero
rope_3d);
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_encoder failed.");
}
float* fake_perhead_scale = nullptr; float* fake_perhead_scale = nullptr;
if (is_cache_int8 && has_zp) { if (is_cache_int8 && has_zp) {
@@ -684,48 +766,89 @@ std::vector<paddle::Tensor> BlockAttnKernel(
param.page_attn.block_table = &block_table_tensor; param.page_attn.block_table = &block_table_tensor;
// rope + cache // rope + cache
int ret = infer_ops::split_rope_cache_kv_decoder<XPU_XType, int ret = 0;
float, if (pos_emb_type == "NEOX") {
XPU_CType, ret = infer_ops::split_neox_cache_kv_decoder<XPU_XType,
D_Scale, float,
int>( XPU_CType,
xpu_ctx->x_context(), D_Scale,
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()) + int>(
total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv xpu_ctx->x_context(),
reinterpret_cast<const float*>( reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()) +
rotary_embs.data<float>()), // rotary_pos_emb total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv
reinterpret_cast<const int*>( reinterpret_cast<const float*>(
block_tables.data<int>()), // block_table rotary_embs.data<float>()), // rotary_pos_emb
q_buf.data<XPU_XType>(), reinterpret_cast<const int*>(
nullptr, block_tables.data<int>()), // block_table
nullptr, q_buf.data<XPU_XType>(),
const_cast<XPU_CType*>( nullptr,
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())), nullptr,
const_cast<XPU_CType*>( const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(value_cache.data<cdata_t>())), reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
vsl.usual_lod_vp, // seq_lod const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
vsl.slot_mapping_vp, // real_batch value_cache.data<cdata_t>())),
param.batch_size, // batch_size vsl.usual_lod_vp, // seq_lod
1, // emb_batch_size = rotary_embs.dims()[1] = 1 vsl.slot_mapping_vp, // real_batch
rope_max_seqlen, // max_seqlen param.batch_size, // batch_size
param.head_num, 1, // emb_batch_size = rotary_embs.dims()[1] = 1
param.kv_head_num, rope_max_seqlen, // max_seqlen
param.head_dim, param.head_num,
param.max_batch_size, param.kv_head_num,
block_size, param.head_dim,
max_block_per_seq, param.max_batch_size,
"BLHD", block_size,
"HLD", max_block_per_seq,
pos_emb_type, "BLHD",
reinterpret_cast<D_Scale*>(quant_k_scale), // k_cache_scale_inv "HLD",
reinterpret_cast<D_Scale*>(quant_v_scale), // v_cache_scale_inv pos_emb_type,
reinterpret_cast<D_Scale*>(quant_k_zp), // k_cache_zp reinterpret_cast<D_Scale*>(quant_k_scale), // k_cache_scale_inv
reinterpret_cast<D_Scale*>(quant_v_zp), // v_cache_zp reinterpret_cast<D_Scale*>(quant_v_scale), // v_cache_scale_inv
is_cache_int8, // bool b_c8_pc reinterpret_cast<D_Scale*>(quant_k_zp), // k_cache_zp
rope_3d, // rope_3d reinterpret_cast<D_Scale*>(quant_v_zp), // v_cache_zp
rope_3d_num_seqs); rope_3d);
PD_CHECK(ret == api::SUCCESS, PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_decoder failed.");
"infer_ops::split_rope_cache_kv_decoder failed."); } else {
ret = infer_ops::split_rope_cache_kv_decoder<XPU_XType,
float,
XPU_CType,
D_Scale,
int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()) +
total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv
reinterpret_cast<const float*>(
rotary_embs.data<float>()), // rotary_pos_emb
reinterpret_cast<const int*>(
block_tables.data<int>()), // block_table
q_buf.data<XPU_XType>(),
nullptr,
nullptr,
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>())),
vsl.usual_lod_vp, // seq_lod
vsl.slot_mapping_vp, // real_batch
param.batch_size, // batch_size
1, // emb_batch_size = rotary_embs.dims()[1] = 1
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
"BLHD",
"HLD",
pos_emb_type,
reinterpret_cast<D_Scale*>(quant_k_scale), // k_cache_scale_inv
reinterpret_cast<D_Scale*>(quant_v_scale), // v_cache_scale_inv
reinterpret_cast<D_Scale*>(quant_k_zp), // k_cache_zp
reinterpret_cast<D_Scale*>(quant_v_zp), // v_cache_zp
is_cache_int8, // bool b_c8_pc
rope_3d);
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_decoder failed.");
}
float* fake_perhead_scale = nullptr; float* fake_perhead_scale = nullptr;
if (is_cache_int8 && has_zp) { if (is_cache_int8 && has_zp) {

View File

@@ -847,12 +847,9 @@ class XPUModelRunner(ModelRunnerBase):
head_dim = self.model_config.head_dim head_dim = self.model_config.head_dim
if "paddleocr" in self.model_config.model_type: # neox style = True if "paddleocr" in self.model_config.model_type: # neox style = True
rope_head_dim = head_dim rope_head_dim = head_dim
self.share_inputs["pos_emb_type"] = "NEOX"
else: # neox style = False else: # neox style = False
rope_head_dim = head_dim // 2 rope_head_dim = head_dim // 2
if rope_head_dim == self.model_config.head_dim:
self.share_inputs["pos_emb_type"] = "NORMAL"
else:
self.share_inputs["pos_emb_type"] = "HALF_HEAD_DIM" self.share_inputs["pos_emb_type"] = "HALF_HEAD_DIM"
self.share_inputs["rope_emb"] = paddle.full( self.share_inputs["rope_emb"] = paddle.full(