mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 01:22:59 +08:00
[MTP] optimize mtp infer speed (#2840)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
This commit is contained in:
@@ -518,6 +518,213 @@ int64_t open_mem_handle(paddle::Tensor& mem_handle);
|
|||||||
|
|
||||||
void free_shared_buffer(int64_t buffer);
|
void free_shared_buffer(int64_t buffer);
|
||||||
|
|
||||||
|
// speculative decoding Kernel
|
||||||
|
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||||
|
const paddle::Tensor& input_ids,
|
||||||
|
const paddle::Tensor& draft_tokens,
|
||||||
|
const paddle::Tensor& cum_offsets,
|
||||||
|
const paddle::Tensor& token_num,
|
||||||
|
const paddle::Tensor& seq_len,
|
||||||
|
const paddle::Tensor& seq_lens_encoder);
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
|
||||||
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& seq_lens_decoder);
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
|
||||||
|
const paddle::Tensor& output_cum_offsets_tmp,
|
||||||
|
const paddle::Tensor& out_token_num,
|
||||||
|
const paddle::Tensor& seq_lens_output,
|
||||||
|
const int max_seq_len);
|
||||||
|
|
||||||
|
|
||||||
|
void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
||||||
|
const paddle::Tensor &logits,
|
||||||
|
const paddle::Tensor &penalty_scores,
|
||||||
|
const paddle::Tensor &frequency_scores,
|
||||||
|
const paddle::Tensor &presence_scores,
|
||||||
|
const paddle::Tensor &temperatures,
|
||||||
|
const paddle::Tensor &bad_tokens,
|
||||||
|
const paddle::Tensor &cur_len,
|
||||||
|
const paddle::Tensor &min_len,
|
||||||
|
const paddle::Tensor &eos_token_id,
|
||||||
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
|
const paddle::Tensor &output_padding_offset,
|
||||||
|
const paddle::Tensor &output_cum_offsets,
|
||||||
|
const int max_seq_len);
|
||||||
|
|
||||||
|
void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
|
||||||
|
const paddle::Tensor &accept_num,
|
||||||
|
const paddle::Tensor &pre_ids,
|
||||||
|
const paddle::Tensor &step_idx,
|
||||||
|
const paddle::Tensor &stop_flags,
|
||||||
|
const paddle::Tensor &seq_lens,
|
||||||
|
const paddle::Tensor &stop_seqs,
|
||||||
|
const paddle::Tensor &stop_seqs_len,
|
||||||
|
const paddle::Tensor &end_ids);
|
||||||
|
|
||||||
|
|
||||||
|
void SpeculateVerify(
|
||||||
|
const paddle::Tensor &accept_tokens, const paddle::Tensor &accept_num,
|
||||||
|
const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags,
|
||||||
|
const paddle::Tensor &seq_lens_encoder,
|
||||||
|
const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &draft_tokens,
|
||||||
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
|
const paddle::Tensor &verify_tokens, const paddle::Tensor &verify_scores,
|
||||||
|
const paddle::Tensor &max_dec_len, const paddle::Tensor &end_tokens,
|
||||||
|
const paddle::Tensor &is_block_step,
|
||||||
|
const paddle::Tensor &output_cum_offsets,
|
||||||
|
const paddle::Tensor &actual_candidate_len,
|
||||||
|
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
|
||||||
|
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
|
||||||
|
|
||||||
|
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
|
||||||
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
|
const paddle::Tensor ¬_need_stop,
|
||||||
|
const paddle::Tensor &draft_tokens,
|
||||||
|
const paddle::Tensor &actual_draft_token_nums,
|
||||||
|
const paddle::Tensor &accept_tokens,
|
||||||
|
const paddle::Tensor &accept_num,
|
||||||
|
const paddle::Tensor &stop_flags,
|
||||||
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
|
const paddle::Tensor &is_block_step,
|
||||||
|
const paddle::Tensor &stop_nums);
|
||||||
|
|
||||||
|
void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
||||||
|
const paddle::Tensor &accept_tokens,
|
||||||
|
const paddle::Tensor &accept_num,
|
||||||
|
const paddle::Tensor &stop_flags,
|
||||||
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
|
const paddle::Tensor &seq_lens_encoder,
|
||||||
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
|
const paddle::Tensor &step_idx);
|
||||||
|
|
||||||
|
void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
|
||||||
|
const paddle::Tensor& accept_num,
|
||||||
|
const paddle::Tensor& not_need_stop,
|
||||||
|
int64_t rank_id,
|
||||||
|
bool save_each_rank);
|
||||||
|
|
||||||
|
|
||||||
|
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
|
||||||
|
const paddle::Tensor& seq_lens_decoder);
|
||||||
|
|
||||||
|
void NgramMatch(const paddle::Tensor &input_ids,
|
||||||
|
const paddle::Tensor &input_ids_len,
|
||||||
|
const paddle::Tensor &pre_ids,
|
||||||
|
const paddle::Tensor &step_idx,
|
||||||
|
const paddle::Tensor &draft_token_num,
|
||||||
|
const paddle::Tensor &draft_tokens,
|
||||||
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
|
const paddle::Tensor &seq_lens_encoder,
|
||||||
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
|
const paddle::Tensor &max_dec_len,
|
||||||
|
const int max_ngram_size,
|
||||||
|
const int max_draft_tokens);
|
||||||
|
|
||||||
|
|
||||||
|
// MTP
|
||||||
|
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
|
||||||
|
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||||
|
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||||
|
const paddle::Tensor& base_model_stop_flags);
|
||||||
|
|
||||||
|
|
||||||
|
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||||
|
const paddle::Tensor& input_ids,
|
||||||
|
const paddle::Tensor& stop_flags,
|
||||||
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
|
const paddle::Tensor& step_idx,
|
||||||
|
const paddle::Tensor& not_need_stop,
|
||||||
|
const paddle::Tensor& batch_drop,
|
||||||
|
const paddle::Tensor& accept_tokens,
|
||||||
|
const paddle::Tensor& accept_num,
|
||||||
|
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||||
|
const paddle::Tensor& base_model_seq_lens_decoder,
|
||||||
|
const paddle::Tensor& base_model_step_idx,
|
||||||
|
const paddle::Tensor& base_model_stop_flags,
|
||||||
|
const paddle::Tensor& base_model_is_block_step,
|
||||||
|
const paddle::Tensor& base_model_draft_tokens,
|
||||||
|
const int max_draft_token,
|
||||||
|
const bool truncate_first_token,
|
||||||
|
const bool splitwise_prefill);
|
||||||
|
|
||||||
|
|
||||||
|
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||||
|
const paddle::Tensor& draft_tokens,
|
||||||
|
const paddle::Tensor& pre_ids,
|
||||||
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
|
const paddle::Tensor& step_idx,
|
||||||
|
const paddle::Tensor& output_cum_offsets,
|
||||||
|
const paddle::Tensor& stop_flags,
|
||||||
|
const paddle::Tensor& not_need_stop,
|
||||||
|
const paddle::Tensor& max_dec_len,
|
||||||
|
const paddle::Tensor& end_ids,
|
||||||
|
const paddle::Tensor& base_model_draft_tokens,
|
||||||
|
const int max_seq_len,
|
||||||
|
const int substep);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> EagleGetHiddenStates(
|
||||||
|
const paddle::Tensor& input,
|
||||||
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
|
const paddle::Tensor& stop_flags,
|
||||||
|
const paddle::Tensor& accept_nums,
|
||||||
|
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||||
|
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||||
|
const int actual_draft_token_num);
|
||||||
|
|
||||||
|
void MTPStepPaddle(
|
||||||
|
const paddle::Tensor &base_model_stop_flags,
|
||||||
|
const paddle::Tensor &stop_flags,
|
||||||
|
const paddle::Tensor &batch_drop,
|
||||||
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
|
const paddle::Tensor &seq_lens_encoder,
|
||||||
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
|
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
|
||||||
|
const paddle::Tensor &encoder_block_lens,
|
||||||
|
const paddle::Tensor &used_list_len,
|
||||||
|
const paddle::Tensor &free_list,
|
||||||
|
const paddle::Tensor &free_list_len,
|
||||||
|
const int block_size,
|
||||||
|
const int max_draft_tokens);
|
||||||
|
|
||||||
|
void SpeculateStepPaddle(
|
||||||
|
const paddle::Tensor &stop_flags,
|
||||||
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
|
const paddle::Tensor &ori_seq_lens_encoder,
|
||||||
|
const paddle::Tensor &seq_lens_encoder,
|
||||||
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
|
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
|
||||||
|
const paddle::Tensor &encoder_block_lens,
|
||||||
|
const paddle::Tensor &is_block_step,
|
||||||
|
const paddle::Tensor &step_block_list,
|
||||||
|
const paddle::Tensor &step_lens,
|
||||||
|
const paddle::Tensor &recover_block_list,
|
||||||
|
const paddle::Tensor &recover_lens,
|
||||||
|
const paddle::Tensor &need_block_list,
|
||||||
|
const paddle::Tensor &need_block_len,
|
||||||
|
const paddle::Tensor &used_list_len,
|
||||||
|
const paddle::Tensor &free_list,
|
||||||
|
const paddle::Tensor &free_list_len,
|
||||||
|
const paddle::Tensor &input_ids,
|
||||||
|
const paddle::Tensor &pre_ids,
|
||||||
|
const paddle::Tensor &step_idx,
|
||||||
|
const paddle::Tensor &next_tokens,
|
||||||
|
const paddle::Tensor &first_token_ids,
|
||||||
|
const paddle::Tensor &accept_num,
|
||||||
|
const int block_size,
|
||||||
|
const int encoder_decoder_block_num,
|
||||||
|
const int max_draft_tokens);
|
||||||
|
|
||||||
PYBIND11_MODULE(fastdeploy_ops, m) {
|
PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||||
|
|
||||||
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
|
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
|
||||||
@@ -687,9 +894,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
* append_attn/get_block_shape_and_split_kv_block.cu
|
* append_attn/get_block_shape_and_split_kv_block.cu
|
||||||
* get_block_shape_and_split_kv_block
|
* get_block_shape_and_split_kv_block
|
||||||
*/
|
*/
|
||||||
// m.def("f_get_block_shape_and_split_kv_block",
|
m.def("get_block_shape_and_split_kv_block",
|
||||||
// &GetBlockShapeAndSplitKVBlock, "get_block_shape_and_split_kv_block
|
&GetBlockShapeAndSplitKVBlock, "get_block_shape_and_split_kv_block function");
|
||||||
// function");
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* get_padding_offset.cu
|
* get_padding_offset.cu
|
||||||
@@ -747,7 +953,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
"text_image_gather_scatter function");
|
"text_image_gather_scatter function");
|
||||||
|
|
||||||
m.def("count_tokens_per_expert_func", &count_tokens_per_expert_func);
|
m.def("count_tokens_per_expert_func", &count_tokens_per_expert_func);
|
||||||
|
|
||||||
m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel);
|
m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel);
|
||||||
|
|
||||||
m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi,
|
m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi,
|
||||||
@@ -786,7 +991,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant,
|
m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant,
|
||||||
"dynamic_per_token_scaled_fp8_quant function",
|
"dynamic_per_token_scaled_fp8_quant function",
|
||||||
py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub"));
|
py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub"));
|
||||||
|
|
||||||
m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function");
|
m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function");
|
||||||
|
|
||||||
m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function");
|
m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function");
|
||||||
@@ -802,11 +1006,9 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"),
|
py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"),
|
||||||
py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"),
|
py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"),
|
||||||
py::arg("activation_type"), "cutlass_fp8_fp8_half_gemm_fused function");
|
py::arg("activation_type"), "cutlass_fp8_fp8_half_gemm_fused function");
|
||||||
|
|
||||||
m.def("moe_fused_hadamard_quant_fp8", &MoeFusedHadamardQuantFp8Func,
|
m.def("moe_fused_hadamard_quant_fp8", &MoeFusedHadamardQuantFp8Func,
|
||||||
py::arg("input"), py::arg("scale"), py::arg("topk_ids"),
|
py::arg("input"), py::arg("scale"), py::arg("topk_ids"),
|
||||||
py::arg("top_k"), py::arg("intermediate_size"), py::arg("tiled"), "moe_fused_hadamard_quant_fp8 function");
|
py::arg("top_k"), py::arg("intermediate_size"), py::arg("tiled"), "moe_fused_hadamard_quant_fp8 function");
|
||||||
|
|
||||||
m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func,
|
m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func,
|
||||||
py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function");
|
py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function");
|
||||||
#endif
|
#endif
|
||||||
@@ -830,4 +1032,39 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
m.def("open_mem_handle", &open_mem_handle, "open_mem_handle");
|
m.def("open_mem_handle", &open_mem_handle, "open_mem_handle");
|
||||||
|
|
||||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta");
|
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta");
|
||||||
|
|
||||||
|
// speculative decoding Kernel
|
||||||
|
m.def("speculate_get_padding_offset", &SpeculateGetPaddingOffset, "speculate_get_padding_offset function");
|
||||||
|
|
||||||
|
m.def("speculate_get_seq_lens_output", &SpeculateGetSeqLensOutput, "speculate_get_seq_lens_output function");
|
||||||
|
|
||||||
|
m.def("speculate_get_output_padding_offset",&SpeculateGetOutputPaddingOffset, "speculate_get_output_padding_offset function");
|
||||||
|
|
||||||
|
m.def("speculate_get_token_penalty_multi_scores",&SpecTokenPenaltyMultiScores, "speculate_get_token_penalty_multi_scores function");
|
||||||
|
|
||||||
|
m.def("speculate_set_stop_value_multi_seqs",&SpecGetStopFlagsMultiSeqs, "speculate_set_stop_value_multi_seqs function");
|
||||||
|
|
||||||
|
m.def("speculate_verify",&SpeculateVerify, "speculate_verify function");
|
||||||
|
|
||||||
|
m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function");
|
||||||
|
|
||||||
|
m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function");
|
||||||
|
|
||||||
|
m.def("speculate_save_output", &SpeculateSaveWithOutputMsgStatic, "speculate_save_output function");
|
||||||
|
|
||||||
|
m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function");
|
||||||
|
|
||||||
|
m.def("ngram_match", &NgramMatch, "ngram_match function");
|
||||||
|
|
||||||
|
m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function");
|
||||||
|
|
||||||
|
m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function");
|
||||||
|
|
||||||
|
m.def("draft_model_update",&DraftModelUpdate, "draft_model_update function");
|
||||||
|
|
||||||
|
m.def("eagle_get_hidden_states",&EagleGetHiddenStates, "eagle_get_hidden_states function");
|
||||||
|
|
||||||
|
m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function");
|
||||||
|
|
||||||
|
m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function");
|
||||||
}
|
}
|
||||||
|
@@ -246,7 +246,7 @@ void token_penalty_multi_scores_kernel(
|
|||||||
max_seq_len);
|
max_seq_len);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
||||||
const paddle::Tensor &logits,
|
const paddle::Tensor &logits,
|
||||||
const paddle::Tensor &penalty_scores,
|
const paddle::Tensor &penalty_scores,
|
||||||
const paddle::Tensor &frequency_scores,
|
const paddle::Tensor &frequency_scores,
|
||||||
@@ -338,4 +338,4 @@ PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores)
|
|||||||
.Outputs({"logits_out"})
|
.Outputs({"logits_out"})
|
||||||
.Attrs({"max_seq_len: int"})
|
.Attrs({"max_seq_len: int"})
|
||||||
.SetInplaceMap({{"logits", "logits_out"}})
|
.SetInplaceMap({{"logits", "logits_out"}})
|
||||||
.SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores));
|
.SetKernelFn(PD_KERNEL(SpecTokenPenaltyMultiScores));
|
||||||
|
@@ -266,18 +266,6 @@ void SpeculateVerify(
|
|||||||
seed++;
|
seed++;
|
||||||
offset++;
|
offset++;
|
||||||
|
|
||||||
auto err = cudaDeviceSynchronize();
|
|
||||||
if (err != 0) {
|
|
||||||
printf("err %d\n", err);
|
|
||||||
}
|
|
||||||
|
|
||||||
err = cudaGetLastError();
|
|
||||||
|
|
||||||
if (err != 0) {
|
|
||||||
printf("err %d\n", err);
|
|
||||||
}
|
|
||||||
|
|
||||||
// printf("inited curand\n");
|
|
||||||
bool use_topk = false;
|
bool use_topk = false;
|
||||||
char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
|
char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
|
||||||
if (env_var) {
|
if (env_var) {
|
||||||
|
@@ -123,7 +123,7 @@ def apply_speculative_penalty_multi_scores(
|
|||||||
from fastdeploy.model_executor.ops.gpu import \
|
from fastdeploy.model_executor.ops.gpu import \
|
||||||
speculate_get_token_penalty_multi_scores
|
speculate_get_token_penalty_multi_scores
|
||||||
|
|
||||||
logits = speculate_get_token_penalty_multi_scores(
|
speculate_get_token_penalty_multi_scores(
|
||||||
pre_token_ids,
|
pre_token_ids,
|
||||||
logits,
|
logits,
|
||||||
repetition_penalties,
|
repetition_penalties,
|
||||||
@@ -141,5 +141,5 @@ def apply_speculative_penalty_multi_scores(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
# inplace
|
||||||
return logits
|
return logits
|
||||||
|
@@ -101,6 +101,8 @@ def pre_process(
|
|||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
)
|
)
|
||||||
|
if isinstance(seq_lens_output, list):
|
||||||
|
seq_lens_output = seq_lens_output[0]
|
||||||
output_token_num = paddle.sum(seq_lens_output)
|
output_token_num = paddle.sum(seq_lens_output)
|
||||||
output_cum_offsets_tmp = paddle.cumsum(max_len - seq_lens_output)
|
output_cum_offsets_tmp = paddle.cumsum(max_len - seq_lens_output)
|
||||||
output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset(
|
output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset(
|
||||||
|
@@ -497,6 +497,8 @@ class MTPProposer(Proposer):
|
|||||||
self.main_model_inputs["seq_lens_encoder"],
|
self.main_model_inputs["seq_lens_encoder"],
|
||||||
self.max_draft_token_num,
|
self.max_draft_token_num,
|
||||||
)
|
)
|
||||||
|
if isinstance(target_hidden_states, list):
|
||||||
|
target_hidden_states = target_hidden_states[0]
|
||||||
|
|
||||||
return target_hidden_states
|
return target_hidden_states
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user