mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature][MTP]support new speculative decoding method named hybrid mtp with ngram (#3610)
This commit is contained in:
@@ -614,7 +614,7 @@ void SpeculateVerify(
|
||||
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
|
||||
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
|
||||
|
||||
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
|
||||
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
@@ -659,6 +659,20 @@ void NgramMatch(const paddle::Tensor &input_ids,
|
||||
const int max_draft_tokens);
|
||||
|
||||
|
||||
void HybridMtpNgram(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &input_ids_len,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &draft_token_num,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &max_dec_len,
|
||||
const int max_ngram_size,
|
||||
const int min_ngram_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
|
||||
// MTP
|
||||
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
@@ -675,6 +689,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
@@ -1121,7 +1136,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("speculate_verify",&SpeculateVerify, "speculate_verify function");
|
||||
|
||||
m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function");
|
||||
m.def("speculate_update",&SpeculateUpdate, "Speculate Update Kernel");
|
||||
|
||||
m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function");
|
||||
|
||||
@@ -1131,6 +1146,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("ngram_match", &NgramMatch, "ngram_match function");
|
||||
|
||||
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");
|
||||
|
||||
m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function");
|
||||
|
||||
m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function");
|
||||
|
@@ -26,6 +26,7 @@ __global__ void process_splitwise_prefill(
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
@@ -36,11 +37,12 @@ __global__ void process_splitwise_prefill(
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len) {
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len) {
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
int64_t not_stop_flag = 0;
|
||||
@@ -93,6 +95,7 @@ __global__ void draft_model_preprocess_kernel(
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
@@ -103,11 +106,12 @@ __global__ void draft_model_preprocess_kernel(
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len) {
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len) {
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
int64_t not_stop_flag = 0;
|
||||
@@ -124,6 +128,7 @@ __global__ void draft_model_preprocess_kernel(
|
||||
base_model_draft_tokens + tid * base_model_draft_tokens_len;
|
||||
auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid];
|
||||
const int32_t base_model_seq_len_this_time = base_model_seq_lens_this_time[tid];
|
||||
auto* pre_ids_now = pre_ids + tid * pre_ids_len;
|
||||
#pragma unroll
|
||||
for (int i = 1; i < base_model_draft_tokens_len; i++) {
|
||||
base_model_draft_tokens_now[i] = -1;
|
||||
@@ -137,14 +142,12 @@ __global__ void draft_model_preprocess_kernel(
|
||||
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
|
||||
not_stop_flag = 1;
|
||||
// 1. first token
|
||||
if (base_model_step_idx_now == 0) {
|
||||
seq_lens_this_time[tid] = 0;
|
||||
not_stop_flag = 0;
|
||||
} else if (seq_lens_encoder[tid] > 0) {
|
||||
if (seq_lens_encoder[tid] > 0) {
|
||||
// Can be extended to first few tokens
|
||||
int seq_len_encoder = seq_lens_encoder[tid];
|
||||
stop_flags[tid] = false;
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
pre_ids_now[0] = base_model_first_token;
|
||||
int position = seq_len_encoder;
|
||||
if (TRCUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
@@ -161,34 +164,17 @@ __global__ void draft_model_preprocess_kernel(
|
||||
step_idx[tid] = base_model_step_idx[tid] - base_model_seq_len_this_time;
|
||||
} else {
|
||||
// 2: Last base model generated token and first MTP token
|
||||
seq_lens_decoder[tid] -= (base_model_seq_len_this_time - 2);
|
||||
step_idx[tid] -= (base_model_seq_len_this_time - 2);
|
||||
seq_lens_decoder[tid] -= num_model_step - 1;
|
||||
step_idx[tid] -= num_model_step - 1;
|
||||
}
|
||||
for (int i = 0; i < accept_num_now; i++) {
|
||||
draft_tokens_now[i] = accept_tokens_now[i];
|
||||
const int pre_id_pos = base_model_step_idx[tid] - (accept_num_now - i);
|
||||
const int64_t accept_token = accept_tokens_now[i];
|
||||
pre_ids_now[pre_id_pos] = accept_token;
|
||||
}
|
||||
seq_lens_this_time[tid] = accept_num_now;
|
||||
}
|
||||
// (liuzichang): Temperary Reserved for debug
|
||||
// else if (accept_num_now <=
|
||||
// max_draft_token) /*Accept partial draft tokens*/ {
|
||||
// // Base Model reject stop
|
||||
// if (stop_flags[tid]) {
|
||||
// stop_flags[tid] = false;
|
||||
// seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
|
||||
// step_idx[tid] = base_model_step_idx[tid];
|
||||
// } else {
|
||||
// seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
|
||||
// step_idx[tid] -= max_draft_token - accept_num_now;
|
||||
// }
|
||||
// int64_t modified_token = accept_tokens_now[accept_num_now - 1];
|
||||
// draft_tokens_now[0] = modified_token;
|
||||
// seq_lens_this_time[tid] = 1;
|
||||
|
||||
// } else /*Accept all draft tokens*/ {
|
||||
// draft_tokens_now[1] = accept_tokens_now[max_draft_token];
|
||||
// seq_lens_this_time[tid] = 2;
|
||||
// }
|
||||
} else {
|
||||
stop_flags[tid] = true;
|
||||
seq_lens_this_time[tid] = 0;
|
||||
@@ -215,6 +201,7 @@ void DispatchRunner(
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
@@ -225,11 +212,12 @@ void DispatchRunner(
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool splitwise_prefill) {
|
||||
constexpr int BlockSize = 512;
|
||||
if (splitwise_prefill) {
|
||||
@@ -244,6 +232,7 @@ void DispatchRunner(
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
@@ -254,11 +243,12 @@ void DispatchRunner(
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len);
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
} else {
|
||||
draft_model_preprocess_kernel<BlockSize, TRCUNCATE_FIRST_TOKEN>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
@@ -271,6 +261,7 @@ void DispatchRunner(
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
@@ -281,11 +272,12 @@ void DispatchRunner(
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len);
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,6 +292,7 @@ void DispatchTokenMode(
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
@@ -310,11 +303,12 @@ void DispatchTokenMode(
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
if (truncate_first_token) {
|
||||
@@ -329,6 +323,7 @@ void DispatchTokenMode(
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
@@ -339,11 +334,12 @@ void DispatchTokenMode(
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
} else {
|
||||
@@ -358,6 +354,7 @@ void DispatchTokenMode(
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
@@ -368,11 +365,12 @@ void DispatchTokenMode(
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
}
|
||||
@@ -390,6 +388,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
@@ -399,13 +398,14 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& base_model_stop_flags,
|
||||
const paddle::Tensor& base_model_is_block_step,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_draft_token,
|
||||
const int num_model_step,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
int accept_tokens_len = accept_tokens.shape()[1];
|
||||
int input_ids_len = input_ids.shape()[1];
|
||||
int draft_tokens_len = draft_tokens.shape()[1];
|
||||
int pre_ids_len = pre_ids.shape()[1];
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
constexpr int BlockSize = 512;
|
||||
int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1];
|
||||
@@ -423,6 +423,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool*>(batch_drop.data<bool>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
@@ -433,11 +434,12 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
base_model_is_block_step.data<bool>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
max_draft_token,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill);
|
||||
|
||||
@@ -458,6 +460,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"step_idx",
|
||||
"not_need_stop",
|
||||
"batch_drop",
|
||||
"pre_ids",
|
||||
"accept_tokens",
|
||||
"accept_num",
|
||||
"base_model_seq_lens_this_time",
|
||||
@@ -475,8 +478,9 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"seq_lens_decoder_out",
|
||||
"step_idx_out",
|
||||
"not_need_stop_out",
|
||||
"batch_drop_out"})
|
||||
.Attrs({"max_draft_token: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
|
||||
"batch_drop_out",
|
||||
"pre_ids_out"})
|
||||
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
@@ -485,5 +489,6 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"step_idx", "step_idx_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},
|
||||
{"batch_drop", "batch_drop_out"}})
|
||||
{"batch_drop", "batch_drop_out"},
|
||||
{"pre_ids", "pre_ids_out"}})
|
||||
.SetKernelFn(PD_KERNEL(DraftModelPreprocess));
|
||||
|
@@ -63,10 +63,9 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
|
||||
token_this_time = next_tokens_start[seq_len_this_time - 1];
|
||||
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
|
||||
base_model_draft_tokens_now[substep + 1] = token_this_time;
|
||||
for (int i = 0; i < seq_len_this_time; ++i) {
|
||||
pre_ids_now[step_idx[tid] + 1 + i] = next_tokens_start[i];
|
||||
}
|
||||
step_idx[tid] += seq_len_this_time;
|
||||
pre_ids_now[step_idx[tid]] = token_this_time;
|
||||
|
||||
|
||||
} else {
|
||||
token_this_time = next_tokens_start[0];
|
||||
|
@@ -49,9 +49,7 @@ __global__ void ComputeOrderKernel(
|
||||
for (int j = 0; j < cur_seq_lens_encoder; j++) {
|
||||
position_map[in_offset++] = out_offset++;
|
||||
}
|
||||
// 2. base model encoder. Base step=0
|
||||
} else if (cur_base_model_seq_lens_encoder != 0) {
|
||||
// 3. New end
|
||||
// 2. Base model stop at last verify-step.
|
||||
} else if (cur_base_model_seq_lens_this_time != 0 && cur_seq_lens_this_time == 0) {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: base=0. draft !=0 \n", i);
|
||||
|
@@ -0,0 +1,214 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cstdlib>
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
|
||||
int sum_mixed(const int *value, int num) {
|
||||
int sum_value = 0;
|
||||
for (int i = 0; i <= num; i++) {
|
||||
sum_value += value[i];
|
||||
}
|
||||
return sum_value;
|
||||
}
|
||||
|
||||
void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
|
||||
const int64_t *input_ids_len,
|
||||
const int64_t *pre_ids,
|
||||
const int64_t *step_idx,
|
||||
const int *draft_token_num,
|
||||
int64_t *draft_tokens,
|
||||
int32_t *seq_lens_this_time,
|
||||
int32_t *seq_lens_decoder,
|
||||
int64_t *max_dec_len,
|
||||
int64_t input_ids_stride,
|
||||
int64_t pre_ids_stride,
|
||||
int64_t draft_tokens_stride,
|
||||
int64_t max_batch_size,
|
||||
int max_ngram_size = 3,
|
||||
int min_ngram_size = 1,
|
||||
const int max_draft_tokens = 10) {
|
||||
int threshold = 1024;
|
||||
// dynamic in future
|
||||
char *env_var = getenv("SPEC_TOKENUM_THRESHOLD");
|
||||
if (env_var) {
|
||||
threshold = std::stoi(env_var);
|
||||
}
|
||||
int unprocessed_batch_size = 0;
|
||||
for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
|
||||
if (seq_lens_decoder[batch_idx] > 0) {
|
||||
unprocessed_batch_size++;
|
||||
}
|
||||
}
|
||||
for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
|
||||
const int ori_seq_len_this_time = seq_lens_this_time[batch_idx];
|
||||
int max_draft_tokens_query = std::min(static_cast<int64_t>(
|
||||
max_draft_tokens - ori_seq_len_this_time + 1), max_dec_len[batch_idx] - step_idx[batch_idx] - 1);
|
||||
|
||||
if (ori_seq_len_this_time == 0 || max_draft_tokens_query <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
|
||||
int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride;
|
||||
const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride;
|
||||
const int64_t cur_step_idx = step_idx[batch_idx];
|
||||
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
|
||||
unprocessed_batch_size--;
|
||||
|
||||
auto sum_token_num = sum_mixed(seq_lens_this_time, batch_idx);
|
||||
int left_min_token_num = unprocessed_batch_size;
|
||||
|
||||
if (sum_token_num + max_draft_tokens_query + left_min_token_num > threshold) {
|
||||
int tmp_max_draft_tokens = threshold - sum_token_num - left_min_token_num;
|
||||
max_draft_tokens_query = std::min(max_draft_tokens_query, tmp_max_draft_tokens);
|
||||
}
|
||||
|
||||
if (sum_token_num + left_min_token_num >= threshold - 1) {
|
||||
continue;
|
||||
}
|
||||
bool match_global = false;
|
||||
// apply ngram_match in input_ids
|
||||
for (int ngram_size = max_ngram_size; ngram_size >= min_ngram_size && !match_global; --ngram_size) {
|
||||
// Extract the last n tokens as our search ngram
|
||||
if (cur_step_idx < ngram_size) {
|
||||
continue;
|
||||
}
|
||||
const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
|
||||
|
||||
// Iterate through sliding windows of size ngram_size
|
||||
// bool match_input = false;
|
||||
for (int64_t i = 0; i <= cur_input_ids_len - ngram_size && !match_global; ++i) {
|
||||
// Check if the current window matches the ngram
|
||||
bool match_local = true;
|
||||
for (int j = 0; j < ngram_size; j++) {
|
||||
if (ngram[j] != cur_input_ids[i + j]) {
|
||||
match_local = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_local) {
|
||||
int64_t start_idx = i + ngram_size;
|
||||
int64_t end_idx = std::min(start_idx + max_draft_tokens_query, cur_input_ids_len);
|
||||
if (start_idx >= end_idx)
|
||||
continue;
|
||||
|
||||
int64_t cur_draft_token_num = end_idx - start_idx;
|
||||
|
||||
seq_lens_this_time[batch_idx] = ori_seq_len_this_time + cur_draft_token_num;
|
||||
memcpy(cur_draft_tokens + ori_seq_len_this_time, cur_input_ids + start_idx, sizeof(int64_t) * cur_draft_token_num);
|
||||
// To break the current batch_idx for-loop
|
||||
match_global = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// apply ngram_match in generated tokens
|
||||
if (!match_global) {
|
||||
for (int64_t i = 0; i <= cur_step_idx - ngram_size && !match_global; ++i) {
|
||||
// Check if the current window matches the ngram
|
||||
bool match_local = true;
|
||||
|
||||
for (int j = 0; j < ngram_size; j++) {
|
||||
if (ngram[j] != cur_pre_ids[i + j]) {
|
||||
match_local = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_local) {
|
||||
int64_t start_idx = i + ngram_size;
|
||||
int64_t end_idx = std::min(start_idx + max_draft_tokens_query, cur_step_idx);
|
||||
|
||||
int64_t cur_draft_token_num = end_idx - start_idx;
|
||||
|
||||
if (start_idx >= end_idx)
|
||||
continue;
|
||||
// printf("match in Output with Ngram_size %d. %lld:[%lld,%lld]\n",ngram_size, cur_draft_token_num, start_idx, end_idx);
|
||||
|
||||
seq_lens_this_time[batch_idx] = ori_seq_len_this_time + cur_draft_token_num;
|
||||
memcpy(cur_draft_tokens + ori_seq_len_this_time, cur_pre_ids + start_idx, sizeof(int64_t) * cur_draft_token_num);
|
||||
match_global = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void HybridMtpNgram(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &input_ids_len,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &draft_token_num,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &max_dec_len,
|
||||
const int max_ngram_size,
|
||||
const int min_ngram_size,
|
||||
const int max_draft_tokens) {
|
||||
|
||||
auto input_ids_shape = input_ids.shape();
|
||||
const int64_t input_ids_stride = input_ids_shape[1];
|
||||
|
||||
auto pre_ids_shape = pre_ids.shape();
|
||||
const int64_t pre_ids_stride = pre_ids_shape[1];
|
||||
|
||||
auto draft_tokens_shape = draft_tokens.shape();
|
||||
const int64_t draft_tokens_stride = draft_tokens_shape[1];
|
||||
|
||||
const int64_t max_batch_size = seq_lens_this_time.shape()[0];
|
||||
|
||||
find_candidate_pred_tokens_mixed(input_ids.data<int64_t>(),
|
||||
input_ids_len.data<int64_t>(),
|
||||
pre_ids.data<int64_t>(),
|
||||
step_idx.data<int64_t>(),
|
||||
draft_token_num.data<int>(),
|
||||
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
|
||||
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
|
||||
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
|
||||
input_ids_stride,
|
||||
pre_ids_stride,
|
||||
draft_tokens_stride,
|
||||
max_batch_size,
|
||||
max_ngram_size,
|
||||
min_ngram_size,
|
||||
max_draft_tokens);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(hybrid_mtp_ngram)
|
||||
.Inputs({"input_ids",
|
||||
"input_ids_len",
|
||||
"pre_ids",
|
||||
"step_idx",
|
||||
"draft_token_num",
|
||||
"draft_tokens",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_decoder",
|
||||
"max_dec_len"})
|
||||
.Attrs({"max_ngram_size: int", "min_ngram_size: int", "max_draft_tokens: int"})
|
||||
.Outputs({"draft_tokens_out", "seq_lens_this_time_out"})
|
||||
.SetKernelFn(PD_KERNEL(HybridMtpNgram))
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, {"seq_lens_this_time", "seq_lens_this_time_out"}});
|
@@ -23,14 +23,7 @@
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#define MAX_BSZ 256
|
||||
#define MAX_DRAFT_TOKENS 6
|
||||
|
||||
struct msgdata {
|
||||
int64_t mtype;
|
||||
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
|
||||
2]; // stop_flag, bsz, accept_num*bsz, tokens...
|
||||
};
|
||||
#include "speculate_msg.h"
|
||||
|
||||
void SpeculateGetOutput(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
@@ -54,7 +47,7 @@ void SpeculateGetOutput(const paddle::Tensor& x,
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
}
|
||||
|
||||
static struct msgdata msg_rcv;
|
||||
static struct speculate_msgdata msg_rcv;
|
||||
|
||||
static key_t key = ftok("./", msg_queue_id);
|
||||
|
||||
|
@@ -1,69 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
__global__ void SpeculateHydraSetScoreThresholdKernel(
|
||||
float* threshold,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int* accept_num,
|
||||
const int real_bsz,
|
||||
const float default_threshold = 0.3,
|
||||
const float upper_threshold = 0.8,
|
||||
const float lower_threshold = 0.0,
|
||||
const float threshold_step = 0.1,
|
||||
const float threshold_step_fac = 0.5) {
|
||||
for (int bid = threadIdx.x; bid < real_bsz; bid += blockDim.x) {
|
||||
if (seq_lens_encoder[bid] > 0) {
|
||||
threshold[bid] = default_threshold;
|
||||
} else if (seq_lens_this_time[bid] <= 1) {
|
||||
continue;
|
||||
} else if (accept_num[bid] >= seq_lens_this_time[bid] &&
|
||||
threshold[bid] >
|
||||
lower_threshold + threshold_step * threshold_step_fac) {
|
||||
threshold[bid] -= threshold_step * threshold_step_fac;
|
||||
} else if (accept_num[bid] < seq_lens_this_time[bid] &&
|
||||
threshold[bid] < upper_threshold - threshold_step) {
|
||||
threshold[bid] += threshold_step;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateHydraSetScoreThreshold(const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& threshold) {
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
std::vector<int64_t> seq_lens_this_time_shape = seq_lens_this_time.shape();
|
||||
const int bsz = seq_lens_this_time_shape[0];
|
||||
|
||||
SpeculateHydraSetScoreThresholdKernel<<<1, 256, 0, cu_stream>>>(
|
||||
const_cast<float*>(threshold.data<float>()),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
accept_num.data<int>(),
|
||||
bsz);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_hydra_set_score_threshold)
|
||||
.Inputs(
|
||||
{"seq_lens_this_time", "seq_lens_encoder", "accept_num", "threshold"})
|
||||
.Outputs({"threshold_out"})
|
||||
.SetInplaceMap({{"threshold", "threshold_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateHydraSetScoreThreshold));
|
@@ -1,68 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
__global__ void hydra_update_this_time(int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
const float* topk_scores,
|
||||
const float* score_threshold,
|
||||
int real_bsz,
|
||||
int idx) {
|
||||
int linear_idx = threadIdx.x;
|
||||
// verify and set stop flags
|
||||
for (; linear_idx < real_bsz; linear_idx += blockDim.x) {
|
||||
if (seq_lens_encoder[linear_idx] == 0 &&
|
||||
seq_lens_decoder[linear_idx] != 0) {
|
||||
if (topk_scores[linear_idx] > score_threshold[linear_idx] &&
|
||||
seq_lens_this_time[linear_idx] == idx + 1) {
|
||||
seq_lens_this_time[linear_idx]++;
|
||||
}
|
||||
} else if (seq_lens_encoder[linear_idx] == 0 &&
|
||||
seq_lens_decoder[linear_idx] == 0) {
|
||||
seq_lens_this_time[linear_idx] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void HydraUpdateThisTime(const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& topk_scores,
|
||||
const paddle::Tensor& score_threshold,
|
||||
const int real_bsz,
|
||||
const int idx) {
|
||||
constexpr int BlockSize = 512;
|
||||
|
||||
hydra_update_this_time<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
topk_scores.data<float>(),
|
||||
score_threshold.data<float>(),
|
||||
real_bsz,
|
||||
idx);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_hydra_update_seqlens_this_time)
|
||||
.Inputs({"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"topk_scores",
|
||||
"score_threshold"})
|
||||
.Outputs({"seq_lens_this_time_out"})
|
||||
.Attrs({"real_bsz: int", "idx: int"})
|
||||
.SetInplaceMap({{"seq_lens_this_time", "seq_lens_this_time_out"}})
|
||||
.SetKernelFn(PD_KERNEL(HydraUpdateThisTime));
|
@@ -1,149 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "helper.h"
|
||||
|
||||
template <typename T, int VecSize>
|
||||
__global__ void RebuildAppendPaddingKernel(
|
||||
T *out,
|
||||
const T *full_hidden_states,
|
||||
const int *cum_offset,
|
||||
const int *seq_len_encoder,
|
||||
const int *seq_len_decoder,
|
||||
const int *output_padding_offset,
|
||||
const int seq_len,
|
||||
const int dim_embed,
|
||||
const size_t elem_nums) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
for (int64_t i = global_idx * VecSize; i < elem_nums; i += gridDim.x * blockDim.x * VecSize) {
|
||||
const int out_token_id = i / dim_embed;
|
||||
const int ori_token_id = out_token_id + output_padding_offset[out_token_id];
|
||||
const int bi = ori_token_id / seq_len;
|
||||
int seq_id = 0;
|
||||
|
||||
if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue;
|
||||
else if (seq_len_encoder[bi] != 0) {
|
||||
seq_id = seq_len_encoder[bi] - 1;
|
||||
}
|
||||
|
||||
const int input_token_id = ori_token_id - cum_offset[bi] + seq_id;
|
||||
const int bias_idx = i % dim_embed;
|
||||
|
||||
Load<T, VecSize>(&full_hidden_states[input_token_id * dim_embed + bias_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &out[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> DispatchDtype(
|
||||
const paddle::Tensor& full_hidden_states,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const int max_seq_len) {
|
||||
// src: [token_num, dim_embed]
|
||||
// dst: [batch_size, 1, dim_embed]
|
||||
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
|
||||
int dim_embed = full_hidden_states.shape()[1];
|
||||
int output_token_num = output_padding_offset.shape()[0];
|
||||
int elem_nums = output_token_num * dim_embed;
|
||||
constexpr int PackSize = VEC_16B / sizeof(DataType_);
|
||||
assert(elem_nums % PackSize == 0);
|
||||
|
||||
auto out = paddle::full({output_token_num, dim_embed}, 0, full_hidden_states.dtype(), full_hidden_states.place());
|
||||
|
||||
int pack_num = elem_nums / PackSize;
|
||||
const int threads_per_block = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks(pack_num, &grid_size);
|
||||
|
||||
RebuildAppendPaddingKernel<DataType_, PackSize><<<grid_size, threads_per_block, 0, full_hidden_states.stream()>>>(
|
||||
reinterpret_cast<DataType_*>(out.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(full_hidden_states.data<data_t>()),
|
||||
cum_offsets.data<int32_t>(),
|
||||
seq_len_encoder.data<int32_t>(),
|
||||
seq_len_decoder.data<int32_t>(),
|
||||
output_padding_offset.data<int32_t>(),
|
||||
max_seq_len,
|
||||
dim_embed,
|
||||
elem_nums);
|
||||
return {out};
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> RebuildAppendPadding(
|
||||
const paddle::Tensor& full_hidden_states,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const int max_seq_len) {
|
||||
|
||||
|
||||
switch (full_hidden_states.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return DispatchDtype<paddle::DataType::BFLOAT16>(
|
||||
full_hidden_states, cum_offsets, seq_len_encoder, seq_len_decoder, output_padding_offset, max_seq_len);
|
||||
case paddle::DataType::FLOAT16:
|
||||
return DispatchDtype<paddle::DataType::FLOAT16>(
|
||||
full_hidden_states, cum_offsets, seq_len_encoder, seq_len_decoder, output_padding_offset, max_seq_len);
|
||||
default:
|
||||
PD_THROW("Unsupported data type.");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<int64_t>> RebuildAppendPaddingInferShape(
|
||||
const std::vector<int64_t>& full_hidden_states_shape,
|
||||
const std::vector<int64_t>& cum_offsets_shape,
|
||||
const std::vector<int64_t>& seq_len_encoder_shape,
|
||||
const std::vector<int64_t>& seq_len_decoder_shape,
|
||||
const std::vector<int64_t>& output_padding_offset_shape) {
|
||||
const int64_t output_token_num = output_padding_offset_shape[0];
|
||||
const int64_t dim_embed = full_hidden_states_shape[1];
|
||||
std::vector<int64_t> out_shape = {output_token_num, dim_embed};
|
||||
return {out_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> RebuildAppendPaddingInferDtype(
|
||||
const paddle::DataType& full_hidden_states_dtype,
|
||||
const paddle::DataType& cum_offsets_dtype,
|
||||
const paddle::DataType& seq_len_encoder_dtype,
|
||||
const paddle::DataType& seq_len_decoder_dtype,
|
||||
const paddle::DataType& output_padding_offset_dtype) {
|
||||
return {full_hidden_states_dtype};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_rebuild_append_padding)
|
||||
.Inputs({"full_hidden_states",
|
||||
"cum_offsets",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder",
|
||||
"output_padding_offset"})
|
||||
.Attrs({"max_seq_len: int"})
|
||||
.Outputs({"out"})
|
||||
.SetKernelFn(PD_KERNEL(RebuildAppendPadding))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(RebuildAppendPaddingInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildAppendPaddingInferDtype));
|
@@ -23,14 +23,7 @@
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#define MAX_BSZ 256
|
||||
#define MAX_DRAFT_TOKENS 6
|
||||
|
||||
struct msgdata {
|
||||
long mtype;
|
||||
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
|
||||
2]; // stop_flag, bsz, tokens
|
||||
};
|
||||
#include "speculate_msg.h"
|
||||
|
||||
void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
@@ -62,7 +55,7 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
|
||||
#endif
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
}
|
||||
static struct msgdata msg_sed;
|
||||
static struct speculate_msgdata msg_sed;
|
||||
static key_t key = ftok("./", msg_queue_id);
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
|
||||
|
@@ -15,7 +15,7 @@
|
||||
#include "helper.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void speculate_update_v3(int *seq_lens_encoder,
|
||||
__global__ void speculate_update(int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
bool *not_need_stop,
|
||||
int64_t *draft_tokens,
|
||||
@@ -90,7 +90,7 @@ __global__ void speculate_update_v3(int *seq_lens_encoder,
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
|
||||
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
@@ -108,7 +108,7 @@ void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
|
||||
constexpr int BlockSize = 512;
|
||||
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
speculate_update_v3<BlockSize><<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
speculate_update<BlockSize><<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
@@ -130,7 +130,7 @@ void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_update_v3)
|
||||
PD_BUILD_STATIC_OP(speculate_update)
|
||||
.Inputs({"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"not_need_stop",
|
||||
@@ -152,4 +152,4 @@ PD_BUILD_STATIC_OP(speculate_update_v3)
|
||||
{"not_need_stop", "not_need_stop_out"},
|
||||
{"draft_tokens", "draft_tokens_out"},
|
||||
{"actual_draft_token_nums", "actual_draft_token_nums_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateUpdateV3));
|
||||
.SetKernelFn(PD_KERNEL(SpeculateUpdate));
|
@@ -1,55 +0,0 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h" // NOLINT
|
||||
|
||||
__global__ void update_this_time(int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
int real_bsz,
|
||||
int value) {
|
||||
int linear_idx = threadIdx.x;
|
||||
// verify and set stop flags
|
||||
for (; linear_idx < real_bsz; linear_idx += blockDim.x) {
|
||||
if (seq_lens_encoder[linear_idx] == 0 &&
|
||||
seq_lens_decoder[linear_idx] != 0) {
|
||||
seq_lens_this_time[linear_idx] = value;
|
||||
} else if (seq_lens_encoder[linear_idx] == 0 &&
|
||||
seq_lens_decoder[linear_idx] == 0) {
|
||||
seq_lens_this_time[linear_idx] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateThisTime(const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const int real_bsz,
|
||||
const int value) {
|
||||
constexpr int BlockSize = 512;
|
||||
|
||||
update_this_time<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
real_bsz,
|
||||
value);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_update_seq_lens_this_time)
|
||||
.Inputs({"seq_lens_this_time", "seq_lens_encoder", "seq_lens_decoder"})
|
||||
.Outputs({"seq_lens_this_time_out"})
|
||||
.Attrs({"real_bsz: int", "value: int"})
|
||||
.SetInplaceMap({{"seq_lens_this_time", "seq_lens_this_time_out"}})
|
||||
.SetKernelFn(PD_KERNEL(UpdateThisTime));
|
@@ -1,146 +0,0 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h" // NOLINT
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void speculate_update(int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
bool *not_need_stop,
|
||||
int64_t *draft_tokens,
|
||||
int *actual_draft_token_nums,
|
||||
const int64_t *accept_tokens,
|
||||
const int *accept_num,
|
||||
const bool *stop_flags,
|
||||
const int *seq_lens_this_time,
|
||||
const bool *is_block_step,
|
||||
const int real_bsz,
|
||||
const int max_draft_tokens) {
|
||||
const int bid = threadIdx.x;
|
||||
const int accept_num_now = accept_num[bid];
|
||||
int stop_flag_now_int = 0;
|
||||
if (!(is_block_step[bid] || bid >= real_bsz)) {
|
||||
if (stop_flags[bid]) {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
if (seq_lens_encoder[bid] == 0) {
|
||||
seq_lens_decoder[bid] += accept_num_now;
|
||||
}
|
||||
|
||||
if (seq_lens_this_time[bid] > 1 &&
|
||||
seq_lens_encoder[bid] ==
|
||||
0) { // 对于append模式,需要根据接收与否确定是否要降低下次draft
|
||||
// token的数量
|
||||
auto current_actual_draft_token_num = actual_draft_token_nums[bid];
|
||||
if (accept_num_now - 1 == current_actual_draft_token_num) {
|
||||
if (current_actual_draft_token_num + 2 <=
|
||||
max_draft_tokens - 1) {
|
||||
actual_draft_token_nums[bid] =
|
||||
current_actual_draft_token_num + 2;
|
||||
} else if (current_actual_draft_token_num + 1 <=
|
||||
max_draft_tokens - 1) {
|
||||
actual_draft_token_nums[bid] =
|
||||
current_actual_draft_token_num + 1;
|
||||
} else {
|
||||
actual_draft_token_nums[bid] = max_draft_tokens - 1;
|
||||
}
|
||||
} else {
|
||||
actual_draft_token_nums[bid] =
|
||||
actual_draft_token_nums[bid] - 1 >= 1
|
||||
? actual_draft_token_nums[bid] - 1
|
||||
: 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (seq_lens_encoder[bid] != 0) {
|
||||
seq_lens_decoder[bid] += seq_lens_encoder[bid];
|
||||
seq_lens_encoder[bid] = 0;
|
||||
}
|
||||
draft_tokens[bid * max_draft_tokens] =
|
||||
accept_tokens[bid * max_draft_tokens + accept_num_now - 1];
|
||||
if (stop_flag_now_int) {
|
||||
seq_lens_decoder[bid] = 0;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// printf("stop_sum %d \n", stop_sum);
|
||||
not_need_stop[0] = stop_sum < real_bsz;
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateUpdateV2(const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &actual_draft_token_nums,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &is_block_step) {
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
auto max_draft_tokens = draft_tokens.shape()[1];
|
||||
|
||||
constexpr int BlockSize = 512;
|
||||
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
speculate_update<BlockSize><<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int *>(actual_draft_token_nums.data<int>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
stop_flags.data<bool>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
is_block_step.data<bool>(),
|
||||
real_bsz,
|
||||
max_draft_tokens);
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_update_v2)
|
||||
.Inputs({"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"not_need_stop",
|
||||
"draft_tokens",
|
||||
"actual_draft_token_nums",
|
||||
"accept_tokens",
|
||||
"accept_num",
|
||||
"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"is_block_step"})
|
||||
.Outputs({"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
"not_need_stop_out",
|
||||
"draft_tokens_out",
|
||||
"actual_draft_token_nums_out"})
|
||||
.SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},
|
||||
{"draft_tokens", "draft_tokens_out"},
|
||||
{"actual_draft_token_nums", "actual_draft_token_nums_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateUpdateV2));
|
@@ -349,16 +349,24 @@ class SpeculativeConfig:
|
||||
self,
|
||||
args,
|
||||
):
|
||||
# speculative method, choose in [None, "ngram_match", "mtp"]
|
||||
self.method_list = ["ngram_match", "mtp"]
|
||||
self.mtp_strategy_list = ["default", "with_ngram"]
|
||||
|
||||
# speculative method, choose in [None, "ngram_match", "mtp", "hybrid_mtp_ngram"]
|
||||
self.method: Optional[str] = None
|
||||
# mtp strategy in mtp-method
|
||||
self.mtp_strategy = "default"
|
||||
# the max length of speculative tokens
|
||||
self.num_speculative_tokens: int = 1
|
||||
# the model runner step of draft model/mtp...
|
||||
self.num_model_steps: int = 1
|
||||
# the max length of candidate tokens for speculative method
|
||||
self.max_candidate_len: int = 5
|
||||
# the max length of verify window for speculative method
|
||||
self.verify_window: int = 2
|
||||
# ngram match
|
||||
self.max_ngram_size: int = 5
|
||||
self.min_ngram_size: int = 2
|
||||
# model for mtp/eagle/draft_model
|
||||
self.model: Optional[str] = None
|
||||
# quantization of model
|
||||
@@ -445,6 +453,33 @@ class SpeculativeConfig:
|
||||
logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
logger.info("=============================================================")
|
||||
|
||||
def check_legality_parameters(
|
||||
self,
|
||||
) -> None:
|
||||
"""Check the legality of parameters passed in from the command line"""
|
||||
if self.method is not None:
|
||||
assert (
|
||||
self.method in self.method_list
|
||||
), f"speculative method only support {self.method_list} now, but get {self.method}."
|
||||
|
||||
assert (
|
||||
self.num_speculative_tokens >= 1 and self.num_speculative_tokens <= 5
|
||||
), f"num_speculative_tokens only support in range[1, 5], but get {self.num_speculative_tokens}."
|
||||
assert (
|
||||
self.num_model_steps >= 1 and self.num_model_steps <= 5
|
||||
), f"num_model_steps only support in range[1, 5], but get {self.num_model_steps}."
|
||||
|
||||
if self.method in ["mtp", "hybrid_mtp_ngram"]:
|
||||
if self.num_speculative_tokens < self.num_model_steps:
|
||||
logger.warning(
|
||||
f"Get num_model_steps > num_speculative_tokens. Reset num_speculative_tokens to {self.num_model_steps}"
|
||||
)
|
||||
self.num_speculative_tokens = self.num_model_steps
|
||||
|
||||
assert (
|
||||
self.mtp_strategy in self.mtp_strategy_list
|
||||
), f"mtp_strategy_list only support {self.mtp_strategy_list}, but get {self.mtp_strategy}"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.to_json_string()
|
||||
|
||||
|
@@ -248,6 +248,7 @@ class Ernie4_5_MTPModel(nn.Layer):
|
||||
|
||||
self.num_layers = fd_config.model_config.num_hidden_layers
|
||||
self.embed_tokens = fd_config.speculative_config.sharing_model.ernie.embed_tokens
|
||||
self.norm = fd_config.speculative_config.sharing_model.ernie.norm
|
||||
|
||||
self.layers = nn.LayerList(
|
||||
[
|
||||
@@ -318,6 +319,8 @@ class Ernie4_5_MTPModel(nn.Layer):
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
@@ -68,7 +68,7 @@ else:
|
||||
speculate_set_value_by_flags_and_idx,
|
||||
speculate_step_paddle,
|
||||
speculate_step_system_cache,
|
||||
speculate_update_v3,
|
||||
speculate_update,
|
||||
step_paddle,
|
||||
step_system_cache,
|
||||
update_inputs,
|
||||
@@ -308,7 +308,7 @@ def post_process_normal(
|
||||
|
||||
def post_process_specualate(model_output, save_each_rank: bool = False, skip_save_output: bool = False):
|
||||
""""""
|
||||
speculate_update_v3(
|
||||
speculate_update(
|
||||
model_output.seq_lens_encoder,
|
||||
model_output.seq_lens_decoder,
|
||||
model_output.not_need_stop,
|
||||
|
@@ -252,12 +252,13 @@ class TokenProcessor:
|
||||
|
||||
def _compute_speculative_status(self):
|
||||
# TODO(liuzichang): Supplement more statistics
|
||||
interval = 50
|
||||
interval = 10
|
||||
if self.speculative_stats_step % interval == 0:
|
||||
accept_ratio = 1 - self.total_step * 1.0 / self.number_of_output_tokens
|
||||
spec_logger.info(
|
||||
f"Speculate global accept ratio(Accept draft_tokens/Generated tokens): {accept_ratio}"
|
||||
f" total step: {self.total_step}. total output token num: {self.number_of_output_tokens}"
|
||||
f" avarage accept len: {self.number_of_output_tokens / self.total_step}"
|
||||
)
|
||||
|
||||
if self.cfg.speculative_config.method in ["mtp"]:
|
||||
|
@@ -45,6 +45,10 @@ class Proposer(ABC):
|
||||
self.max_model_len = self.parallel_config.max_model_len
|
||||
self.speculative_method = self.speculative_config.method
|
||||
self.max_draft_token_num = self.speculative_config.num_speculative_tokens
|
||||
self.num_model_steps = self.speculative_config.num_model_steps
|
||||
|
||||
self.max_ngram_size = self.speculative_config.max_ngram_size
|
||||
self.min_ngram_size = self.speculative_config.min_ngram_size
|
||||
|
||||
spec_logger.info(f"Speculate config: {self.speculative_config}")
|
||||
|
||||
|
@@ -35,6 +35,7 @@ from fastdeploy.model_executor.ops.gpu import (
|
||||
draft_model_update,
|
||||
eagle_get_hidden_states,
|
||||
eagle_get_self_hidden_states,
|
||||
hybrid_mtp_ngram,
|
||||
mtp_save_first_token,
|
||||
mtp_step_paddle,
|
||||
share_external_data,
|
||||
@@ -57,6 +58,8 @@ class MTPProposer(Proposer):
|
||||
self._update_cfg(main_model)
|
||||
self._load_model()
|
||||
self.main_model_inputs = main_model_inputs
|
||||
self.mtp_strategy = self.speculative_config.mtp_strategy
|
||||
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
|
||||
|
||||
# [mixed, prefill, decoder]
|
||||
self.role = "mixed"
|
||||
@@ -336,10 +339,11 @@ class MTPProposer(Proposer):
|
||||
|
||||
self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool")
|
||||
self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32")
|
||||
if self.max_draft_token_num > 1:
|
||||
if self.num_model_steps > 1:
|
||||
self.last_seq_lens_this_time = paddle.full_like(
|
||||
self.main_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
|
||||
)
|
||||
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
|
||||
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
|
||||
"""
|
||||
@@ -364,6 +368,7 @@ class MTPProposer(Proposer):
|
||||
request = req_dicts[i]
|
||||
idx = request.idx
|
||||
length = len(request.prompt_token_ids)
|
||||
self.input_ids_len[idx] = length
|
||||
|
||||
if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode":
|
||||
length = len(request.prompt_token_ids)
|
||||
@@ -460,6 +465,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["step_idx"],
|
||||
self.model_inputs["not_need_stop"],
|
||||
self.model_inputs["batch_drop"],
|
||||
self.model_inputs["pre_ids"],
|
||||
self.main_model_inputs["accept_tokens"],
|
||||
self.main_model_inputs["accept_num"],
|
||||
self.main_model_inputs["seq_lens_this_time"],
|
||||
@@ -469,7 +475,7 @@ class MTPProposer(Proposer):
|
||||
self.main_model_inputs["stop_flags"],
|
||||
self.main_model_inputs["is_block_step"],
|
||||
self.main_model_inputs["draft_tokens"],
|
||||
self.max_draft_token_num,
|
||||
self.num_model_steps,
|
||||
self.speculative_method in ["eagle", "mtp"],
|
||||
self.role == "prefill",
|
||||
)
|
||||
@@ -483,7 +489,7 @@ class MTPProposer(Proposer):
|
||||
self.main_model_inputs["accept_num"],
|
||||
self.main_model_inputs["seq_lens_this_time"],
|
||||
self.main_model_inputs["seq_lens_encoder"],
|
||||
self.max_draft_token_num,
|
||||
self.num_model_steps,
|
||||
)
|
||||
if isinstance(target_hidden_states, list):
|
||||
target_hidden_states = target_hidden_states[0]
|
||||
@@ -523,7 +529,7 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
Main process for MTP inference
|
||||
"""
|
||||
for substep in range(self.max_draft_token_num):
|
||||
for substep in range(self.num_model_steps):
|
||||
if self.model_inputs["not_need_stop"]:
|
||||
self.model_inputs["substep"] = substep
|
||||
# Remove padding
|
||||
@@ -542,6 +548,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
)
|
||||
|
||||
# Initialize forward meta data
|
||||
self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
|
||||
self.model_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
|
||||
@@ -567,7 +574,7 @@ class MTPProposer(Proposer):
|
||||
eos_token_ids=self.model_inputs["eos_token_id"],
|
||||
)
|
||||
|
||||
if self.max_draft_token_num > 1:
|
||||
if self.num_model_steps > 1:
|
||||
self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"])
|
||||
|
||||
model_output = self.model(
|
||||
@@ -601,7 +608,7 @@ class MTPProposer(Proposer):
|
||||
|
||||
self._post_process(sampled_token_ids)
|
||||
|
||||
if substep != self.max_draft_token_num - 1:
|
||||
if substep != self.num_model_steps - 1:
|
||||
target_hidden_states = self._get_self_hidden_states(hidden_states)
|
||||
|
||||
def _get_self_hidden_states(self, hidden_states):
|
||||
@@ -673,11 +680,37 @@ class MTPProposer(Proposer):
|
||||
self.max_draft_token_num,
|
||||
)
|
||||
|
||||
def _extend_draft_token_with_ngram_match(self):
|
||||
# TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency
|
||||
device = paddle.CUDAPinnedPlace()
|
||||
|
||||
draft_tokens = self.main_model_inputs["draft_tokens"].cpu()
|
||||
seq_lens_this_time = self.main_model_inputs["seq_lens_this_time"].cpu()
|
||||
seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu()
|
||||
hybrid_mtp_ngram(
|
||||
self.model_inputs["input_ids"]._copy_to(device, True),
|
||||
self.input_ids_len,
|
||||
self.model_inputs["pre_ids"]._copy_to(device, True),
|
||||
self.model_inputs["step_idx"].cpu(),
|
||||
self.main_model_inputs["actual_draft_token_num"].cpu(),
|
||||
draft_tokens,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
self.model_inputs["max_dec_len"].cpu(),
|
||||
self.max_ngram_size,
|
||||
self.min_ngram_size,
|
||||
self.max_draft_token_num,
|
||||
)
|
||||
self.main_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
|
||||
self.main_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
|
||||
|
||||
def _run_impl(self, full_hidden_states):
|
||||
""""""
|
||||
target_hidden_states = self._prepare_inputs(full_hidden_states)
|
||||
self._propose(target_hidden_states=target_hidden_states)
|
||||
self._update_status()
|
||||
if self.hybrid_mode:
|
||||
self._extend_draft_token_with_ngram_match()
|
||||
|
||||
def is_chunk_prefill_enabled(self):
|
||||
""""""
|
||||
|
75
tests/operators/test_hybrid_mtp_ngram.py
Normal file
75
tests/operators/test_hybrid_mtp_ngram.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import hybrid_mtp_ngram
|
||||
|
||||
|
||||
class TestNgramMatchMixed(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.max_bsz = 2
|
||||
self.max_draft_tokens = 5
|
||||
self.max_len = 32
|
||||
self.max_dec_len = 10
|
||||
self.max_ngram_size = 5
|
||||
self.min_ngram_size = 2
|
||||
|
||||
# 初始化输入 tensor
|
||||
self.input_ids = paddle.full(shape=[self.max_bsz, self.max_len], fill_value=-1, dtype="int64").cpu()
|
||||
self.input_ids_len = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int64").cpu()
|
||||
self.pre_ids = paddle.full(shape=[self.max_bsz, self.max_len], fill_value=-1, dtype="int64").cpu()
|
||||
self.step_idx = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int64").cpu()
|
||||
self.draft_token_num = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int32").cpu()
|
||||
self.draft_tokens = paddle.full(
|
||||
shape=[self.max_bsz, self.max_draft_tokens + 1],
|
||||
fill_value=-1,
|
||||
dtype="int64",
|
||||
).cpu()
|
||||
self.seq_lens_this_time = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int32").cpu()
|
||||
self.seq_lens_decoder = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int32").cpu()
|
||||
self.max_dec_len = paddle.full(
|
||||
shape=[self.max_bsz, 1],
|
||||
fill_value=self.max_dec_len,
|
||||
dtype="int64",
|
||||
).cpu()
|
||||
|
||||
# 设置具体数据
|
||||
self.input_ids[:, :10] = np.arange(0, 10)
|
||||
self.input_ids_len[:] = 10
|
||||
pre_ids_np = np.array([10, 9, 8, 7, 6, 10, 9, 8, 7], dtype="int32")
|
||||
self.pre_ids[:, : pre_ids_np.shape[0]] = pre_ids_np
|
||||
self.step_idx[:] = 8
|
||||
|
||||
self.draft_token_num[:] = 5
|
||||
self.draft_tokens[:, :2] = np.array([8, 7])
|
||||
self.seq_lens_this_time[:] = 2
|
||||
self.seq_lens_decoder[:] = 12
|
||||
self.max_dec_len[:] = 512
|
||||
|
||||
# 期望结果
|
||||
self.ref_seq_lens_this_time = np.array([[6], [6]], dtype="int32")
|
||||
self.ref_draft_tokens = np.array([[8, 7, 6, 10, 9, 8], [8, 7, 6, 10, 9, 8]], dtype="int64")
|
||||
|
||||
def test_ngram_match_mixed(self):
|
||||
hybrid_mtp_ngram(
|
||||
self.input_ids,
|
||||
self.input_ids_len,
|
||||
self.pre_ids,
|
||||
self.step_idx,
|
||||
self.draft_token_num,
|
||||
self.draft_tokens,
|
||||
self.seq_lens_this_time,
|
||||
self.seq_lens_decoder,
|
||||
self.max_dec_len,
|
||||
self.max_ngram_size,
|
||||
self.min_ngram_size,
|
||||
self.max_draft_tokens,
|
||||
)
|
||||
|
||||
np.testing.assert_allclose(self.seq_lens_this_time.numpy(), self.ref_seq_lens_this_time)
|
||||
np.testing.assert_allclose(self.draft_tokens.numpy(), self.ref_draft_tokens)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user