[XPU] add speculate_get_logits (#5497)

* [XPU] add speculate_step_system_cache

* [XPU] add speculate_step_system_cache

* [XPU] add speculate_get_logits

* delete context

* add ptr check

---------

Co-authored-by: cmcamdy <1027740945@qq.com>
Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
RuohengMa
2025-12-12 15:38:30 +08:00
committed by GitHub
parent 888c4b992d
commit 12c76f8137
6 changed files with 603 additions and 0 deletions

View File

@@ -470,6 +470,16 @@ void SpeculateStepPaddle(
const int encoder_decoder_block_num,
const int max_draft_tokens);
void SpeculateGetLogits(const paddle::Tensor& draft_logits,
const paddle::Tensor& next_token_num,
const paddle::Tensor& batch_token_num,
const paddle::Tensor& cu_next_token_offset,
const paddle::Tensor& cu_batch_token_offset,
const paddle::Tensor& logits,
const paddle::Tensor& first_token_logits,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder);
void SaveOutMmsgStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
@@ -1174,6 +1184,19 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("max_draft_tokens"),
"Step paddle function");
m.def("speculate_get_logits",
&SpeculateGetLogits,
py::arg("draft_logits"),
py::arg("next_token_num"),
py::arg("batch_token_num"),
py::arg("cu_next_token_offset"),
py::arg("cu_batch_token_offset"),
py::arg("logits"),
py::arg("first_token_logits"),
py::arg("seq_lens_this_time"),
py::arg("seq_lens_encoder"),
"speculate get logits function");
m.def("text_image_gather_scatter",
&TextImageGatherScatter,
py::arg("input"),