diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index e8eea990b..832bdbf69 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -43,6 +43,36 @@ void GetOutputKVSignal(const paddle::Tensor &x, int64_t rank_id, bool wait_flag); +std::vector BlockAttn( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& rotary_embs, + const paddle::Tensor& block_tables, + const paddle::Tensor& prefix_block_tables, + const paddle::Tensor& len_info_cpu, + const paddle::Tensor& encoder_seq_lod_cpu, + const paddle::Tensor& decoder_seq_lod_cpu, + const paddle::Tensor& encoder_kv_lod_cpu, + const paddle::Tensor& encoder_batch_map_cpu, + const paddle::Tensor& decoder_context_len_cpu, + const paddle::Tensor& decoder_context_len_cache_cpu, + const paddle::Tensor& decoder_batch_map_cpu, + const paddle::Tensor& prefix_len_cpu, + const paddle::optional& k_scales, + const paddle::optional& v_scales, + const paddle::optional& k_scales_inv, + const paddle::optional& v_scales_inv, + const paddle::optional& k_zeros, + const paddle::optional& v_zeros, + const paddle::optional& shift, + const paddle::optional& smooth, + const paddle::optional& kv_signal_data_cpu, + const paddle::optional& cachekv_signal_thread_cpu, + const std::string &pos_emb_type="NORMAL", + bool rope_3d=false); + std::vector MoERedundantTopKSelect( const paddle::Tensor& gating_logits, const paddle::Tensor& expert_id_to_ep_rank_array, @@ -327,6 +357,37 @@ std::vector SpeculateGetSeqLensOutput( const paddle::Tensor& seq_lens_decoder); PYBIND11_MODULE(fastdeploy_ops, m) { + m.def("block_attn", + &BlockAttn, + py::arg("qkv"), + py::arg("key_cache"), + py::arg("value_cache"), + py::arg("cum_offsets"), + py::arg("rotary_embs"), + py::arg("block_tables"), + py::arg("prefix_block_tables"), + py::arg("len_info_cpu"), + py::arg("encoder_seq_lod_cpu"), + py::arg("decoder_seq_lod_cpu"), + py::arg("encoder_kv_lod_cpu"), + py::arg("encoder_batch_map_cpu"), + py::arg("decoder_context_len_cpu"), + py::arg("decoder_context_len_cache_cpu"), + py::arg("decoder_batch_map_cpu"), + py::arg("prefix_len_cpu"), + py::arg("k_scales"), + py::arg("v_scales"), + py::arg("k_scales_inv"), + py::arg("v_scales_inv"), + py::arg("k_zeros"), + py::arg("v_zeros"), + py::arg("shift"), + py::arg("smooth"), + py::arg("kv_signal_data_cpu"), + py::arg("cachekv_signal_thread_cpu"), + py::arg("pos_emb_type") = "NORMAL", + py::arg("rope_3d") = false, + "block attention in XPU"); m.def("cuda_host_alloc", &custom_xpu_host_alloc, "Allocate pinned memory",