mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] support kernel for mtp(base) (#4748)
* [XPU] support kernel for mtp(base) * [XPU] support kernel for mtp(base) * format * format * format * fix gather next token * fix step && add test * fix * mv pre/post process * add adjust batch / gather next token for mtp * fix code style * fix mtp kenrel name * fix mtp kernel test * mv xpu pre/post process * mv xpu pre/post process
This commit is contained in:
@@ -37,13 +37,14 @@ std::vector<paddle::Tensor> AdjustBatch(
|
||||
const paddle::Tensor& x, // [token_num, dim_embed]
|
||||
const paddle::Tensor& cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor& encoder_seq_lod,
|
||||
const paddle::Tensor& decoder_seq_lod,
|
||||
const paddle::Tensor& encoder_batch_idx,
|
||||
const paddle::Tensor& decoder_batch_idx,
|
||||
const paddle::Tensor& encoder_seq_lod_cpu,
|
||||
const paddle::Tensor& decoder_seq_lod_cpu,
|
||||
const paddle::Tensor& encoder_batch_idx_cpu,
|
||||
const paddle::Tensor& decoder_batch_idx_cpu,
|
||||
const paddle::Tensor& enc_batch_tensor,
|
||||
const paddle::Tensor& dec_batch_tensor,
|
||||
const paddle::Tensor& len_info_cpu,
|
||||
const paddle::optional<paddle::Tensor>& output_padding_offset,
|
||||
int max_input_length);
|
||||
|
||||
@@ -264,7 +265,9 @@ void SpeculateVerify(const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& topp,
|
||||
int max_seq_len,
|
||||
int verify_window,
|
||||
bool enable_topp);
|
||||
bool enable_topp,
|
||||
bool benchmark_mode,
|
||||
bool accept_all_drafts);
|
||||
|
||||
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& seq_lens_decoder);
|
||||
@@ -285,21 +288,23 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& seq_lens_encoder_record,
|
||||
const paddle::Tensor& seq_lens_decoder_record,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_seq_lens_decoder,
|
||||
const paddle::Tensor& base_model_step_idx,
|
||||
const paddle::Tensor& base_model_stop_flags,
|
||||
const paddle::Tensor& base_model_is_block_step,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_draft_token,
|
||||
const int num_model_step,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill);
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1);
|
||||
|
||||
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
@@ -324,18 +329,19 @@ std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
|
||||
const paddle::Tensor& step_idx);
|
||||
|
||||
std::vector<paddle::Tensor> GatherNextToken(
|
||||
const paddle::Tensor& tmp_out, // [token_num, dim_embed]
|
||||
const paddle::Tensor& x, // [token_num, dim_embed]
|
||||
const paddle::Tensor& cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor& encoder_seq_lod,
|
||||
const paddle::Tensor& decoder_seq_lod,
|
||||
const paddle::Tensor& encoder_batch_map,
|
||||
const paddle::Tensor& decoder_batch_map,
|
||||
const paddle::Tensor& encoder_seq_lod_cpu,
|
||||
const paddle::Tensor& decoder_seq_lod_cpu,
|
||||
const paddle::Tensor& encoder_batch_map_cpu,
|
||||
const paddle::Tensor& decoder_batch_map_cpu,
|
||||
const paddle::Tensor& enc_batch_tensor,
|
||||
const paddle::Tensor& dec_batch_tensor,
|
||||
const paddle::Tensor& len_info_cpu,
|
||||
const paddle::optional<paddle::Tensor>& output_padding_offset,
|
||||
int max_input_length);
|
||||
int max_bsz);
|
||||
|
||||
std::vector<paddle::Tensor> GetImgBoundaries(
|
||||
const paddle::Tensor& task_input_ids,
|
||||
@@ -436,6 +442,34 @@ void MTPStepPaddle(
|
||||
const int block_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
void SpeculateStepPaddle(
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& ori_seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& block_tables, // [bsz, block_num_per_seq]
|
||||
const paddle::Tensor& encoder_block_lens,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& step_block_list,
|
||||
const paddle::Tensor& step_lens,
|
||||
const paddle::Tensor& recover_block_list,
|
||||
const paddle::Tensor& recover_lens,
|
||||
const paddle::Tensor& need_block_list,
|
||||
const paddle::Tensor& need_block_len,
|
||||
const paddle::Tensor& used_list_len,
|
||||
const paddle::Tensor& free_list,
|
||||
const paddle::Tensor& free_list_len,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& first_token_ids,
|
||||
const paddle::Tensor& accept_num,
|
||||
const int block_size,
|
||||
const int encoder_decoder_block_num,
|
||||
const int max_draft_tokens);
|
||||
|
||||
void SaveOutMmsgStatic(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
@@ -542,13 +576,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("x"),
|
||||
py::arg("cum_offsets"),
|
||||
py::arg("encoder_seq_lod"),
|
||||
py::arg("decoder_seq_lod"),
|
||||
py::arg("encoder_batch_idx"),
|
||||
py::arg("decoder_batch_idx"),
|
||||
py::arg("encoder_seq_lod_cpu"),
|
||||
py::arg("decoder_seq_lod_cpu"),
|
||||
py::arg("encoder_batch_idx_cpu"),
|
||||
py::arg("decoder_batch_idx_cpu"),
|
||||
py::arg("enc_batch_tensor"),
|
||||
py::arg("dec_batch_tensor"),
|
||||
py::arg("len_info_cpu"),
|
||||
py::arg("output_padding_offset"),
|
||||
py::arg("max_input_length"),
|
||||
"adjust batch in XPU");
|
||||
@@ -620,21 +655,23 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("seq_lens_encoder"),
|
||||
py::arg("seq_lens_decoder"),
|
||||
py::arg("step_idx"),
|
||||
py::arg("seq_lens_encoder_record"),
|
||||
py::arg("seq_lens_decoder_record"),
|
||||
py::arg("not_need_stop"),
|
||||
py::arg("is_block_step"),
|
||||
py::arg("batch_drop"),
|
||||
py::arg("pre_ids"),
|
||||
py::arg("accept_tokens"),
|
||||
py::arg("accept_num"),
|
||||
py::arg("base_model_seq_lens_this_time"),
|
||||
py::arg("base_model_seq_lens_encoder"),
|
||||
py::arg("base_model_seq_lens_decoder"),
|
||||
py::arg("base_model_step_idx"),
|
||||
py::arg("base_model_stop_flags"),
|
||||
py::arg("base_model_is_block_step"),
|
||||
py::arg("base_model_draft_tokens"),
|
||||
py::arg("max_draft_token"),
|
||||
py::arg("num_model_step"),
|
||||
py::arg("truncate_first_token"),
|
||||
py::arg("splitwise_prefill"),
|
||||
py::arg("kvcache_scheduler_v1"),
|
||||
"Preprocess data for draft model in speculative decoding");
|
||||
|
||||
m.def("draft_model_postprocess",
|
||||
@@ -727,18 +764,19 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("gather_next_token",
|
||||
&GatherNextToken,
|
||||
py::arg("tmp_out"),
|
||||
py::arg("x"),
|
||||
py::arg("cum_offsets"),
|
||||
py::arg("encoder_seq_lod"),
|
||||
py::arg("decoder_seq_lod"),
|
||||
py::arg("encoder_batch_map"),
|
||||
py::arg("decoder_batch_map"),
|
||||
py::arg("encoder_seq_lod_cpu"),
|
||||
py::arg("decoder_seq_lod_cpu"),
|
||||
py::arg("encoder_batch_map_cpu"),
|
||||
py::arg("decoder_batch_map_cpu"),
|
||||
py::arg("enc_batch_tensor"),
|
||||
py::arg("dec_batch_tensor"),
|
||||
py::arg("len_info_cpu"),
|
||||
py::arg("output_padding_offset"),
|
||||
py::arg("max_input_length"),
|
||||
py::arg("max_bsz"),
|
||||
"Gather next token for XPU");
|
||||
|
||||
m.def("get_img_boundaries",
|
||||
@@ -983,6 +1021,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("max_seq_len"),
|
||||
py::arg("verify_window"),
|
||||
py::arg("enable_topp"),
|
||||
py::arg("benchmark_mode"),
|
||||
py::arg("accept_all_drafts"),
|
||||
"Perform speculative verification for decoding");
|
||||
|
||||
m.def("speculate_clear_accept_nums",
|
||||
@@ -1104,6 +1144,36 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("encoder_decoder_block_num"),
|
||||
"Step paddle function");
|
||||
|
||||
m.def("speculate_step_paddle",
|
||||
&SpeculateStepPaddle,
|
||||
py::arg("stop_flags"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("ori_seq_lens_encoder"),
|
||||
py::arg("seq_lens_encoder"),
|
||||
py::arg("seq_lens_decoder"),
|
||||
py::arg("block_tables"),
|
||||
py::arg("encoder_block_lens"),
|
||||
py::arg("is_block_step"),
|
||||
py::arg("step_block_list"),
|
||||
py::arg("step_lens"),
|
||||
py::arg("recover_block_list"),
|
||||
py::arg("recover_lens"),
|
||||
py::arg("need_block_list"),
|
||||
py::arg("need_block_len"),
|
||||
py::arg("used_list_len"),
|
||||
py::arg("free_list"),
|
||||
py::arg("free_list_len"),
|
||||
py::arg("input_ids"),
|
||||
py::arg("pre_ids"),
|
||||
py::arg("step_idx"),
|
||||
py::arg("next_tokens"),
|
||||
py::arg("first_token_ids"),
|
||||
py::arg("accept_num"),
|
||||
py::arg("block_size"),
|
||||
py::arg("encoder_decoder_block_num"),
|
||||
py::arg("max_draft_tokens"),
|
||||
"Step paddle function");
|
||||
|
||||
m.def("text_image_gather_scatter",
|
||||
&TextImageGatherScatter,
|
||||
py::arg("input"),
|
||||
|
||||
Reference in New Issue
Block a user