diff --git a/custom_ops/xpu_ops/src/ops/block_attn.cc b/custom_ops/xpu_ops/src/ops/block_attn.cc index f6ade82b1..6153a77dd 100644 --- a/custom_ops/xpu_ops/src/ops/block_attn.cc +++ b/custom_ops/xpu_ops/src/ops/block_attn.cc @@ -291,46 +291,89 @@ std::vector BlockAttnKernel( KV_BUF_TYPE, {total_enc_len, hidden_dim}); // rope + cache - int ret = infer_ops:: - split_rope_cache_kv_encoder( - xpu_ctx->x_context(), - reinterpret_cast(qkv.data()), // qkv - reinterpret_cast( - rotary_embs.data()), // rotary_pos_emb - reinterpret_cast( - block_tables.data()), // block_table - q_buf.data(), - k_buf.data(), - v_buf.data(), - const_cast( - reinterpret_cast(key_cache.data())), - const_cast(reinterpret_cast( - value_cache.data())), - vsl.usual_lod_vp, // seq_lod - vsl.slot_mapping_vp, // real_batch - prefix_lens_vp, // start_tokens - param.batch_size, // batch_size - 1, // emb_batch_size - rope_max_seqlen, // max_seqlen - param.head_num, - param.kv_head_num, - param.head_dim, - param.max_batch_size, - block_size, - max_block_per_seq, - "BLHD", - "HLD", - pos_emb_type, - nullptr, // k_cache_scale_inv - use for per head - nullptr, // v_cache_scale_inv - use for per head - quant_k_scale, // intx_k_pc_scale - quant_v_scale, // intx_v_pc_scale - quant_k_zp, // intx_k_pc_zero - quant_v_zp, // intx_v_pc_zero - rope_3d, // rope_3d - rope_3d_num_seqs); - PD_CHECK(ret == api::SUCCESS, - "infer_ops::split_rope_cache_kv_encoder failed."); + int ret = 0; + if (pos_emb_type == "NEOX") { + ret = infer_ops:: + split_neox_cache_kv_encoder( + xpu_ctx->x_context(), + reinterpret_cast(qkv.data()), // qkv + reinterpret_cast( + rotary_embs.data()), // rotary_pos_emb + reinterpret_cast( + block_tables.data()), // block_table + q_buf.data(), + k_buf.data(), + v_buf.data(), + const_cast(reinterpret_cast( + key_cache.data())), + const_cast(reinterpret_cast( + value_cache.data())), + vsl.usual_lod_vp, // seq_lod + vsl.slot_mapping_vp, // real_batch + param.batch_size, // batch_size + 1, // emb_batch_size + rope_max_seqlen, // max_seqlen + param.head_num, + param.kv_head_num, + param.head_dim, + param.max_batch_size, + block_size, + max_block_per_seq, + "BLHD", + "HLD", + pos_emb_type, + nullptr, // k_cache_scale_inv - use for per head + nullptr, // v_cache_scale_inv - use for per head + nullptr, // intx_k_pc_scale + nullptr, // intx_v_pc_scale + nullptr, // intx_k_pc_zero + nullptr, // intx_v_pc_zero + rope_3d); + PD_CHECK(ret == api::SUCCESS, "split_neox_cache_kv_encoder failed."); + } else { + ret = infer_ops::split_rope_cache_kv_encoder( + xpu_ctx->x_context(), + reinterpret_cast(qkv.data()), // qkv + reinterpret_cast( + rotary_embs.data()), // rotary_pos_emb + reinterpret_cast( + block_tables.data()), // block_table + q_buf.data(), + k_buf.data(), + v_buf.data(), + const_cast( + reinterpret_cast(key_cache.data())), + const_cast( + reinterpret_cast(value_cache.data())), + vsl.usual_lod_vp, // seq_lod + vsl.slot_mapping_vp, // real_batch + prefix_lens_vp, // start_tokens + param.batch_size, // batch_size + 1, // emb_batch_size + rope_max_seqlen, // max_seqlen + param.head_num, + param.kv_head_num, + param.head_dim, + param.max_batch_size, + block_size, + max_block_per_seq, + "BLHD", + "HLD", + pos_emb_type, + nullptr, // k_cache_scale_inv - use for per head + nullptr, // v_cache_scale_inv - use for per head + quant_k_scale, // intx_k_pc_scale + quant_v_scale, // intx_v_pc_scale + quant_k_zp, // intx_k_pc_zero + quant_v_zp, // intx_v_pc_zero + rope_3d); + PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_encoder failed."); + } + // pd split if (FLAGS_fmt_write_cache_completed_signal) { XPUEvent write_event = nullptr; @@ -496,49 +539,88 @@ std::vector BlockAttnKernel( nullptr}; // use for split rope enc as lod in MTP // rope + cache - int ret = infer_ops::split_rope_cache_kv_encoder( - xpu_ctx->x_context(), - reinterpret_cast(qkv.data()), // qkv - reinterpret_cast( - rotary_embs.data()), // rotary_pos_emb - reinterpret_cast( - block_tables.data()), // block_table - q_buf.data(), - k_buf.data(), - v_buf.data(), - const_cast( - reinterpret_cast(key_cache.data())), - const_cast( - reinterpret_cast(value_cache.data())), - decoder_seq_lod_vp, // seq_lod - decoder_batch_map_vp, // real_batch - decoder_context_len_cache_vp, // start_tokens (prefix len) - param.batch_size, // batch_size - 1, // emb_batch_size - rope_max_seqlen, // max_seqlen - param.head_num, - param.kv_head_num, - param.head_dim, - param.max_batch_size, - block_size, - max_block_per_seq, - "BLHD", - "HLD", - pos_emb_type, - nullptr, // k_cache_scale_inv - use for per head - nullptr, // v_cache_scale_inv - use for per head - quant_k_scale, // intx_k_pc_scale - quant_v_scale, // intx_v_pc_scale - quant_k_zp, // intx_k_pc_zero - quant_v_zp, // intx_v_pc_zero - rope_3d, // rope_3d - rope_3d_num_seqs); - PD_CHECK(ret == api::SUCCESS, - "infer_ops::split_rope_cache_kv_encoder failed."); + int ret = 0; + if (pos_emb_type == "NEOX") { + ret = infer_ops:: + split_neox_cache_kv_encoder( + xpu_ctx->x_context(), + reinterpret_cast(qkv.data()), // qkv + reinterpret_cast( + rotary_embs.data()), // rotary_pos_emb + reinterpret_cast( + block_tables.data()), // block_table + q_buf.data(), + k_buf.data(), + v_buf.data(), + const_cast(reinterpret_cast( + key_cache.data())), + const_cast(reinterpret_cast( + value_cache.data())), + decoder_seq_lod_vp, // seq_lod + decoder_batch_map_vp, // real_batch + param.batch_size, // batch_size + 1, // emb_batch_size + rope_max_seqlen, // max_seqlen + param.head_num, + param.kv_head_num, + param.head_dim, + param.max_batch_size, + block_size, + max_block_per_seq, + "BLHD", + "HLD", + pos_emb_type, + nullptr, // k_cache_scale_inv - use for per head + nullptr, // v_cache_scale_inv - use for per head + nullptr, // intx_k_pc_scale + nullptr, // intx_v_pc_scale + nullptr, // intx_k_pc_zero + nullptr, // intx_v_pc_zero + rope_3d); + PD_CHECK(ret == api::SUCCESS, "split_neox_cache_kv_encoder failed."); + } else { + ret = infer_ops::split_rope_cache_kv_encoder( + xpu_ctx->x_context(), + reinterpret_cast(qkv.data()), // qkv + reinterpret_cast( + rotary_embs.data()), // rotary_pos_emb + reinterpret_cast( + block_tables.data()), // block_table + q_buf.data(), + k_buf.data(), + v_buf.data(), + const_cast( + reinterpret_cast(key_cache.data())), + const_cast(reinterpret_cast( + value_cache.data())), + decoder_seq_lod_vp, // seq_lod + decoder_batch_map_vp, // real_batch + decoder_context_len_cache_vp, // start_tokens (prefix len) + param.batch_size, // batch_size + 1, // emb_batch_size + rope_max_seqlen, // max_seqlen + param.head_num, + param.kv_head_num, + param.head_dim, + param.max_batch_size, + block_size, + max_block_per_seq, + "BLHD", + "HLD", + pos_emb_type, + nullptr, // k_cache_scale_inv - use for per head + nullptr, // v_cache_scale_inv - use for per head + quant_k_scale, // intx_k_pc_scale + quant_v_scale, // intx_v_pc_scale + quant_k_zp, // intx_k_pc_zero + quant_v_zp, // intx_v_pc_zero + rope_3d); + PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_encoder failed."); + } float* fake_perhead_scale = nullptr; if (is_cache_int8 && has_zp) { @@ -684,48 +766,89 @@ std::vector BlockAttnKernel( param.page_attn.block_table = &block_table_tensor; // rope + cache - int ret = infer_ops::split_rope_cache_kv_decoder( - xpu_ctx->x_context(), - reinterpret_cast(qkv.data()) + - total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv - reinterpret_cast( - rotary_embs.data()), // rotary_pos_emb - reinterpret_cast( - block_tables.data()), // block_table - q_buf.data(), - nullptr, - nullptr, - const_cast( - reinterpret_cast(key_cache.data())), - const_cast( - reinterpret_cast(value_cache.data())), - vsl.usual_lod_vp, // seq_lod - vsl.slot_mapping_vp, // real_batch - param.batch_size, // batch_size - 1, // emb_batch_size = rotary_embs.dims()[1] = 1 - rope_max_seqlen, // max_seqlen - param.head_num, - param.kv_head_num, - param.head_dim, - param.max_batch_size, - block_size, - max_block_per_seq, - "BLHD", - "HLD", - pos_emb_type, - reinterpret_cast(quant_k_scale), // k_cache_scale_inv - reinterpret_cast(quant_v_scale), // v_cache_scale_inv - reinterpret_cast(quant_k_zp), // k_cache_zp - reinterpret_cast(quant_v_zp), // v_cache_zp - is_cache_int8, // bool b_c8_pc - rope_3d, // rope_3d - rope_3d_num_seqs); - PD_CHECK(ret == api::SUCCESS, - "infer_ops::split_rope_cache_kv_decoder failed."); + int ret = 0; + if (pos_emb_type == "NEOX") { + ret = infer_ops::split_neox_cache_kv_decoder( + xpu_ctx->x_context(), + reinterpret_cast(qkv.data()) + + total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv + reinterpret_cast( + rotary_embs.data()), // rotary_pos_emb + reinterpret_cast( + block_tables.data()), // block_table + q_buf.data(), + nullptr, + nullptr, + const_cast( + reinterpret_cast(key_cache.data())), + const_cast(reinterpret_cast( + value_cache.data())), + vsl.usual_lod_vp, // seq_lod + vsl.slot_mapping_vp, // real_batch + param.batch_size, // batch_size + 1, // emb_batch_size = rotary_embs.dims()[1] = 1 + rope_max_seqlen, // max_seqlen + param.head_num, + param.kv_head_num, + param.head_dim, + param.max_batch_size, + block_size, + max_block_per_seq, + "BLHD", + "HLD", + pos_emb_type, + reinterpret_cast(quant_k_scale), // k_cache_scale_inv + reinterpret_cast(quant_v_scale), // v_cache_scale_inv + reinterpret_cast(quant_k_zp), // k_cache_zp + reinterpret_cast(quant_v_zp), // v_cache_zp + rope_3d); + PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_decoder failed."); + } else { + ret = infer_ops::split_rope_cache_kv_decoder( + xpu_ctx->x_context(), + reinterpret_cast(qkv.data()) + + total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv + reinterpret_cast( + rotary_embs.data()), // rotary_pos_emb + reinterpret_cast( + block_tables.data()), // block_table + q_buf.data(), + nullptr, + nullptr, + const_cast( + reinterpret_cast(key_cache.data())), + const_cast(reinterpret_cast( + value_cache.data())), + vsl.usual_lod_vp, // seq_lod + vsl.slot_mapping_vp, // real_batch + param.batch_size, // batch_size + 1, // emb_batch_size = rotary_embs.dims()[1] = 1 + rope_max_seqlen, // max_seqlen + param.head_num, + param.kv_head_num, + param.head_dim, + param.max_batch_size, + block_size, + max_block_per_seq, + "BLHD", + "HLD", + pos_emb_type, + reinterpret_cast(quant_k_scale), // k_cache_scale_inv + reinterpret_cast(quant_v_scale), // v_cache_scale_inv + reinterpret_cast(quant_k_zp), // k_cache_zp + reinterpret_cast(quant_v_zp), // v_cache_zp + is_cache_int8, // bool b_c8_pc + rope_3d); + PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_decoder failed."); + } float* fake_perhead_scale = nullptr; if (is_cache_int8 && has_zp) { diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index f88a11477..f9dad431e 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -847,12 +847,9 @@ class XPUModelRunner(ModelRunnerBase): head_dim = self.model_config.head_dim if "paddleocr" in self.model_config.model_type: # neox style = True rope_head_dim = head_dim + self.share_inputs["pos_emb_type"] = "NEOX" else: # neox style = False rope_head_dim = head_dim // 2 - - if rope_head_dim == self.model_config.head_dim: - self.share_inputs["pos_emb_type"] = "NORMAL" - else: self.share_inputs["pos_emb_type"] = "HALF_HEAD_DIM" self.share_inputs["rope_emb"] = paddle.full(