[Feature][MTP]support new speculative decoding method named hybrid mtp with ngram (#3610)

This commit is contained in:
freeliuzc
2025-08-26 14:29:22 +08:00
committed by GitHub
parent 0a0d2959b9
commit 52eda7fdb3
20 changed files with 454 additions and 571 deletions

View File

@@ -614,7 +614,7 @@ void SpeculateVerify(
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &draft_tokens,
@@ -659,6 +659,20 @@ void NgramMatch(const paddle::Tensor &input_ids,
const int max_draft_tokens);
void HybridMtpNgram(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &draft_token_num,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &max_dec_len,
const int max_ngram_size,
const int min_ngram_size,
const int max_draft_tokens);
// MTP
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
const paddle::Tensor& base_model_seq_lens_this_time,
@@ -675,6 +689,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_this_time,
@@ -1121,7 +1136,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("speculate_verify",&SpeculateVerify, "speculate_verify function");
m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function");
m.def("speculate_update",&SpeculateUpdate, "Speculate Update Kernel");
m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function");
@@ -1131,6 +1146,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("ngram_match", &NgramMatch, "ngram_match function");
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");
m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function");
m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function");