mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-01 20:32:52 +08:00
[XPU] xpu support neox style ROPE (#4723)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Co-authored-by: ddchenhao66 <dhaochen163.com>
This commit is contained in:
@@ -291,46 +291,89 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
|||||||
KV_BUF_TYPE,
|
KV_BUF_TYPE,
|
||||||
{total_enc_len, hidden_dim});
|
{total_enc_len, hidden_dim});
|
||||||
// rope + cache
|
// rope + cache
|
||||||
int ret = infer_ops::
|
int ret = 0;
|
||||||
split_rope_cache_kv_encoder<XPU_XType, float, XPU_CType, int, E_Scale>(
|
if (pos_emb_type == "NEOX") {
|
||||||
xpu_ctx->x_context(),
|
ret = infer_ops::
|
||||||
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
|
split_neox_cache_kv_encoder<XPU_XType, float, XPU_CType, int>(
|
||||||
reinterpret_cast<const float*>(
|
xpu_ctx->x_context(),
|
||||||
rotary_embs.data<float>()), // rotary_pos_emb
|
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
|
||||||
reinterpret_cast<const int*>(
|
reinterpret_cast<const float*>(
|
||||||
block_tables.data<int>()), // block_table
|
rotary_embs.data<float>()), // rotary_pos_emb
|
||||||
q_buf.data<XPU_XType>(),
|
reinterpret_cast<const int*>(
|
||||||
k_buf.data<XPU_XType>(),
|
block_tables.data<int>()), // block_table
|
||||||
v_buf.data<XPU_XType>(),
|
q_buf.data<XPU_XType>(),
|
||||||
const_cast<XPU_CType*>(
|
k_buf.data<XPU_XType>(),
|
||||||
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
|
v_buf.data<XPU_XType>(),
|
||||||
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
|
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
|
||||||
value_cache.data<cdata_t>())),
|
key_cache.data<cdata_t>())),
|
||||||
vsl.usual_lod_vp, // seq_lod
|
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
|
||||||
vsl.slot_mapping_vp, // real_batch
|
value_cache.data<cdata_t>())),
|
||||||
prefix_lens_vp, // start_tokens
|
vsl.usual_lod_vp, // seq_lod
|
||||||
param.batch_size, // batch_size
|
vsl.slot_mapping_vp, // real_batch
|
||||||
1, // emb_batch_size
|
param.batch_size, // batch_size
|
||||||
rope_max_seqlen, // max_seqlen
|
1, // emb_batch_size
|
||||||
param.head_num,
|
rope_max_seqlen, // max_seqlen
|
||||||
param.kv_head_num,
|
param.head_num,
|
||||||
param.head_dim,
|
param.kv_head_num,
|
||||||
param.max_batch_size,
|
param.head_dim,
|
||||||
block_size,
|
param.max_batch_size,
|
||||||
max_block_per_seq,
|
block_size,
|
||||||
"BLHD",
|
max_block_per_seq,
|
||||||
"HLD",
|
"BLHD",
|
||||||
pos_emb_type,
|
"HLD",
|
||||||
nullptr, // k_cache_scale_inv - use for per head
|
pos_emb_type,
|
||||||
nullptr, // v_cache_scale_inv - use for per head
|
nullptr, // k_cache_scale_inv - use for per head
|
||||||
quant_k_scale, // intx_k_pc_scale
|
nullptr, // v_cache_scale_inv - use for per head
|
||||||
quant_v_scale, // intx_v_pc_scale
|
nullptr, // intx_k_pc_scale
|
||||||
quant_k_zp, // intx_k_pc_zero
|
nullptr, // intx_v_pc_scale
|
||||||
quant_v_zp, // intx_v_pc_zero
|
nullptr, // intx_k_pc_zero
|
||||||
rope_3d, // rope_3d
|
nullptr, // intx_v_pc_zero
|
||||||
rope_3d_num_seqs);
|
rope_3d);
|
||||||
PD_CHECK(ret == api::SUCCESS,
|
PD_CHECK(ret == api::SUCCESS, "split_neox_cache_kv_encoder failed.");
|
||||||
"infer_ops::split_rope_cache_kv_encoder failed.");
|
} else {
|
||||||
|
ret = infer_ops::split_rope_cache_kv_encoder<XPU_XType,
|
||||||
|
float,
|
||||||
|
XPU_CType,
|
||||||
|
int,
|
||||||
|
E_Scale>(
|
||||||
|
xpu_ctx->x_context(),
|
||||||
|
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
|
||||||
|
reinterpret_cast<const float*>(
|
||||||
|
rotary_embs.data<float>()), // rotary_pos_emb
|
||||||
|
reinterpret_cast<const int*>(
|
||||||
|
block_tables.data<int>()), // block_table
|
||||||
|
q_buf.data<XPU_XType>(),
|
||||||
|
k_buf.data<XPU_XType>(),
|
||||||
|
v_buf.data<XPU_XType>(),
|
||||||
|
const_cast<XPU_CType*>(
|
||||||
|
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
|
||||||
|
const_cast<XPU_CType*>(
|
||||||
|
reinterpret_cast<const XPU_CType*>(value_cache.data<cdata_t>())),
|
||||||
|
vsl.usual_lod_vp, // seq_lod
|
||||||
|
vsl.slot_mapping_vp, // real_batch
|
||||||
|
prefix_lens_vp, // start_tokens
|
||||||
|
param.batch_size, // batch_size
|
||||||
|
1, // emb_batch_size
|
||||||
|
rope_max_seqlen, // max_seqlen
|
||||||
|
param.head_num,
|
||||||
|
param.kv_head_num,
|
||||||
|
param.head_dim,
|
||||||
|
param.max_batch_size,
|
||||||
|
block_size,
|
||||||
|
max_block_per_seq,
|
||||||
|
"BLHD",
|
||||||
|
"HLD",
|
||||||
|
pos_emb_type,
|
||||||
|
nullptr, // k_cache_scale_inv - use for per head
|
||||||
|
nullptr, // v_cache_scale_inv - use for per head
|
||||||
|
quant_k_scale, // intx_k_pc_scale
|
||||||
|
quant_v_scale, // intx_v_pc_scale
|
||||||
|
quant_k_zp, // intx_k_pc_zero
|
||||||
|
quant_v_zp, // intx_v_pc_zero
|
||||||
|
rope_3d);
|
||||||
|
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_encoder failed.");
|
||||||
|
}
|
||||||
|
|
||||||
// pd split
|
// pd split
|
||||||
if (FLAGS_fmt_write_cache_completed_signal) {
|
if (FLAGS_fmt_write_cache_completed_signal) {
|
||||||
XPUEvent write_event = nullptr;
|
XPUEvent write_event = nullptr;
|
||||||
@@ -496,49 +539,88 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
|||||||
nullptr}; // use for split rope enc as lod in MTP
|
nullptr}; // use for split rope enc as lod in MTP
|
||||||
|
|
||||||
// rope + cache
|
// rope + cache
|
||||||
int ret = infer_ops::split_rope_cache_kv_encoder<XPU_XType,
|
int ret = 0;
|
||||||
float,
|
if (pos_emb_type == "NEOX") {
|
||||||
XPU_CType,
|
ret = infer_ops::
|
||||||
int,
|
split_neox_cache_kv_encoder<XPU_XType, float, XPU_CType, int>(
|
||||||
E_Scale>(
|
xpu_ctx->x_context(),
|
||||||
xpu_ctx->x_context(),
|
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
|
||||||
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
|
reinterpret_cast<const float*>(
|
||||||
reinterpret_cast<const float*>(
|
rotary_embs.data<float>()), // rotary_pos_emb
|
||||||
rotary_embs.data<float>()), // rotary_pos_emb
|
reinterpret_cast<const int*>(
|
||||||
reinterpret_cast<const int*>(
|
block_tables.data<int>()), // block_table
|
||||||
block_tables.data<int>()), // block_table
|
q_buf.data<XPU_XType>(),
|
||||||
q_buf.data<XPU_XType>(),
|
k_buf.data<XPU_XType>(),
|
||||||
k_buf.data<XPU_XType>(),
|
v_buf.data<XPU_XType>(),
|
||||||
v_buf.data<XPU_XType>(),
|
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
|
||||||
const_cast<XPU_CType*>(
|
key_cache.data<cdata_t>())),
|
||||||
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
|
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
|
||||||
const_cast<XPU_CType*>(
|
value_cache.data<cdata_t>())),
|
||||||
reinterpret_cast<const XPU_CType*>(value_cache.data<cdata_t>())),
|
decoder_seq_lod_vp, // seq_lod
|
||||||
decoder_seq_lod_vp, // seq_lod
|
decoder_batch_map_vp, // real_batch
|
||||||
decoder_batch_map_vp, // real_batch
|
param.batch_size, // batch_size
|
||||||
decoder_context_len_cache_vp, // start_tokens (prefix len)
|
1, // emb_batch_size
|
||||||
param.batch_size, // batch_size
|
rope_max_seqlen, // max_seqlen
|
||||||
1, // emb_batch_size
|
param.head_num,
|
||||||
rope_max_seqlen, // max_seqlen
|
param.kv_head_num,
|
||||||
param.head_num,
|
param.head_dim,
|
||||||
param.kv_head_num,
|
param.max_batch_size,
|
||||||
param.head_dim,
|
block_size,
|
||||||
param.max_batch_size,
|
max_block_per_seq,
|
||||||
block_size,
|
"BLHD",
|
||||||
max_block_per_seq,
|
"HLD",
|
||||||
"BLHD",
|
pos_emb_type,
|
||||||
"HLD",
|
nullptr, // k_cache_scale_inv - use for per head
|
||||||
pos_emb_type,
|
nullptr, // v_cache_scale_inv - use for per head
|
||||||
nullptr, // k_cache_scale_inv - use for per head
|
nullptr, // intx_k_pc_scale
|
||||||
nullptr, // v_cache_scale_inv - use for per head
|
nullptr, // intx_v_pc_scale
|
||||||
quant_k_scale, // intx_k_pc_scale
|
nullptr, // intx_k_pc_zero
|
||||||
quant_v_scale, // intx_v_pc_scale
|
nullptr, // intx_v_pc_zero
|
||||||
quant_k_zp, // intx_k_pc_zero
|
rope_3d);
|
||||||
quant_v_zp, // intx_v_pc_zero
|
PD_CHECK(ret == api::SUCCESS, "split_neox_cache_kv_encoder failed.");
|
||||||
rope_3d, // rope_3d
|
} else {
|
||||||
rope_3d_num_seqs);
|
ret = infer_ops::split_rope_cache_kv_encoder<XPU_XType,
|
||||||
PD_CHECK(ret == api::SUCCESS,
|
float,
|
||||||
"infer_ops::split_rope_cache_kv_encoder failed.");
|
XPU_CType,
|
||||||
|
int,
|
||||||
|
E_Scale>(
|
||||||
|
xpu_ctx->x_context(),
|
||||||
|
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
|
||||||
|
reinterpret_cast<const float*>(
|
||||||
|
rotary_embs.data<float>()), // rotary_pos_emb
|
||||||
|
reinterpret_cast<const int*>(
|
||||||
|
block_tables.data<int>()), // block_table
|
||||||
|
q_buf.data<XPU_XType>(),
|
||||||
|
k_buf.data<XPU_XType>(),
|
||||||
|
v_buf.data<XPU_XType>(),
|
||||||
|
const_cast<XPU_CType*>(
|
||||||
|
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
|
||||||
|
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
|
||||||
|
value_cache.data<cdata_t>())),
|
||||||
|
decoder_seq_lod_vp, // seq_lod
|
||||||
|
decoder_batch_map_vp, // real_batch
|
||||||
|
decoder_context_len_cache_vp, // start_tokens (prefix len)
|
||||||
|
param.batch_size, // batch_size
|
||||||
|
1, // emb_batch_size
|
||||||
|
rope_max_seqlen, // max_seqlen
|
||||||
|
param.head_num,
|
||||||
|
param.kv_head_num,
|
||||||
|
param.head_dim,
|
||||||
|
param.max_batch_size,
|
||||||
|
block_size,
|
||||||
|
max_block_per_seq,
|
||||||
|
"BLHD",
|
||||||
|
"HLD",
|
||||||
|
pos_emb_type,
|
||||||
|
nullptr, // k_cache_scale_inv - use for per head
|
||||||
|
nullptr, // v_cache_scale_inv - use for per head
|
||||||
|
quant_k_scale, // intx_k_pc_scale
|
||||||
|
quant_v_scale, // intx_v_pc_scale
|
||||||
|
quant_k_zp, // intx_k_pc_zero
|
||||||
|
quant_v_zp, // intx_v_pc_zero
|
||||||
|
rope_3d);
|
||||||
|
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_encoder failed.");
|
||||||
|
}
|
||||||
|
|
||||||
float* fake_perhead_scale = nullptr;
|
float* fake_perhead_scale = nullptr;
|
||||||
if (is_cache_int8 && has_zp) {
|
if (is_cache_int8 && has_zp) {
|
||||||
@@ -684,48 +766,89 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
|||||||
param.page_attn.block_table = &block_table_tensor;
|
param.page_attn.block_table = &block_table_tensor;
|
||||||
|
|
||||||
// rope + cache
|
// rope + cache
|
||||||
int ret = infer_ops::split_rope_cache_kv_decoder<XPU_XType,
|
int ret = 0;
|
||||||
float,
|
if (pos_emb_type == "NEOX") {
|
||||||
XPU_CType,
|
ret = infer_ops::split_neox_cache_kv_decoder<XPU_XType,
|
||||||
D_Scale,
|
float,
|
||||||
int>(
|
XPU_CType,
|
||||||
xpu_ctx->x_context(),
|
D_Scale,
|
||||||
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()) +
|
int>(
|
||||||
total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv
|
xpu_ctx->x_context(),
|
||||||
reinterpret_cast<const float*>(
|
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()) +
|
||||||
rotary_embs.data<float>()), // rotary_pos_emb
|
total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv
|
||||||
reinterpret_cast<const int*>(
|
reinterpret_cast<const float*>(
|
||||||
block_tables.data<int>()), // block_table
|
rotary_embs.data<float>()), // rotary_pos_emb
|
||||||
q_buf.data<XPU_XType>(),
|
reinterpret_cast<const int*>(
|
||||||
nullptr,
|
block_tables.data<int>()), // block_table
|
||||||
nullptr,
|
q_buf.data<XPU_XType>(),
|
||||||
const_cast<XPU_CType*>(
|
nullptr,
|
||||||
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
|
nullptr,
|
||||||
const_cast<XPU_CType*>(
|
const_cast<XPU_CType*>(
|
||||||
reinterpret_cast<const XPU_CType*>(value_cache.data<cdata_t>())),
|
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
|
||||||
vsl.usual_lod_vp, // seq_lod
|
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
|
||||||
vsl.slot_mapping_vp, // real_batch
|
value_cache.data<cdata_t>())),
|
||||||
param.batch_size, // batch_size
|
vsl.usual_lod_vp, // seq_lod
|
||||||
1, // emb_batch_size = rotary_embs.dims()[1] = 1
|
vsl.slot_mapping_vp, // real_batch
|
||||||
rope_max_seqlen, // max_seqlen
|
param.batch_size, // batch_size
|
||||||
param.head_num,
|
1, // emb_batch_size = rotary_embs.dims()[1] = 1
|
||||||
param.kv_head_num,
|
rope_max_seqlen, // max_seqlen
|
||||||
param.head_dim,
|
param.head_num,
|
||||||
param.max_batch_size,
|
param.kv_head_num,
|
||||||
block_size,
|
param.head_dim,
|
||||||
max_block_per_seq,
|
param.max_batch_size,
|
||||||
"BLHD",
|
block_size,
|
||||||
"HLD",
|
max_block_per_seq,
|
||||||
pos_emb_type,
|
"BLHD",
|
||||||
reinterpret_cast<D_Scale*>(quant_k_scale), // k_cache_scale_inv
|
"HLD",
|
||||||
reinterpret_cast<D_Scale*>(quant_v_scale), // v_cache_scale_inv
|
pos_emb_type,
|
||||||
reinterpret_cast<D_Scale*>(quant_k_zp), // k_cache_zp
|
reinterpret_cast<D_Scale*>(quant_k_scale), // k_cache_scale_inv
|
||||||
reinterpret_cast<D_Scale*>(quant_v_zp), // v_cache_zp
|
reinterpret_cast<D_Scale*>(quant_v_scale), // v_cache_scale_inv
|
||||||
is_cache_int8, // bool b_c8_pc
|
reinterpret_cast<D_Scale*>(quant_k_zp), // k_cache_zp
|
||||||
rope_3d, // rope_3d
|
reinterpret_cast<D_Scale*>(quant_v_zp), // v_cache_zp
|
||||||
rope_3d_num_seqs);
|
rope_3d);
|
||||||
PD_CHECK(ret == api::SUCCESS,
|
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_decoder failed.");
|
||||||
"infer_ops::split_rope_cache_kv_decoder failed.");
|
} else {
|
||||||
|
ret = infer_ops::split_rope_cache_kv_decoder<XPU_XType,
|
||||||
|
float,
|
||||||
|
XPU_CType,
|
||||||
|
D_Scale,
|
||||||
|
int>(
|
||||||
|
xpu_ctx->x_context(),
|
||||||
|
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()) +
|
||||||
|
total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv
|
||||||
|
reinterpret_cast<const float*>(
|
||||||
|
rotary_embs.data<float>()), // rotary_pos_emb
|
||||||
|
reinterpret_cast<const int*>(
|
||||||
|
block_tables.data<int>()), // block_table
|
||||||
|
q_buf.data<XPU_XType>(),
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
const_cast<XPU_CType*>(
|
||||||
|
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
|
||||||
|
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
|
||||||
|
value_cache.data<cdata_t>())),
|
||||||
|
vsl.usual_lod_vp, // seq_lod
|
||||||
|
vsl.slot_mapping_vp, // real_batch
|
||||||
|
param.batch_size, // batch_size
|
||||||
|
1, // emb_batch_size = rotary_embs.dims()[1] = 1
|
||||||
|
rope_max_seqlen, // max_seqlen
|
||||||
|
param.head_num,
|
||||||
|
param.kv_head_num,
|
||||||
|
param.head_dim,
|
||||||
|
param.max_batch_size,
|
||||||
|
block_size,
|
||||||
|
max_block_per_seq,
|
||||||
|
"BLHD",
|
||||||
|
"HLD",
|
||||||
|
pos_emb_type,
|
||||||
|
reinterpret_cast<D_Scale*>(quant_k_scale), // k_cache_scale_inv
|
||||||
|
reinterpret_cast<D_Scale*>(quant_v_scale), // v_cache_scale_inv
|
||||||
|
reinterpret_cast<D_Scale*>(quant_k_zp), // k_cache_zp
|
||||||
|
reinterpret_cast<D_Scale*>(quant_v_zp), // v_cache_zp
|
||||||
|
is_cache_int8, // bool b_c8_pc
|
||||||
|
rope_3d);
|
||||||
|
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_decoder failed.");
|
||||||
|
}
|
||||||
|
|
||||||
float* fake_perhead_scale = nullptr;
|
float* fake_perhead_scale = nullptr;
|
||||||
if (is_cache_int8 && has_zp) {
|
if (is_cache_int8 && has_zp) {
|
||||||
|
|||||||
@@ -847,12 +847,9 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
head_dim = self.model_config.head_dim
|
head_dim = self.model_config.head_dim
|
||||||
if "paddleocr" in self.model_config.model_type: # neox style = True
|
if "paddleocr" in self.model_config.model_type: # neox style = True
|
||||||
rope_head_dim = head_dim
|
rope_head_dim = head_dim
|
||||||
|
self.share_inputs["pos_emb_type"] = "NEOX"
|
||||||
else: # neox style = False
|
else: # neox style = False
|
||||||
rope_head_dim = head_dim // 2
|
rope_head_dim = head_dim // 2
|
||||||
|
|
||||||
if rope_head_dim == self.model_config.head_dim:
|
|
||||||
self.share_inputs["pos_emb_type"] = "NORMAL"
|
|
||||||
else:
|
|
||||||
self.share_inputs["pos_emb_type"] = "HALF_HEAD_DIM"
|
self.share_inputs["pos_emb_type"] = "HALF_HEAD_DIM"
|
||||||
|
|
||||||
self.share_inputs["rope_emb"] = paddle.full(
|
self.share_inputs["rope_emb"] = paddle.full(
|
||||||
|
|||||||
Reference in New Issue
Block a user