[Feature][MTP]Support new mtp (#3656)

* update multi-draft-token strategy

* fix format

* support hybrid mtp with ngram speculative decoding method
This commit is contained in:
freeliuzc
2025-08-27 19:38:26 +08:00
committed by GitHub
parent 62659a7a73
commit c753f1fc9e
20 changed files with 501 additions and 579 deletions

View File

@@ -604,7 +604,7 @@ void SpeculateVerify(
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &draft_tokens,
@@ -649,6 +649,20 @@ void NgramMatch(const paddle::Tensor &input_ids,
const int max_draft_tokens);
void HybridMtpNgram(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &draft_token_num,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &max_dec_len,
const int max_ngram_size,
const int min_ngram_size,
const int max_draft_tokens);
// MTP
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
const paddle::Tensor& base_model_seq_lens_this_time,
@@ -665,8 +679,10 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
@@ -1083,7 +1099,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("speculate_verify",&SpeculateVerify, "speculate_verify function");
m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function");
m.def("speculate_update",&SpeculateUpdate, "Speculate Update Kernel");
m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function");
@@ -1093,6 +1109,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("ngram_match", &NgramMatch, "ngram_match function");
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");
m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function");
m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function");