[XPU] support kernel for mtp(base) (#4748)

* [XPU] support kernel for mtp(base)

* [XPU] support kernel for mtp(base)

* format

* format

* format

* fix gather next token

* fix step && add test

* fix

* mv pre/post process

* add adjust batch / gather next token for mtp

* fix code style

* fix mtp kenrel name

* fix mtp kernel test

* mv xpu pre/post process

* mv xpu pre/post process
This commit is contained in:
cmcamdy
2025-11-27 15:05:44 +08:00
committed by GitHub
parent e63d715fc3
commit 5a67a6d960
32 changed files with 3618 additions and 972 deletions

View File

@@ -37,13 +37,14 @@ std::vector<paddle::Tensor> AdjustBatch(
const paddle::Tensor& x, // [token_num, dim_embed]
const paddle::Tensor& cum_offsets, // [bsz, 1]
const paddle::Tensor& encoder_seq_lod,
const paddle::Tensor& decoder_seq_lod,
const paddle::Tensor& encoder_batch_idx,
const paddle::Tensor& decoder_batch_idx,
const paddle::Tensor& encoder_seq_lod_cpu,
const paddle::Tensor& decoder_seq_lod_cpu,
const paddle::Tensor& encoder_batch_idx_cpu,
const paddle::Tensor& decoder_batch_idx_cpu,
const paddle::Tensor& enc_batch_tensor,
const paddle::Tensor& dec_batch_tensor,
const paddle::Tensor& len_info_cpu,
const paddle::optional<paddle::Tensor>& output_padding_offset,
int max_input_length);
@@ -264,7 +265,9 @@ void SpeculateVerify(const paddle::Tensor& accept_tokens,
const paddle::Tensor& topp,
int max_seq_len,
int verify_window,
bool enable_topp);
bool enable_topp,
bool benchmark_mode,
bool accept_all_drafts);
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder);
@@ -285,21 +288,23 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& seq_lens_encoder_record,
const paddle::Tensor& seq_lens_decoder_record,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& is_block_step,
const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
const paddle::Tensor& base_model_stop_flags,
const paddle::Tensor& base_model_is_block_step,
const paddle::Tensor& base_model_draft_tokens,
const int max_draft_token,
const int num_model_step,
const bool truncate_first_token,
const bool splitwise_prefill);
const bool splitwise_prefill,
const bool kvcache_scheduler_v1);
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
const paddle::Tensor& base_model_seq_lens_this_time,
@@ -324,18 +329,19 @@ std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
const paddle::Tensor& step_idx);
std::vector<paddle::Tensor> GatherNextToken(
const paddle::Tensor& tmp_out, // [token_num, dim_embed]
const paddle::Tensor& x, // [token_num, dim_embed]
const paddle::Tensor& cum_offsets, // [bsz, 1]
const paddle::Tensor& encoder_seq_lod,
const paddle::Tensor& decoder_seq_lod,
const paddle::Tensor& encoder_batch_map,
const paddle::Tensor& decoder_batch_map,
const paddle::Tensor& encoder_seq_lod_cpu,
const paddle::Tensor& decoder_seq_lod_cpu,
const paddle::Tensor& encoder_batch_map_cpu,
const paddle::Tensor& decoder_batch_map_cpu,
const paddle::Tensor& enc_batch_tensor,
const paddle::Tensor& dec_batch_tensor,
const paddle::Tensor& len_info_cpu,
const paddle::optional<paddle::Tensor>& output_padding_offset,
int max_input_length);
int max_bsz);
std::vector<paddle::Tensor> GetImgBoundaries(
const paddle::Tensor& task_input_ids,
@@ -436,6 +442,34 @@ void MTPStepPaddle(
const int block_size,
const int max_draft_tokens);
void SpeculateStepPaddle(
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& ori_seq_lens_encoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor& encoder_block_lens,
const paddle::Tensor& is_block_step,
const paddle::Tensor& step_block_list,
const paddle::Tensor& step_lens,
const paddle::Tensor& recover_block_list,
const paddle::Tensor& recover_lens,
const paddle::Tensor& need_block_list,
const paddle::Tensor& need_block_len,
const paddle::Tensor& used_list_len,
const paddle::Tensor& free_list,
const paddle::Tensor& free_list_len,
const paddle::Tensor& input_ids,
const paddle::Tensor& pre_ids,
const paddle::Tensor& step_idx,
const paddle::Tensor& next_tokens,
const paddle::Tensor& first_token_ids,
const paddle::Tensor& accept_num,
const int block_size,
const int encoder_decoder_block_num,
const int max_draft_tokens);
void SaveOutMmsgStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
@@ -542,13 +576,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("x"),
py::arg("cum_offsets"),
py::arg("encoder_seq_lod"),
py::arg("decoder_seq_lod"),
py::arg("encoder_batch_idx"),
py::arg("decoder_batch_idx"),
py::arg("encoder_seq_lod_cpu"),
py::arg("decoder_seq_lod_cpu"),
py::arg("encoder_batch_idx_cpu"),
py::arg("decoder_batch_idx_cpu"),
py::arg("enc_batch_tensor"),
py::arg("dec_batch_tensor"),
py::arg("len_info_cpu"),
py::arg("output_padding_offset"),
py::arg("max_input_length"),
"adjust batch in XPU");
@@ -620,21 +655,23 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("step_idx"),
py::arg("seq_lens_encoder_record"),
py::arg("seq_lens_decoder_record"),
py::arg("not_need_stop"),
py::arg("is_block_step"),
py::arg("batch_drop"),
py::arg("pre_ids"),
py::arg("accept_tokens"),
py::arg("accept_num"),
py::arg("base_model_seq_lens_this_time"),
py::arg("base_model_seq_lens_encoder"),
py::arg("base_model_seq_lens_decoder"),
py::arg("base_model_step_idx"),
py::arg("base_model_stop_flags"),
py::arg("base_model_is_block_step"),
py::arg("base_model_draft_tokens"),
py::arg("max_draft_token"),
py::arg("num_model_step"),
py::arg("truncate_first_token"),
py::arg("splitwise_prefill"),
py::arg("kvcache_scheduler_v1"),
"Preprocess data for draft model in speculative decoding");
m.def("draft_model_postprocess",
@@ -727,18 +764,19 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("gather_next_token",
&GatherNextToken,
py::arg("tmp_out"),
py::arg("x"),
py::arg("cum_offsets"),
py::arg("encoder_seq_lod"),
py::arg("decoder_seq_lod"),
py::arg("encoder_batch_map"),
py::arg("decoder_batch_map"),
py::arg("encoder_seq_lod_cpu"),
py::arg("decoder_seq_lod_cpu"),
py::arg("encoder_batch_map_cpu"),
py::arg("decoder_batch_map_cpu"),
py::arg("enc_batch_tensor"),
py::arg("dec_batch_tensor"),
py::arg("len_info_cpu"),
py::arg("output_padding_offset"),
py::arg("max_input_length"),
py::arg("max_bsz"),
"Gather next token for XPU");
m.def("get_img_boundaries",
@@ -983,6 +1021,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("max_seq_len"),
py::arg("verify_window"),
py::arg("enable_topp"),
py::arg("benchmark_mode"),
py::arg("accept_all_drafts"),
"Perform speculative verification for decoding");
m.def("speculate_clear_accept_nums",
@@ -1104,6 +1144,36 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("encoder_decoder_block_num"),
"Step paddle function");
m.def("speculate_step_paddle",
&SpeculateStepPaddle,
py::arg("stop_flags"),
py::arg("seq_lens_this_time"),
py::arg("ori_seq_lens_encoder"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("block_tables"),
py::arg("encoder_block_lens"),
py::arg("is_block_step"),
py::arg("step_block_list"),
py::arg("step_lens"),
py::arg("recover_block_list"),
py::arg("recover_lens"),
py::arg("need_block_list"),
py::arg("need_block_len"),
py::arg("used_list_len"),
py::arg("free_list"),
py::arg("free_list_len"),
py::arg("input_ids"),
py::arg("pre_ids"),
py::arg("step_idx"),
py::arg("next_tokens"),
py::arg("first_token_ids"),
py::arg("accept_num"),
py::arg("block_size"),
py::arg("encoder_decoder_block_num"),
py::arg("max_draft_tokens"),
"Step paddle function");
m.def("text_image_gather_scatter",
&TextImageGatherScatter,
py::arg("input"),