mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] bind block_attn kernel with pybind (#4499)
This commit is contained in:
@@ -43,6 +43,36 @@ void GetOutputKVSignal(const paddle::Tensor &x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag);
|
||||
|
||||
std::vector<paddle::Tensor> BlockAttn(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& rotary_embs,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& prefix_block_tables,
|
||||
const paddle::Tensor& len_info_cpu,
|
||||
const paddle::Tensor& encoder_seq_lod_cpu,
|
||||
const paddle::Tensor& decoder_seq_lod_cpu,
|
||||
const paddle::Tensor& encoder_kv_lod_cpu,
|
||||
const paddle::Tensor& encoder_batch_map_cpu,
|
||||
const paddle::Tensor& decoder_context_len_cpu,
|
||||
const paddle::Tensor& decoder_context_len_cache_cpu,
|
||||
const paddle::Tensor& decoder_batch_map_cpu,
|
||||
const paddle::Tensor& prefix_len_cpu,
|
||||
const paddle::optional<paddle::Tensor>& k_scales,
|
||||
const paddle::optional<paddle::Tensor>& v_scales,
|
||||
const paddle::optional<paddle::Tensor>& k_scales_inv,
|
||||
const paddle::optional<paddle::Tensor>& v_scales_inv,
|
||||
const paddle::optional<paddle::Tensor>& k_zeros,
|
||||
const paddle::optional<paddle::Tensor>& v_zeros,
|
||||
const paddle::optional<paddle::Tensor>& shift,
|
||||
const paddle::optional<paddle::Tensor>& smooth,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
|
||||
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
|
||||
const std::string &pos_emb_type="NORMAL",
|
||||
bool rope_3d=false);
|
||||
|
||||
std::vector<paddle::Tensor> MoERedundantTopKSelect(
|
||||
const paddle::Tensor& gating_logits,
|
||||
const paddle::Tensor& expert_id_to_ep_rank_array,
|
||||
@@ -327,6 +357,37 @@ std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
|
||||
const paddle::Tensor& seq_lens_decoder);
|
||||
|
||||
PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("block_attn",
|
||||
&BlockAttn,
|
||||
py::arg("qkv"),
|
||||
py::arg("key_cache"),
|
||||
py::arg("value_cache"),
|
||||
py::arg("cum_offsets"),
|
||||
py::arg("rotary_embs"),
|
||||
py::arg("block_tables"),
|
||||
py::arg("prefix_block_tables"),
|
||||
py::arg("len_info_cpu"),
|
||||
py::arg("encoder_seq_lod_cpu"),
|
||||
py::arg("decoder_seq_lod_cpu"),
|
||||
py::arg("encoder_kv_lod_cpu"),
|
||||
py::arg("encoder_batch_map_cpu"),
|
||||
py::arg("decoder_context_len_cpu"),
|
||||
py::arg("decoder_context_len_cache_cpu"),
|
||||
py::arg("decoder_batch_map_cpu"),
|
||||
py::arg("prefix_len_cpu"),
|
||||
py::arg("k_scales"),
|
||||
py::arg("v_scales"),
|
||||
py::arg("k_scales_inv"),
|
||||
py::arg("v_scales_inv"),
|
||||
py::arg("k_zeros"),
|
||||
py::arg("v_zeros"),
|
||||
py::arg("shift"),
|
||||
py::arg("smooth"),
|
||||
py::arg("kv_signal_data_cpu"),
|
||||
py::arg("cachekv_signal_thread_cpu"),
|
||||
py::arg("pos_emb_type") = "NORMAL",
|
||||
py::arg("rope_3d") = false,
|
||||
"block attention in XPU");
|
||||
m.def("cuda_host_alloc",
|
||||
&custom_xpu_host_alloc,
|
||||
"Allocate pinned memory",
|
||||
|
||||
Reference in New Issue
Block a user