[XPU] bind block_attn kernel with pybind (#4499)

This commit is contained in:
Lucas
2025-10-21 10:58:52 +08:00
committed by GitHub
parent d85ef5352a
commit 99564349a7

View File

@@ -43,6 +43,36 @@ void GetOutputKVSignal(const paddle::Tensor &x,
int64_t rank_id,
bool wait_flag);
std::vector<paddle::Tensor> BlockAttn(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& rotary_embs,
const paddle::Tensor& block_tables,
const paddle::Tensor& prefix_block_tables,
const paddle::Tensor& len_info_cpu,
const paddle::Tensor& encoder_seq_lod_cpu,
const paddle::Tensor& decoder_seq_lod_cpu,
const paddle::Tensor& encoder_kv_lod_cpu,
const paddle::Tensor& encoder_batch_map_cpu,
const paddle::Tensor& decoder_context_len_cpu,
const paddle::Tensor& decoder_context_len_cache_cpu,
const paddle::Tensor& decoder_batch_map_cpu,
const paddle::Tensor& prefix_len_cpu,
const paddle::optional<paddle::Tensor>& k_scales,
const paddle::optional<paddle::Tensor>& v_scales,
const paddle::optional<paddle::Tensor>& k_scales_inv,
const paddle::optional<paddle::Tensor>& v_scales_inv,
const paddle::optional<paddle::Tensor>& k_zeros,
const paddle::optional<paddle::Tensor>& v_zeros,
const paddle::optional<paddle::Tensor>& shift,
const paddle::optional<paddle::Tensor>& smooth,
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
const std::string &pos_emb_type="NORMAL",
bool rope_3d=false);
std::vector<paddle::Tensor> MoERedundantTopKSelect(
const paddle::Tensor& gating_logits,
const paddle::Tensor& expert_id_to_ep_rank_array,
@@ -327,6 +357,37 @@ std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
const paddle::Tensor& seq_lens_decoder);
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("block_attn",
&BlockAttn,
py::arg("qkv"),
py::arg("key_cache"),
py::arg("value_cache"),
py::arg("cum_offsets"),
py::arg("rotary_embs"),
py::arg("block_tables"),
py::arg("prefix_block_tables"),
py::arg("len_info_cpu"),
py::arg("encoder_seq_lod_cpu"),
py::arg("decoder_seq_lod_cpu"),
py::arg("encoder_kv_lod_cpu"),
py::arg("encoder_batch_map_cpu"),
py::arg("decoder_context_len_cpu"),
py::arg("decoder_context_len_cache_cpu"),
py::arg("decoder_batch_map_cpu"),
py::arg("prefix_len_cpu"),
py::arg("k_scales"),
py::arg("v_scales"),
py::arg("k_scales_inv"),
py::arg("v_scales_inv"),
py::arg("k_zeros"),
py::arg("v_zeros"),
py::arg("shift"),
py::arg("smooth"),
py::arg("kv_signal_data_cpu"),
py::arg("cachekv_signal_thread_cpu"),
py::arg("pos_emb_type") = "NORMAL",
py::arg("rope_3d") = false,
"block attention in XPU");
m.def("cuda_host_alloc",
&custom_xpu_host_alloc,
"Allocate pinned memory",