mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-01 12:22:53 +08:00
[MTP]support mtp chunk_prefill_v1 (#4365)
* support mtp chunk_prefill_v1 * fix mtp chunkprefill output * fix mtp chunkprefill output, fix unit test * fix save_output --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -702,8 +702,11 @@ void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
|||||||
void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
|
void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
|
||||||
const paddle::Tensor& accept_num,
|
const paddle::Tensor& accept_num,
|
||||||
const paddle::Tensor& not_need_stop,
|
const paddle::Tensor& not_need_stop,
|
||||||
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
|
const paddle::Tensor& prompt_lens,
|
||||||
int64_t rank_id,
|
int64_t rank_id,
|
||||||
bool save_each_rank);
|
bool save_each_rank,
|
||||||
|
bool skip_prefill);
|
||||||
|
|
||||||
|
|
||||||
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
|
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
|
||||||
@@ -712,7 +715,9 @@ void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
|
|||||||
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||||
const paddle::Tensor &block_tables,
|
const paddle::Tensor &block_tables,
|
||||||
const paddle::Tensor &stop_flags,
|
const paddle::Tensor &stop_flags,
|
||||||
|
const paddle::Tensor &prompt_lens,
|
||||||
const paddle::Tensor &seq_lens_this_time,
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
|
const paddle::Tensor &seq_lens_encoder,
|
||||||
const paddle::Tensor &seq_lens_decoder,
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
const paddle::Tensor &step_seq_lens_decoder,
|
const paddle::Tensor &step_seq_lens_decoder,
|
||||||
const paddle::Tensor &step_draft_tokens,
|
const paddle::Tensor &step_draft_tokens,
|
||||||
|
|||||||
@@ -28,9 +28,12 @@
|
|||||||
void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
|
void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
|
||||||
const paddle::Tensor& accept_num,
|
const paddle::Tensor& accept_num,
|
||||||
const paddle::Tensor& not_need_stop,
|
const paddle::Tensor& not_need_stop,
|
||||||
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
|
const paddle::Tensor& prompt_lens,
|
||||||
int64_t rank_id,
|
int64_t rank_id,
|
||||||
int msg_queue_id,
|
int msg_queue_id,
|
||||||
int save_each_rank) {
|
int save_each_rank,
|
||||||
|
bool skip_prefill) {
|
||||||
// printf("enter save output");
|
// printf("enter save output");
|
||||||
if (!save_each_rank && rank_id > 0) {
|
if (!save_each_rank && rank_id > 0) {
|
||||||
return;
|
return;
|
||||||
@@ -43,6 +46,11 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
|
|||||||
int64_t* accept_tokens_data = accept_tokens_cpu.data<int64_t>();
|
int64_t* accept_tokens_data = accept_tokens_cpu.data<int64_t>();
|
||||||
int* accept_num_data = accept_num_cpu.data<int>();
|
int* accept_num_data = accept_num_cpu.data<int>();
|
||||||
|
|
||||||
|
auto seq_lens_decoder_cpu = seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
|
||||||
|
auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true);
|
||||||
|
int* seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
|
||||||
|
int64_t* prompt_lens_data = prompt_lens_cpu.data<int64_t>();
|
||||||
|
|
||||||
if (const char* inference_msg_queue_id_env_p =
|
if (const char* inference_msg_queue_id_env_p =
|
||||||
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
|
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
|
||||||
std::string inference_msg_queue_id_env_str(
|
std::string inference_msg_queue_id_env_str(
|
||||||
@@ -95,7 +103,7 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
|
|||||||
msg_sed.mtext[1] = bsz;
|
msg_sed.mtext[1] = bsz;
|
||||||
|
|
||||||
for (int i = 2; i < MAX_BSZ + 2; i++) {
|
for (int i = 2; i < MAX_BSZ + 2; i++) {
|
||||||
if (i - 2 >= bsz) {
|
if (i - 2 >= bsz || (skip_prefill && seq_lens_decoder_data[i - 2] < prompt_lens_data[i - 2])) {
|
||||||
msg_sed.mtext[i] = 0;
|
msg_sed.mtext[i] = 0;
|
||||||
} else {
|
} else {
|
||||||
msg_sed.mtext[i] = (int)accept_num_data[i - 2];
|
msg_sed.mtext[i] = (int)accept_num_data[i - 2];
|
||||||
@@ -125,32 +133,38 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
|
|||||||
void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
|
void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
|
||||||
const paddle::Tensor& accept_num,
|
const paddle::Tensor& accept_num,
|
||||||
const paddle::Tensor& not_need_stop,
|
const paddle::Tensor& not_need_stop,
|
||||||
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
|
const paddle::Tensor& prompt_lens,
|
||||||
int64_t rank_id,
|
int64_t rank_id,
|
||||||
bool save_each_rank) {
|
bool save_each_rank,
|
||||||
|
bool skip_prefill) {
|
||||||
SpeculateSaveWithOutputMsg(
|
SpeculateSaveWithOutputMsg(
|
||||||
accept_tokens, accept_num, not_need_stop, rank_id, 1, save_each_rank);
|
accept_tokens, accept_num, not_need_stop, seq_lens_decoder, prompt_lens, rank_id, 1, save_each_rank, skip_prefill);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpeculateSaveWithOutputMsgDynamic(const paddle::Tensor& accept_tokens,
|
void SpeculateSaveWithOutputMsgDynamic(const paddle::Tensor& accept_tokens,
|
||||||
const paddle::Tensor& accept_num,
|
const paddle::Tensor& accept_num,
|
||||||
const paddle::Tensor& not_need_stop,
|
const paddle::Tensor& not_need_stop,
|
||||||
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
|
const paddle::Tensor& prompt_lens,
|
||||||
int64_t rank_id,
|
int64_t rank_id,
|
||||||
int msg_queue_id,
|
int msg_queue_id,
|
||||||
bool save_each_rank) {
|
bool save_each_rank,
|
||||||
|
bool skip_prefill) {
|
||||||
SpeculateSaveWithOutputMsg(
|
SpeculateSaveWithOutputMsg(
|
||||||
accept_tokens, accept_num, not_need_stop, rank_id, msg_queue_id, save_each_rank);
|
accept_tokens, accept_num, not_need_stop, seq_lens_decoder, prompt_lens, rank_id, msg_queue_id, save_each_rank, skip_prefill);
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(speculate_save_output)
|
PD_BUILD_STATIC_OP(speculate_save_output)
|
||||||
.Inputs({"accept_tokens", "accept_num", "not_need_stop"})
|
.Inputs({"accept_tokens", "accept_num", "not_need_stop", "seq_lens_decoder", "prompt_lens"})
|
||||||
.Attrs({"rank_id: int64_t", "save_each_rank: bool"})
|
.Attrs({"rank_id: int64_t", "save_each_rank: bool", "skip_prefill: bool"})
|
||||||
.Outputs({"x_out"})
|
.Outputs({"x_out"})
|
||||||
.SetInplaceMap({{"accept_tokens", "x_out"}})
|
.SetInplaceMap({{"accept_tokens", "x_out"}})
|
||||||
.SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgStatic));
|
.SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgStatic));
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(speculate_save_output_dynamic)
|
PD_BUILD_STATIC_OP(speculate_save_output_dynamic)
|
||||||
.Inputs({"accept_tokens", "accept_num", "not_need_stop"})
|
.Inputs({"accept_tokens", "accept_num", "not_need_stop", "seq_lens_decoder", "prompt_lens"})
|
||||||
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
|
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool", "skip_prefill: bool"})
|
||||||
.Outputs({"x_out"})
|
.Outputs({"x_out"})
|
||||||
.SetInplaceMap({{"accept_tokens", "x_out"}})
|
.SetInplaceMap({{"accept_tokens", "x_out"}})
|
||||||
.SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgDynamic));
|
.SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgDynamic));
|
||||||
|
|||||||
@@ -19,7 +19,9 @@ __global__ void speculate_schedula_cache(
|
|||||||
const int64_t *draft_tokens,
|
const int64_t *draft_tokens,
|
||||||
int *block_tables,
|
int *block_tables,
|
||||||
bool *stop_flags,
|
bool *stop_flags,
|
||||||
|
const int64_t* prompt_lens,
|
||||||
int *seq_lens_this_time,
|
int *seq_lens_this_time,
|
||||||
|
int *seq_lens_encoder,
|
||||||
int *seq_lens_decoder,
|
int *seq_lens_decoder,
|
||||||
int *step_seq_lens_decoder,
|
int *step_seq_lens_decoder,
|
||||||
int64_t *step_draft_tokens,
|
int64_t *step_draft_tokens,
|
||||||
@@ -44,23 +46,37 @@ __global__ void speculate_schedula_cache(
|
|||||||
int64_t *step_draft_tokens_now = step_draft_tokens + bid * draft_tokens_len;
|
int64_t *step_draft_tokens_now = step_draft_tokens + bid * draft_tokens_len;
|
||||||
int *block_table_now = block_tables + bid * block_num_per_seq;
|
int *block_table_now = block_tables + bid * block_num_per_seq;
|
||||||
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
|
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
|
||||||
const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
|
|
||||||
if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) {
|
if (seq_lens_decoder[bid] >= prompt_lens[bid]) {
|
||||||
is_block_step[bid] = true;
|
// decoder
|
||||||
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
|
const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
|
||||||
seq_lens_this_time[bid] = 0;
|
if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) {
|
||||||
|
is_block_step[bid] = true;
|
||||||
|
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
|
||||||
|
seq_lens_this_time[bid] = 0;
|
||||||
|
stop_flags[bid] = true;
|
||||||
|
stop_flag_now_int = 1;
|
||||||
|
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
|
||||||
|
seq_lens_decoder[bid] = 0;
|
||||||
|
accept_num[bid] = 0;
|
||||||
|
for (int i = 0; i < accept_tokens_len; i++) {
|
||||||
|
accept_tokens_now[i] = -1;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < draft_tokens_len; i++) {
|
||||||
|
step_draft_tokens_now[i] = draft_tokens_now[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// prefill
|
||||||
stop_flags[bid] = true;
|
stop_flags[bid] = true;
|
||||||
stop_flag_now_int = 1;
|
seq_lens_this_time[bid] = 0;
|
||||||
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
|
|
||||||
seq_lens_decoder[bid] = 0;
|
seq_lens_decoder[bid] = 0;
|
||||||
|
seq_lens_encoder[bid] = 0;
|
||||||
accept_num[bid] = 0;
|
accept_num[bid] = 0;
|
||||||
for (int i = 0; i < accept_tokens_len; i++) {
|
stop_flag_now_int = 1;
|
||||||
accept_tokens_now[i] = -1;
|
|
||||||
}
|
|
||||||
for (int i = 0; i < draft_tokens_len; i++) {
|
|
||||||
step_draft_tokens_now[i] = draft_tokens_now[i];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
stop_flag_now_int = 1;
|
stop_flag_now_int = 1;
|
||||||
}
|
}
|
||||||
@@ -83,7 +99,9 @@ __global__ void speculate_schedula_cache(
|
|||||||
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||||
const paddle::Tensor &block_tables,
|
const paddle::Tensor &block_tables,
|
||||||
const paddle::Tensor &stop_flags,
|
const paddle::Tensor &stop_flags,
|
||||||
|
const paddle::Tensor &prompt_lens,
|
||||||
const paddle::Tensor &seq_lens_this_time,
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
|
const paddle::Tensor &seq_lens_encoder,
|
||||||
const paddle::Tensor &seq_lens_decoder,
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
const paddle::Tensor &step_seq_lens_decoder,
|
const paddle::Tensor &step_seq_lens_decoder,
|
||||||
const paddle::Tensor &step_draft_tokens,
|
const paddle::Tensor &step_draft_tokens,
|
||||||
@@ -109,7 +127,9 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
|||||||
draft_tokens.data<int64_t>(),
|
draft_tokens.data<int64_t>(),
|
||||||
const_cast<int *>(block_tables.data<int>()),
|
const_cast<int *>(block_tables.data<int>()),
|
||||||
const_cast<bool *>(stop_flags.data<bool>()),
|
const_cast<bool *>(stop_flags.data<bool>()),
|
||||||
|
prompt_lens.data<int64_t>(),
|
||||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||||
|
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||||
const_cast<int64_t *>(step_draft_tokens.data<int64_t>()),
|
const_cast<int64_t *>(step_draft_tokens.data<int64_t>()),
|
||||||
@@ -138,7 +158,9 @@ PD_BUILD_STATIC_OP(speculate_schedule_cache)
|
|||||||
.Inputs({"draft_tokens",
|
.Inputs({"draft_tokens",
|
||||||
"block_tables",
|
"block_tables",
|
||||||
"stop_flags",
|
"stop_flags",
|
||||||
|
"prompt_lens",
|
||||||
"seq_lens_this_time",
|
"seq_lens_this_time",
|
||||||
|
"seq_lens_encoder",
|
||||||
"seq_lens_decoder",
|
"seq_lens_decoder",
|
||||||
"step_seq_lens_decoder",
|
"step_seq_lens_decoder",
|
||||||
"step_draft_tokens",
|
"step_draft_tokens",
|
||||||
@@ -153,6 +175,7 @@ PD_BUILD_STATIC_OP(speculate_schedule_cache)
|
|||||||
"block_tables_out",
|
"block_tables_out",
|
||||||
"stop_flags_out",
|
"stop_flags_out",
|
||||||
"seq_lens_this_time_out",
|
"seq_lens_this_time_out",
|
||||||
|
"seq_lens_encoder_out",
|
||||||
"seq_lens_decoder_out",
|
"seq_lens_decoder_out",
|
||||||
"step_seq_lens_decoder_out",
|
"step_seq_lens_decoder_out",
|
||||||
"step_draft_tokens_out",
|
"step_draft_tokens_out",
|
||||||
@@ -165,6 +188,7 @@ PD_BUILD_STATIC_OP(speculate_schedule_cache)
|
|||||||
{"block_tables", "block_tables_out"},
|
{"block_tables", "block_tables_out"},
|
||||||
{"stop_flags", "stop_flags_out"},
|
{"stop_flags", "stop_flags_out"},
|
||||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||||
|
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||||
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
|
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
|
||||||
{"step_draft_tokens", "step_draft_tokens_out"},
|
{"step_draft_tokens", "step_draft_tokens_out"},
|
||||||
|
|||||||
@@ -20,30 +20,32 @@
|
|||||||
|
|
||||||
__global__ void speculate_set_value_by_flag_and_id(int64_t *pre_ids_all,
|
__global__ void speculate_set_value_by_flag_and_id(int64_t *pre_ids_all,
|
||||||
const int64_t *accept_tokens,
|
const int64_t *accept_tokens,
|
||||||
const int *accept_num,
|
int *accept_num,
|
||||||
const bool *stop_flags,
|
const bool *stop_flags,
|
||||||
const int *seq_lens_encoder,
|
const int *seq_lens_encoder,
|
||||||
const int *seq_lens_decoder,
|
int *seq_lens_decoder,
|
||||||
const int64_t *step_idx,
|
const int64_t *step_idx,
|
||||||
int bs,
|
int bs,
|
||||||
int length,
|
int length,
|
||||||
int max_draft_tokens) {
|
int max_draft_tokens) {
|
||||||
int tid = threadIdx.x;
|
int tid = threadIdx.x;
|
||||||
if (tid < bs && !stop_flags[tid]) {
|
if (tid < bs) {
|
||||||
int64_t *pre_ids_all_now = pre_ids_all + tid * length;
|
if (!stop_flags[tid]) {
|
||||||
const int64_t *accept_tokens_now =
|
int64_t *pre_ids_all_now = pre_ids_all + tid * length;
|
||||||
accept_tokens + tid * max_draft_tokens;
|
const int64_t *accept_tokens_now =
|
||||||
const int seq_len_dec = seq_lens_decoder[tid];
|
accept_tokens + tid * max_draft_tokens;
|
||||||
const int seq_len_enc = seq_lens_encoder[tid];
|
const int seq_len_dec = seq_lens_decoder[tid];
|
||||||
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped
|
const int seq_len_enc = seq_lens_encoder[tid];
|
||||||
// printf("step_idx[tid] %d\n", step_idx[tid]);
|
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped
|
||||||
if (step_idx[tid] >= 0) {
|
if (step_idx[tid] >= 0) {
|
||||||
for (int i = 0; i < accept_num[tid]; i++) {
|
for (int i = 0; i < accept_num[tid]; i++) {
|
||||||
pre_ids_all_now[step_idx[tid] - i] =
|
pre_ids_all_now[step_idx[tid] - i] =
|
||||||
accept_tokens_now[accept_num[tid] - 1 - i];
|
accept_tokens_now[accept_num[tid] - 1 - i];
|
||||||
// printf("pre_ids_all_now[step_idx[tid] - i] %d \n",
|
}
|
||||||
// pre_ids_all_now[step_idx[tid] - i]);
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
accept_num[tid] = 0;
|
||||||
|
seq_lens_decoder[tid] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -67,10 +69,10 @@ void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
|||||||
speculate_set_value_by_flag_and_id<<<1, block_size, 0, cu_stream>>>(
|
speculate_set_value_by_flag_and_id<<<1, block_size, 0, cu_stream>>>(
|
||||||
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
|
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
|
||||||
accept_tokens.data<int64_t>(),
|
accept_tokens.data<int64_t>(),
|
||||||
accept_num.data<int>(),
|
const_cast<int*>(accept_num.data<int>()),
|
||||||
stop_flags.data<bool>(),
|
stop_flags.data<bool>(),
|
||||||
seq_lens_encoder.data<int>(),
|
seq_lens_encoder.data<int>(),
|
||||||
seq_lens_decoder.data<int>(),
|
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||||
step_idx.data<int64_t>(),
|
step_idx.data<int64_t>(),
|
||||||
bs,
|
bs,
|
||||||
length,
|
length,
|
||||||
@@ -86,6 +88,9 @@ PD_BUILD_STATIC_OP(speculate_set_value_by_flags_and_idx)
|
|||||||
"seq_lens_encoder",
|
"seq_lens_encoder",
|
||||||
"seq_lens_decoder",
|
"seq_lens_decoder",
|
||||||
"step_idx"})
|
"step_idx"})
|
||||||
.Outputs({"pre_ids_all_out"})
|
.Outputs({"pre_ids_all_out", "accept_num_out", "seq_lens_decoder_out"})
|
||||||
.SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}})
|
.SetInplaceMap({
|
||||||
|
{"pre_ids_all", "pre_ids_all_out"},
|
||||||
|
{"accept_num", "accept_num_out"},
|
||||||
|
{"seq_lens_decoder", "seq_lens_decoder_out"}})
|
||||||
.SetKernelFn(PD_KERNEL(SpeculateSetValueByFlagsAndIdx));
|
.SetKernelFn(PD_KERNEL(SpeculateSetValueByFlagsAndIdx));
|
||||||
|
|||||||
@@ -71,9 +71,6 @@ __global__ void speculate_update(int *seq_lens_encoder,
|
|||||||
}
|
}
|
||||||
draft_tokens[bid * max_draft_tokens] =
|
draft_tokens[bid * max_draft_tokens] =
|
||||||
accept_tokens[bid * max_draft_tokens + accept_num_now - 1];
|
accept_tokens[bid * max_draft_tokens + accept_num_now - 1];
|
||||||
if (stop_flag_now_int) {
|
|
||||||
seq_lens_decoder[bid] = 0;
|
|
||||||
}
|
|
||||||
} else if (bid >= real_bsz && bid < max_bsz) {
|
} else if (bid >= real_bsz && bid < max_bsz) {
|
||||||
stop_flag_now_int = 1;
|
stop_flag_now_int = 1;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1011,11 +1011,7 @@ class EngineArgs:
|
|||||||
|
|
||||||
speculative_cfg = self.create_speculative_config()
|
speculative_cfg = self.create_speculative_config()
|
||||||
if not self.enable_chunked_prefill:
|
if not self.enable_chunked_prefill:
|
||||||
if (
|
if current_platform.is_cuda() and self.splitwise_role == "mixed":
|
||||||
current_platform.is_cuda()
|
|
||||||
and self.splitwise_role == "mixed"
|
|
||||||
and (speculative_cfg is None or speculative_cfg.method not in ["mtp"])
|
|
||||||
):
|
|
||||||
# default enable chunked prefill
|
# default enable chunked prefill
|
||||||
self.enable_chunked_prefill = True
|
self.enable_chunked_prefill = True
|
||||||
|
|
||||||
@@ -1028,10 +1024,7 @@ class EngineArgs:
|
|||||||
if paddle.is_compiled_with_xpu():
|
if paddle.is_compiled_with_xpu():
|
||||||
self.max_num_batched_tokens = self.max_model_len
|
self.max_num_batched_tokens = self.max_model_len
|
||||||
else:
|
else:
|
||||||
if speculative_cfg is not None and speculative_cfg.method is not None:
|
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
|
||||||
self.max_num_batched_tokens = self.max_model_len
|
|
||||||
else:
|
|
||||||
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
|
|
||||||
else:
|
else:
|
||||||
if self.enable_chunked_prefill:
|
if self.enable_chunked_prefill:
|
||||||
self.max_num_batched_tokens = 2048
|
self.max_num_batched_tokens = 2048
|
||||||
|
|||||||
@@ -60,7 +60,6 @@ else:
|
|||||||
save_output,
|
save_output,
|
||||||
save_output_topk,
|
save_output_topk,
|
||||||
set_stop_value_multi_ends,
|
set_stop_value_multi_ends,
|
||||||
speculate_clear_accept_nums,
|
|
||||||
speculate_get_output_padding_offset,
|
speculate_get_output_padding_offset,
|
||||||
speculate_get_padding_offset,
|
speculate_get_padding_offset,
|
||||||
speculate_get_seq_lens_output,
|
speculate_get_seq_lens_output,
|
||||||
@@ -329,12 +328,13 @@ def post_process_specualate(
|
|||||||
model_output.accept_tokens,
|
model_output.accept_tokens,
|
||||||
model_output.accept_num,
|
model_output.accept_num,
|
||||||
model_output.not_need_stop,
|
model_output.not_need_stop,
|
||||||
|
model_output.seq_lens_decoder,
|
||||||
|
model_output.prompt_lens,
|
||||||
model_output.mp_rank,
|
model_output.mp_rank,
|
||||||
save_each_rank,
|
save_each_rank,
|
||||||
|
envs.ENABLE_V1_KVCACHE_SCHEDULER,
|
||||||
)
|
)
|
||||||
|
|
||||||
speculate_clear_accept_nums(model_output.accept_num, model_output.seq_lens_decoder)
|
|
||||||
|
|
||||||
# Update pre_ids through accept tokens
|
# Update pre_ids through accept tokens
|
||||||
|
|
||||||
speculate_set_value_by_flags_and_idx(
|
speculate_set_value_by_flags_and_idx(
|
||||||
|
|||||||
@@ -1239,6 +1239,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
|
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
|
||||||
stop_token_ids=self.share_inputs["stop_seqs"],
|
stop_token_ids=self.share_inputs["stop_seqs"],
|
||||||
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
||||||
|
prompt_lens=self.share_inputs["prompt_lens"],
|
||||||
)
|
)
|
||||||
|
|
||||||
post_process(
|
post_process(
|
||||||
@@ -1579,6 +1580,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None),
|
reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None),
|
||||||
stop_token_ids=self.share_inputs["stop_seqs"],
|
stop_token_ids=self.share_inputs["stop_seqs"],
|
||||||
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
||||||
|
prompt_lens=self.share_inputs["prompt_lens"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill":
|
if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill":
|
||||||
@@ -1622,7 +1624,9 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["draft_tokens"],
|
self.share_inputs["draft_tokens"],
|
||||||
self.share_inputs["block_tables"],
|
self.share_inputs["block_tables"],
|
||||||
self.share_inputs["stop_flags"],
|
self.share_inputs["stop_flags"],
|
||||||
|
self.share_inputs["prompt_lens"],
|
||||||
self.share_inputs["seq_lens_this_time"],
|
self.share_inputs["seq_lens_this_time"],
|
||||||
|
self.share_inputs["seq_lens_encoder"],
|
||||||
self.share_inputs["seq_lens_decoder"],
|
self.share_inputs["seq_lens_decoder"],
|
||||||
self.share_inputs["step_seq_lens_decoder"],
|
self.share_inputs["step_seq_lens_decoder"],
|
||||||
self.share_inputs["step_draft_tokens"],
|
self.share_inputs["step_draft_tokens"],
|
||||||
|
|||||||
@@ -250,6 +250,11 @@ class ModelOutputData:
|
|||||||
"""
|
"""
|
||||||
stop_seqs_len: paddle.Tensor = None
|
stop_seqs_len: paddle.Tensor = None
|
||||||
|
|
||||||
|
"""
|
||||||
|
the length of input prompt
|
||||||
|
"""
|
||||||
|
prompt_lens: paddle.Tensor = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelRunnerOutput:
|
class ModelRunnerOutput:
|
||||||
|
|||||||
@@ -10,7 +10,9 @@ def cpu_reference(
|
|||||||
draft_tokens,
|
draft_tokens,
|
||||||
block_tables,
|
block_tables,
|
||||||
stop_flags,
|
stop_flags,
|
||||||
|
prompt_lens,
|
||||||
seq_lens_this_time,
|
seq_lens_this_time,
|
||||||
|
seq_lens_encoder,
|
||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
step_seq_lens_decoder,
|
step_seq_lens_decoder,
|
||||||
step_draft_tokens,
|
step_draft_tokens,
|
||||||
@@ -100,7 +102,9 @@ class TestSpeculateScheduleCache(unittest.TestCase):
|
|||||||
self.block_tables = paddle.to_tensor(np.full((self.real_bsz, self.block_num_per_seq), -1, dtype=np.int32))
|
self.block_tables = paddle.to_tensor(np.full((self.real_bsz, self.block_num_per_seq), -1, dtype=np.int32))
|
||||||
# stop_flags length is max_bsz, others are real_bsz
|
# stop_flags length is max_bsz, others are real_bsz
|
||||||
self.stop_flags = paddle.to_tensor(np.array([False, True, False, False, False], dtype=np.bool_))
|
self.stop_flags = paddle.to_tensor(np.array([False, True, False, False, False], dtype=np.bool_))
|
||||||
|
self.prompt_lens = paddle.to_tensor(np.array([1, 1, 1], dtype=np.int64))
|
||||||
self.seq_lens_this_time = paddle.to_tensor(np.array([5, 6, 7], dtype=np.int32))
|
self.seq_lens_this_time = paddle.to_tensor(np.array([5, 6, 7], dtype=np.int32))
|
||||||
|
self.seq_lens_encoder = paddle.to_tensor(np.array([1, 1, 1], dtype=np.int32))
|
||||||
self.seq_lens_decoder = paddle.to_tensor(np.array([1, 1, 10], dtype=np.int32))
|
self.seq_lens_decoder = paddle.to_tensor(np.array([1, 1, 10], dtype=np.int32))
|
||||||
|
|
||||||
# Will be filled by kernel for the triggering bids only
|
# Will be filled by kernel for the triggering bids only
|
||||||
@@ -128,7 +132,9 @@ class TestSpeculateScheduleCache(unittest.TestCase):
|
|||||||
self.np_draft_tokens = self.draft_tokens.numpy().copy()
|
self.np_draft_tokens = self.draft_tokens.numpy().copy()
|
||||||
self.np_block_tables = self.block_tables.numpy().copy()
|
self.np_block_tables = self.block_tables.numpy().copy()
|
||||||
self.np_stop_flags = self.stop_flags.numpy().copy()
|
self.np_stop_flags = self.stop_flags.numpy().copy()
|
||||||
|
self.np_prompt_lens = self.prompt_lens.numpy().copy()
|
||||||
self.np_seq_lens_this_time = self.seq_lens_this_time.numpy().copy()
|
self.np_seq_lens_this_time = self.seq_lens_this_time.numpy().copy()
|
||||||
|
self.np_seq_lens_encoder = self.seq_lens_encoder.numpy().copy()
|
||||||
self.np_seq_lens_decoder = self.seq_lens_decoder.numpy().copy()
|
self.np_seq_lens_decoder = self.seq_lens_decoder.numpy().copy()
|
||||||
self.np_step_seq_lens_decoder = self.step_seq_lens_decoder.numpy().copy()
|
self.np_step_seq_lens_decoder = self.step_seq_lens_decoder.numpy().copy()
|
||||||
self.np_step_draft_tokens = self.step_draft_tokens.numpy().copy()
|
self.np_step_draft_tokens = self.step_draft_tokens.numpy().copy()
|
||||||
@@ -145,7 +151,9 @@ class TestSpeculateScheduleCache(unittest.TestCase):
|
|||||||
self.draft_tokens,
|
self.draft_tokens,
|
||||||
self.block_tables,
|
self.block_tables,
|
||||||
self.stop_flags,
|
self.stop_flags,
|
||||||
|
self.prompt_lens,
|
||||||
self.seq_lens_this_time,
|
self.seq_lens_this_time,
|
||||||
|
self.seq_lens_encoder,
|
||||||
self.seq_lens_decoder,
|
self.seq_lens_decoder,
|
||||||
self.step_seq_lens_decoder,
|
self.step_seq_lens_decoder,
|
||||||
self.step_draft_tokens,
|
self.step_draft_tokens,
|
||||||
@@ -164,7 +172,9 @@ class TestSpeculateScheduleCache(unittest.TestCase):
|
|||||||
self.np_draft_tokens,
|
self.np_draft_tokens,
|
||||||
self.np_block_tables,
|
self.np_block_tables,
|
||||||
self.np_stop_flags,
|
self.np_stop_flags,
|
||||||
|
self.prompt_lens,
|
||||||
self.np_seq_lens_this_time,
|
self.np_seq_lens_this_time,
|
||||||
|
self.np_seq_lens_encoder,
|
||||||
self.np_seq_lens_decoder,
|
self.np_seq_lens_decoder,
|
||||||
self.np_step_seq_lens_decoder,
|
self.np_step_seq_lens_decoder,
|
||||||
self.np_step_draft_tokens,
|
self.np_step_draft_tokens,
|
||||||
@@ -212,7 +222,9 @@ class TestSpeculateScheduleCache(unittest.TestCase):
|
|||||||
self.draft_tokens,
|
self.draft_tokens,
|
||||||
self.block_tables,
|
self.block_tables,
|
||||||
self.stop_flags,
|
self.stop_flags,
|
||||||
|
self.prompt_lens,
|
||||||
self.seq_lens_this_time,
|
self.seq_lens_this_time,
|
||||||
|
self.seq_lens_encoder,
|
||||||
self.seq_lens_decoder,
|
self.seq_lens_decoder,
|
||||||
self.step_seq_lens_decoder,
|
self.step_seq_lens_decoder,
|
||||||
self.step_draft_tokens,
|
self.step_draft_tokens,
|
||||||
|
|||||||
Reference in New Issue
Block a user