mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] add speculate_get_logits (#5497)
* [XPU] add speculate_step_system_cache * [XPU] add speculate_step_system_cache * [XPU] add speculate_get_logits * delete context * add ptr check --------- Co-authored-by: cmcamdy <1027740945@qq.com> Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -470,6 +470,16 @@ void SpeculateStepPaddle(
|
||||
const int encoder_decoder_block_num,
|
||||
const int max_draft_tokens);
|
||||
|
||||
void SpeculateGetLogits(const paddle::Tensor& draft_logits,
|
||||
const paddle::Tensor& next_token_num,
|
||||
const paddle::Tensor& batch_token_num,
|
||||
const paddle::Tensor& cu_next_token_offset,
|
||||
const paddle::Tensor& cu_batch_token_offset,
|
||||
const paddle::Tensor& logits,
|
||||
const paddle::Tensor& first_token_logits,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder);
|
||||
|
||||
void SaveOutMmsgStatic(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
@@ -1174,6 +1184,19 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("max_draft_tokens"),
|
||||
"Step paddle function");
|
||||
|
||||
m.def("speculate_get_logits",
|
||||
&SpeculateGetLogits,
|
||||
py::arg("draft_logits"),
|
||||
py::arg("next_token_num"),
|
||||
py::arg("batch_token_num"),
|
||||
py::arg("cu_next_token_offset"),
|
||||
py::arg("cu_batch_token_offset"),
|
||||
py::arg("logits"),
|
||||
py::arg("first_token_logits"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("seq_lens_encoder"),
|
||||
"speculate get logits function");
|
||||
|
||||
m.def("text_image_gather_scatter",
|
||||
&TextImageGatherScatter,
|
||||
py::arg("input"),
|
||||
|
||||
Reference in New Issue
Block a user