mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-01 20:32:52 +08:00
[XPU] xpu support neox style ROPE (#4723)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Co-authored-by: ddchenhao66 <dhaochen163.com>
This commit is contained in:
@@ -291,46 +291,89 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
KV_BUF_TYPE,
|
||||
{total_enc_len, hidden_dim});
|
||||
// rope + cache
|
||||
int ret = infer_ops::
|
||||
split_rope_cache_kv_encoder<XPU_XType, float, XPU_CType, int, E_Scale>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
|
||||
reinterpret_cast<const float*>(
|
||||
rotary_embs.data<float>()), // rotary_pos_emb
|
||||
reinterpret_cast<const int*>(
|
||||
block_tables.data<int>()), // block_table
|
||||
q_buf.data<XPU_XType>(),
|
||||
k_buf.data<XPU_XType>(),
|
||||
v_buf.data<XPU_XType>(),
|
||||
const_cast<XPU_CType*>(
|
||||
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
|
||||
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
|
||||
value_cache.data<cdata_t>())),
|
||||
vsl.usual_lod_vp, // seq_lod
|
||||
vsl.slot_mapping_vp, // real_batch
|
||||
prefix_lens_vp, // start_tokens
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size
|
||||
rope_max_seqlen, // max_seqlen
|
||||
param.head_num,
|
||||
param.kv_head_num,
|
||||
param.head_dim,
|
||||
param.max_batch_size,
|
||||
block_size,
|
||||
max_block_per_seq,
|
||||
"BLHD",
|
||||
"HLD",
|
||||
pos_emb_type,
|
||||
nullptr, // k_cache_scale_inv - use for per head
|
||||
nullptr, // v_cache_scale_inv - use for per head
|
||||
quant_k_scale, // intx_k_pc_scale
|
||||
quant_v_scale, // intx_v_pc_scale
|
||||
quant_k_zp, // intx_k_pc_zero
|
||||
quant_v_zp, // intx_v_pc_zero
|
||||
rope_3d, // rope_3d
|
||||
rope_3d_num_seqs);
|
||||
PD_CHECK(ret == api::SUCCESS,
|
||||
"infer_ops::split_rope_cache_kv_encoder failed.");
|
||||
int ret = 0;
|
||||
if (pos_emb_type == "NEOX") {
|
||||
ret = infer_ops::
|
||||
split_neox_cache_kv_encoder<XPU_XType, float, XPU_CType, int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
|
||||
reinterpret_cast<const float*>(
|
||||
rotary_embs.data<float>()), // rotary_pos_emb
|
||||
reinterpret_cast<const int*>(
|
||||
block_tables.data<int>()), // block_table
|
||||
q_buf.data<XPU_XType>(),
|
||||
k_buf.data<XPU_XType>(),
|
||||
v_buf.data<XPU_XType>(),
|
||||
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
|
||||
key_cache.data<cdata_t>())),
|
||||
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
|
||||
value_cache.data<cdata_t>())),
|
||||
vsl.usual_lod_vp, // seq_lod
|
||||
vsl.slot_mapping_vp, // real_batch
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size
|
||||
rope_max_seqlen, // max_seqlen
|
||||
param.head_num,
|
||||
param.kv_head_num,
|
||||
param.head_dim,
|
||||
param.max_batch_size,
|
||||
block_size,
|
||||
max_block_per_seq,
|
||||
"BLHD",
|
||||
"HLD",
|
||||
pos_emb_type,
|
||||
nullptr, // k_cache_scale_inv - use for per head
|
||||
nullptr, // v_cache_scale_inv - use for per head
|
||||
nullptr, // intx_k_pc_scale
|
||||
nullptr, // intx_v_pc_scale
|
||||
nullptr, // intx_k_pc_zero
|
||||
nullptr, // intx_v_pc_zero
|
||||
rope_3d);
|
||||
PD_CHECK(ret == api::SUCCESS, "split_neox_cache_kv_encoder failed.");
|
||||
} else {
|
||||
ret = infer_ops::split_rope_cache_kv_encoder<XPU_XType,
|
||||
float,
|
||||
XPU_CType,
|
||||
int,
|
||||
E_Scale>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
|
||||
reinterpret_cast<const float*>(
|
||||
rotary_embs.data<float>()), // rotary_pos_emb
|
||||
reinterpret_cast<const int*>(
|
||||
block_tables.data<int>()), // block_table
|
||||
q_buf.data<XPU_XType>(),
|
||||
k_buf.data<XPU_XType>(),
|
||||
v_buf.data<XPU_XType>(),
|
||||
const_cast<XPU_CType*>(
|
||||
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
|
||||
const_cast<XPU_CType*>(
|
||||
reinterpret_cast<const XPU_CType*>(value_cache.data<cdata_t>())),
|
||||
vsl.usual_lod_vp, // seq_lod
|
||||
vsl.slot_mapping_vp, // real_batch
|
||||
prefix_lens_vp, // start_tokens
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size
|
||||
rope_max_seqlen, // max_seqlen
|
||||
param.head_num,
|
||||
param.kv_head_num,
|
||||
param.head_dim,
|
||||
param.max_batch_size,
|
||||
block_size,
|
||||
max_block_per_seq,
|
||||
"BLHD",
|
||||
"HLD",
|
||||
pos_emb_type,
|
||||
nullptr, // k_cache_scale_inv - use for per head
|
||||
nullptr, // v_cache_scale_inv - use for per head
|
||||
quant_k_scale, // intx_k_pc_scale
|
||||
quant_v_scale, // intx_v_pc_scale
|
||||
quant_k_zp, // intx_k_pc_zero
|
||||
quant_v_zp, // intx_v_pc_zero
|
||||
rope_3d);
|
||||
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_encoder failed.");
|
||||
}
|
||||
|
||||
// pd split
|
||||
if (FLAGS_fmt_write_cache_completed_signal) {
|
||||
XPUEvent write_event = nullptr;
|
||||
@@ -496,49 +539,88 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
nullptr}; // use for split rope enc as lod in MTP
|
||||
|
||||
// rope + cache
|
||||
int ret = infer_ops::split_rope_cache_kv_encoder<XPU_XType,
|
||||
float,
|
||||
XPU_CType,
|
||||
int,
|
||||
E_Scale>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
|
||||
reinterpret_cast<const float*>(
|
||||
rotary_embs.data<float>()), // rotary_pos_emb
|
||||
reinterpret_cast<const int*>(
|
||||
block_tables.data<int>()), // block_table
|
||||
q_buf.data<XPU_XType>(),
|
||||
k_buf.data<XPU_XType>(),
|
||||
v_buf.data<XPU_XType>(),
|
||||
const_cast<XPU_CType*>(
|
||||
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
|
||||
const_cast<XPU_CType*>(
|
||||
reinterpret_cast<const XPU_CType*>(value_cache.data<cdata_t>())),
|
||||
decoder_seq_lod_vp, // seq_lod
|
||||
decoder_batch_map_vp, // real_batch
|
||||
decoder_context_len_cache_vp, // start_tokens (prefix len)
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size
|
||||
rope_max_seqlen, // max_seqlen
|
||||
param.head_num,
|
||||
param.kv_head_num,
|
||||
param.head_dim,
|
||||
param.max_batch_size,
|
||||
block_size,
|
||||
max_block_per_seq,
|
||||
"BLHD",
|
||||
"HLD",
|
||||
pos_emb_type,
|
||||
nullptr, // k_cache_scale_inv - use for per head
|
||||
nullptr, // v_cache_scale_inv - use for per head
|
||||
quant_k_scale, // intx_k_pc_scale
|
||||
quant_v_scale, // intx_v_pc_scale
|
||||
quant_k_zp, // intx_k_pc_zero
|
||||
quant_v_zp, // intx_v_pc_zero
|
||||
rope_3d, // rope_3d
|
||||
rope_3d_num_seqs);
|
||||
PD_CHECK(ret == api::SUCCESS,
|
||||
"infer_ops::split_rope_cache_kv_encoder failed.");
|
||||
int ret = 0;
|
||||
if (pos_emb_type == "NEOX") {
|
||||
ret = infer_ops::
|
||||
split_neox_cache_kv_encoder<XPU_XType, float, XPU_CType, int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
|
||||
reinterpret_cast<const float*>(
|
||||
rotary_embs.data<float>()), // rotary_pos_emb
|
||||
reinterpret_cast<const int*>(
|
||||
block_tables.data<int>()), // block_table
|
||||
q_buf.data<XPU_XType>(),
|
||||
k_buf.data<XPU_XType>(),
|
||||
v_buf.data<XPU_XType>(),
|
||||
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
|
||||
key_cache.data<cdata_t>())),
|
||||
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
|
||||
value_cache.data<cdata_t>())),
|
||||
decoder_seq_lod_vp, // seq_lod
|
||||
decoder_batch_map_vp, // real_batch
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size
|
||||
rope_max_seqlen, // max_seqlen
|
||||
param.head_num,
|
||||
param.kv_head_num,
|
||||
param.head_dim,
|
||||
param.max_batch_size,
|
||||
block_size,
|
||||
max_block_per_seq,
|
||||
"BLHD",
|
||||
"HLD",
|
||||
pos_emb_type,
|
||||
nullptr, // k_cache_scale_inv - use for per head
|
||||
nullptr, // v_cache_scale_inv - use for per head
|
||||
nullptr, // intx_k_pc_scale
|
||||
nullptr, // intx_v_pc_scale
|
||||
nullptr, // intx_k_pc_zero
|
||||
nullptr, // intx_v_pc_zero
|
||||
rope_3d);
|
||||
PD_CHECK(ret == api::SUCCESS, "split_neox_cache_kv_encoder failed.");
|
||||
} else {
|
||||
ret = infer_ops::split_rope_cache_kv_encoder<XPU_XType,
|
||||
float,
|
||||
XPU_CType,
|
||||
int,
|
||||
E_Scale>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
|
||||
reinterpret_cast<const float*>(
|
||||
rotary_embs.data<float>()), // rotary_pos_emb
|
||||
reinterpret_cast<const int*>(
|
||||
block_tables.data<int>()), // block_table
|
||||
q_buf.data<XPU_XType>(),
|
||||
k_buf.data<XPU_XType>(),
|
||||
v_buf.data<XPU_XType>(),
|
||||
const_cast<XPU_CType*>(
|
||||
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
|
||||
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
|
||||
value_cache.data<cdata_t>())),
|
||||
decoder_seq_lod_vp, // seq_lod
|
||||
decoder_batch_map_vp, // real_batch
|
||||
decoder_context_len_cache_vp, // start_tokens (prefix len)
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size
|
||||
rope_max_seqlen, // max_seqlen
|
||||
param.head_num,
|
||||
param.kv_head_num,
|
||||
param.head_dim,
|
||||
param.max_batch_size,
|
||||
block_size,
|
||||
max_block_per_seq,
|
||||
"BLHD",
|
||||
"HLD",
|
||||
pos_emb_type,
|
||||
nullptr, // k_cache_scale_inv - use for per head
|
||||
nullptr, // v_cache_scale_inv - use for per head
|
||||
quant_k_scale, // intx_k_pc_scale
|
||||
quant_v_scale, // intx_v_pc_scale
|
||||
quant_k_zp, // intx_k_pc_zero
|
||||
quant_v_zp, // intx_v_pc_zero
|
||||
rope_3d);
|
||||
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_encoder failed.");
|
||||
}
|
||||
|
||||
float* fake_perhead_scale = nullptr;
|
||||
if (is_cache_int8 && has_zp) {
|
||||
@@ -684,48 +766,89 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
param.page_attn.block_table = &block_table_tensor;
|
||||
|
||||
// rope + cache
|
||||
int ret = infer_ops::split_rope_cache_kv_decoder<XPU_XType,
|
||||
float,
|
||||
XPU_CType,
|
||||
D_Scale,
|
||||
int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()) +
|
||||
total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv
|
||||
reinterpret_cast<const float*>(
|
||||
rotary_embs.data<float>()), // rotary_pos_emb
|
||||
reinterpret_cast<const int*>(
|
||||
block_tables.data<int>()), // block_table
|
||||
q_buf.data<XPU_XType>(),
|
||||
nullptr,
|
||||
nullptr,
|
||||
const_cast<XPU_CType*>(
|
||||
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
|
||||
const_cast<XPU_CType*>(
|
||||
reinterpret_cast<const XPU_CType*>(value_cache.data<cdata_t>())),
|
||||
vsl.usual_lod_vp, // seq_lod
|
||||
vsl.slot_mapping_vp, // real_batch
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size = rotary_embs.dims()[1] = 1
|
||||
rope_max_seqlen, // max_seqlen
|
||||
param.head_num,
|
||||
param.kv_head_num,
|
||||
param.head_dim,
|
||||
param.max_batch_size,
|
||||
block_size,
|
||||
max_block_per_seq,
|
||||
"BLHD",
|
||||
"HLD",
|
||||
pos_emb_type,
|
||||
reinterpret_cast<D_Scale*>(quant_k_scale), // k_cache_scale_inv
|
||||
reinterpret_cast<D_Scale*>(quant_v_scale), // v_cache_scale_inv
|
||||
reinterpret_cast<D_Scale*>(quant_k_zp), // k_cache_zp
|
||||
reinterpret_cast<D_Scale*>(quant_v_zp), // v_cache_zp
|
||||
is_cache_int8, // bool b_c8_pc
|
||||
rope_3d, // rope_3d
|
||||
rope_3d_num_seqs);
|
||||
PD_CHECK(ret == api::SUCCESS,
|
||||
"infer_ops::split_rope_cache_kv_decoder failed.");
|
||||
int ret = 0;
|
||||
if (pos_emb_type == "NEOX") {
|
||||
ret = infer_ops::split_neox_cache_kv_decoder<XPU_XType,
|
||||
float,
|
||||
XPU_CType,
|
||||
D_Scale,
|
||||
int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()) +
|
||||
total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv
|
||||
reinterpret_cast<const float*>(
|
||||
rotary_embs.data<float>()), // rotary_pos_emb
|
||||
reinterpret_cast<const int*>(
|
||||
block_tables.data<int>()), // block_table
|
||||
q_buf.data<XPU_XType>(),
|
||||
nullptr,
|
||||
nullptr,
|
||||
const_cast<XPU_CType*>(
|
||||
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
|
||||
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
|
||||
value_cache.data<cdata_t>())),
|
||||
vsl.usual_lod_vp, // seq_lod
|
||||
vsl.slot_mapping_vp, // real_batch
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size = rotary_embs.dims()[1] = 1
|
||||
rope_max_seqlen, // max_seqlen
|
||||
param.head_num,
|
||||
param.kv_head_num,
|
||||
param.head_dim,
|
||||
param.max_batch_size,
|
||||
block_size,
|
||||
max_block_per_seq,
|
||||
"BLHD",
|
||||
"HLD",
|
||||
pos_emb_type,
|
||||
reinterpret_cast<D_Scale*>(quant_k_scale), // k_cache_scale_inv
|
||||
reinterpret_cast<D_Scale*>(quant_v_scale), // v_cache_scale_inv
|
||||
reinterpret_cast<D_Scale*>(quant_k_zp), // k_cache_zp
|
||||
reinterpret_cast<D_Scale*>(quant_v_zp), // v_cache_zp
|
||||
rope_3d);
|
||||
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_decoder failed.");
|
||||
} else {
|
||||
ret = infer_ops::split_rope_cache_kv_decoder<XPU_XType,
|
||||
float,
|
||||
XPU_CType,
|
||||
D_Scale,
|
||||
int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()) +
|
||||
total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv
|
||||
reinterpret_cast<const float*>(
|
||||
rotary_embs.data<float>()), // rotary_pos_emb
|
||||
reinterpret_cast<const int*>(
|
||||
block_tables.data<int>()), // block_table
|
||||
q_buf.data<XPU_XType>(),
|
||||
nullptr,
|
||||
nullptr,
|
||||
const_cast<XPU_CType*>(
|
||||
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
|
||||
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
|
||||
value_cache.data<cdata_t>())),
|
||||
vsl.usual_lod_vp, // seq_lod
|
||||
vsl.slot_mapping_vp, // real_batch
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size = rotary_embs.dims()[1] = 1
|
||||
rope_max_seqlen, // max_seqlen
|
||||
param.head_num,
|
||||
param.kv_head_num,
|
||||
param.head_dim,
|
||||
param.max_batch_size,
|
||||
block_size,
|
||||
max_block_per_seq,
|
||||
"BLHD",
|
||||
"HLD",
|
||||
pos_emb_type,
|
||||
reinterpret_cast<D_Scale*>(quant_k_scale), // k_cache_scale_inv
|
||||
reinterpret_cast<D_Scale*>(quant_v_scale), // v_cache_scale_inv
|
||||
reinterpret_cast<D_Scale*>(quant_k_zp), // k_cache_zp
|
||||
reinterpret_cast<D_Scale*>(quant_v_zp), // v_cache_zp
|
||||
is_cache_int8, // bool b_c8_pc
|
||||
rope_3d);
|
||||
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_decoder failed.");
|
||||
}
|
||||
|
||||
float* fake_perhead_scale = nullptr;
|
||||
if (is_cache_int8 && has_zp) {
|
||||
|
||||
@@ -847,12 +847,9 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
head_dim = self.model_config.head_dim
|
||||
if "paddleocr" in self.model_config.model_type: # neox style = True
|
||||
rope_head_dim = head_dim
|
||||
self.share_inputs["pos_emb_type"] = "NEOX"
|
||||
else: # neox style = False
|
||||
rope_head_dim = head_dim // 2
|
||||
|
||||
if rope_head_dim == self.model_config.head_dim:
|
||||
self.share_inputs["pos_emb_type"] = "NORMAL"
|
||||
else:
|
||||
self.share_inputs["pos_emb_type"] = "HALF_HEAD_DIM"
|
||||
|
||||
self.share_inputs["rope_emb"] = paddle.full(
|
||||
|
||||
Reference in New Issue
Block a user