support mtp in splitewise and scheduler_v1 mode (#4743)

This commit is contained in:
freeliuzc
2025-11-03 10:07:15 +08:00
committed by GitHub
parent b8bf57138f
commit f44f4bafd1
3 changed files with 330 additions and 240 deletions

View File

@@ -17,144 +17,201 @@
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "../speculate_msg.h"
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#define MAX_BSZ 256
// #define SAVE_WITH_OUTPUT_DEBUG
#define MAX_DRAFT_TOKENS 6
struct msgdata {
long mtype;
int mtext[2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS]; // stop_flag, token_num, tokens
};
void MTPSaveFirstToken(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
if (!save_each_rank && rank_id > 0) {
return;
}
int x_dim = x.shape()[1];
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
int64_t* x_data = x_cpu.data<int64_t>();
static struct msgdata msg_sed;
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(
inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static key_t key = ftok("./", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
msg_sed.mtype = 1;
bool not_need_stop_data = not_need_stop.data<bool>()[0];
int inference_msg_id_from_env = 1;
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
std::string inference_msg_id_env_str(inference_msg_id_env_p);
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
if (inference_msg_id_from_env == 2) {
// 2 and -2 is preserve for no-output indication.
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be 2, please use other number.");
}
if (inference_msg_id_from_env < 0) {
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be negative, please use other "
"number.");
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
<< std::endl;
#endif
} else {
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout
<< "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
<< std::endl;
#endif
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "save_output_key: " << key << std::endl;
std::cout << "save msgid: " << msgid << std::endl;
#endif
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
: -inference_msg_id_from_env;
int bsz = x.shape()[0];
msg_sed.mtext[1] = bsz;
for (int i = 0; i < bsz; i++) {
#ifdef SAVE_WITH_OUTPUT_DEBUG
printf("bid: %d. 1: %d. 2: %d.\n", i, (int)x_data[i * x_dim], (int)x_data[i * x_dim + 1]);
#endif
msg_sed.mtext[i + 2] = 2;
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ] = (int)x_data[i * x_dim];
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ] = (int)x_data[i * x_dim + 1];
#ifdef SAVE_WITH_OUTPUT_DEBUG
printf("mtext[%d]:%d. mtext[%d]:%d. \n", i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ,
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ],
i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ,
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ]);
#endif
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "msg data: ";
for (int i = 0; i < bsz; i++) {
std::cout << " " << (int)x_data[2*i] << " ";
std::cout << " " << (int)x_data[2*i + 1];
}
std::cout << std::endl;
#endif
if ((msgsnd(msgid,
&msg_sed,
(2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS) * 4, 0)) == -1) {
printf("full msg buffer\n");
}
const paddle::Tensor& not_need_stop,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& step_idx,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank,
bool skip_chunk_prefill) {
if (!save_each_rank && rank_id > 0) {
return;
}
int x_dim = x.shape()[1];
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
int64_t* x_data = x_cpu.data<int64_t>();
auto seq_lens_decoder_cpu =
seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
int* seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true);
int64_t* prompt_lens_data = prompt_lens_cpu.data<int64_t>();
auto step_idx_cpu = step_idx.copy_to(paddle::CPUPlace(), true);
int64_t* step_idx_data = step_idx_cpu.data<int64_t>();
static struct speculate_msgdata msg_sed;
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static key_t key = ftok("./", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
msg_sed.mtype = 1;
bool not_need_stop_data = not_need_stop.data<bool>()[0];
int inference_msg_id_from_env = 1;
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
std::string inference_msg_id_env_str(inference_msg_id_env_p);
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
if (inference_msg_id_from_env == 2) {
// 2 and -2 is preserve for no-output indication.
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be 2, please use other number.");
}
if (inference_msg_id_from_env < 0) {
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be negative, please use other "
"number.");
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
<< std::endl;
#endif
} else {
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
<< std::endl;
#endif
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "save_output_key: " << key << std::endl;
std::cout << "save msgid: " << msgid << std::endl;
#endif
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
: -inference_msg_id_from_env;
int bsz = x.shape()[0];
msg_sed.mtext[1] = bsz;
for (int i = 0; i < bsz; i++) {
#ifdef SAVE_WITH_OUTPUT_DEBUG
printf("bid: %d. 1: %d. 2: %d.\n",
i,
(int)x_data[i * x_dim],
(int)x_data[i * x_dim + 1]);
#endif
if ((skip_chunk_prefill &&
seq_lens_decoder_data[i] < prompt_lens_data[i]) ||
step_idx_data[i] == 0) {
msg_sed.mtext[i + 2] = 0;
#ifdef SAVE_WITH_OUTPUT_DEBUG
printf("bid[%d] skip save mtp output \n", i);
#endif
continue;
} else if (step_idx_data[i] == 1) {
#ifdef SAVE_WITH_OUTPUT_DEBUG
printf("bid[%d] save mtp tokens \n", i);
#endif
msg_sed.mtext[i + 2] = 2;
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ] =
(int)x_data[i * x_dim];
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ] =
(int)x_data[i * x_dim + 1];
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
printf("mtext[%d]:%d. mtext[%d]:%d. \n",
i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ,
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ],
i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ,
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ]);
#endif
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "msg data: ";
for (int i = 0; i < bsz; i++) {
std::cout << " " << (int)x_data[2 * i] << " ";
std::cout << " " << (int)x_data[2 * i + 1];
}
std::cout << std::endl;
#endif
if ((msgsnd(msgid,
&msg_sed,
(2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS) * 4,
0)) == -1) {
printf("full msg buffer\n");
}
return;
}
void MTPSaveFirstTokenStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
bool save_each_rank) {
MTPSaveFirstToken(x, not_need_stop, rank_id, 1, save_each_rank);
const paddle::Tensor& not_need_stop,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& step_idx,
int64_t rank_id,
bool save_each_rank,
bool skip_chunk_prefill) {
MTPSaveFirstToken(x,
not_need_stop,
seq_lens_decoder,
prompt_lens,
step_idx,
rank_id,
1,
save_each_rank,
skip_chunk_prefill);
}
void MTPSaveFirstTokenDynamic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
MTPSaveFirstToken(x, not_need_stop, rank_id, msg_queue_id, save_each_rank);
const paddle::Tensor& not_need_stop,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& step_idx,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank,
bool skip_chunk_prefill) {
MTPSaveFirstToken(x,
not_need_stop,
seq_lens_decoder,
prompt_lens,
step_idx,
rank_id,
msg_queue_id,
save_each_rank,
skip_chunk_prefill);
}
PD_BUILD_STATIC_OP(mtp_save_first_token)
.Inputs({"x", "not_need_stop"})
.Inputs(
{"x", "not_need_stop", "seq_lens_decoder", "prompt_lens", "step_idx"})
.Attrs({"rank_id: int64_t",
"save_each_rank: bool"})
"save_each_rank: bool",
"skip_chunk_prefill: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenStatic));
PD_BUILD_STATIC_OP(mtp_save_first_token_dynamic)
.Inputs({"x", "not_need_stop"})
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
.Inputs(
{"x", "not_need_stop", "seq_lens_decoder", "prompt_lens", "step_idx"})
.Attrs({"rank_id: int64_t",
"msg_queue_id: int",
"save_each_rank: bool",
"skip_chunk_prefill: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenDynamic));

View File

@@ -15,85 +15,94 @@
#include "helper.h"
template <int THREADBLOCK_SIZE>
__global__ void speculate_schedula_cache(
const int64_t *draft_tokens,
int *block_tables,
bool *stop_flags,
const int64_t* prompt_lens,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *step_draft_tokens,
int *step_seq_lens_this_time,
int *accept_num,
int64_t *accept_tokens,
bool *is_block_step,
bool *not_need_stop,
const int64_t *stop_nums,
const int real_bsz,
const int max_bsz,
const int max_next_step_tokens,
const int draft_tokens_len,
const int accept_tokens_len,
const int block_size,
const int block_num_per_seq) {
const int bid = threadIdx.x;
int stop_flag_now_int = 0;
if (bid < real_bsz) {
if (!stop_flags[bid]) {
const int64_t *draft_tokens_now = draft_tokens + bid * draft_tokens_len;
int64_t *step_draft_tokens_now = step_draft_tokens + bid * draft_tokens_len;
int *block_table_now = block_tables + bid * block_num_per_seq;
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
__global__ void speculate_schedula_cache(const int64_t *draft_tokens,
int *block_tables,
bool *stop_flags,
const int64_t *prompt_lens,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *step_draft_tokens,
int *step_seq_lens_this_time,
int *accept_num,
int64_t *accept_tokens,
bool *is_block_step,
bool *not_need_stop,
const int64_t *stop_nums,
const int real_bsz,
const int max_bsz,
const int max_next_step_tokens,
const int draft_tokens_len,
const int accept_tokens_len,
const int block_size,
const int block_num_per_seq,
const bool prefill_one_step_stop) {
const int bid = threadIdx.x;
int stop_flag_now_int = 0;
if (bid < real_bsz) {
if (!stop_flags[bid]) {
const int64_t *draft_tokens_now = draft_tokens + bid * draft_tokens_len;
int64_t *step_draft_tokens_now =
step_draft_tokens + bid * draft_tokens_len;
int *block_table_now = block_tables + bid * block_num_per_seq;
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
if (seq_lens_decoder[bid] >= prompt_lens[bid]) {
// decoder
const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) {
is_block_step[bid] = true;
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
seq_lens_this_time[bid] = 0;
stop_flags[bid] = true;
stop_flag_now_int = 1;
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
seq_lens_decoder[bid] = 0;
accept_num[bid] = 0;
for (int i = 0; i < accept_tokens_len; i++) {
accept_tokens_now[i] = -1;
}
for (int i = 0; i < draft_tokens_len; i++) {
step_draft_tokens_now[i] = draft_tokens_now[i];
}
}
} else {
// prefill
stop_flags[bid] = true;
seq_lens_this_time[bid] = 0;
seq_lens_decoder[bid] = 0;
seq_lens_encoder[bid] = 0;
accept_num[bid] = 0;
stop_flag_now_int = 1;
}
if (seq_lens_decoder[bid] >= prompt_lens[bid]) {
const int max_possible_block_idx =
(seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
} else {
stop_flag_now_int = 1;
if (prefill_one_step_stop) {
stop_flags[bid] = true;
seq_lens_this_time[bid] = 0;
seq_lens_decoder[bid] = 0;
seq_lens_encoder[bid] = 0;
accept_num[bid] = 0;
stop_flag_now_int = 1;
} else if (max_possible_block_idx < block_num_per_seq &&
block_table_now[max_possible_block_idx] == -1) {
is_block_step[bid] = true;
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
seq_lens_this_time[bid] = 0;
stop_flags[bid] = true;
stop_flag_now_int = 1;
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
seq_lens_decoder[bid] = 0;
accept_num[bid] = 0;
for (int i = 0; i < accept_tokens_len; i++) {
accept_tokens_now[i] = -1;
}
for (int i = 0; i < draft_tokens_len; i++) {
step_draft_tokens_now[i] = draft_tokens_now[i];
}
}
} else if (bid >= real_bsz && bid < max_bsz) {
} else {
// prefill
stop_flags[bid] = true;
seq_lens_this_time[bid] = 0;
seq_lens_decoder[bid] = 0;
seq_lens_encoder[bid] = 0;
accept_num[bid] = 0;
stop_flag_now_int = 1;
}
__syncthreads();
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
}
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
if (threadIdx.x == 0) {
// printf("stop_sum %d \n", stop_sum);
not_need_stop[0] = stop_sum < stop_nums[0];
} else {
stop_flag_now_int = 1;
}
} else if (bid >= real_bsz && bid < max_bsz) {
stop_flag_now_int = 1;
}
__syncthreads();
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
if (threadIdx.x == 0) {
// printf("stop_sum %d \n", stop_sum);
not_need_stop[0] = stop_sum < stop_nums[0];
}
}
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
@@ -113,45 +122,51 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
const paddle::Tensor &stop_nums,
const int block_size,
const int max_draft_tokens) {
const int real_bsz = seq_lens_this_time.shape()[0];
const int max_bsz = stop_flags.shape()[0];
const int accept_tokens_len = accept_tokens.shape()[1];
const int draft_token_len = draft_tokens.shape()[1];
const int block_num_per_seq = block_tables.shape()[1];
const int real_bsz = seq_lens_this_time.shape()[0];
const int max_bsz = stop_flags.shape()[0];
const int accept_tokens_len = accept_tokens.shape()[1];
const int draft_token_len = draft_tokens.shape()[1];
const int block_num_per_seq = block_tables.shape()[1];
constexpr int BlockSize = 512;
const int max_next_step_tokens = 2 * max_draft_tokens + 2;
constexpr int BlockSize = 512;
const int max_next_step_tokens = 2 * max_draft_tokens + 2;
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
}
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
speculate_schedula_cache<BlockSize>
<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
draft_tokens.data<int64_t>(),
const_cast<int *>(block_tables.data<int>()),
const_cast<bool *>(stop_flags.data<bool>()),
prompt_lens.data<int64_t>(),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int64_t *>(step_draft_tokens.data<int64_t>()),
const_cast<int *>(step_seq_lens_this_time.data<int>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<bool *>(is_block_step.data<bool>()),
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
stop_nums.data<int64_t>(),
real_bsz,
max_bsz,
max_next_step_tokens,
draft_token_len,
accept_tokens_len,
block_size,
block_num_per_seq,
prefill_one_step_stop);
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
speculate_schedula_cache<BlockSize><<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
draft_tokens.data<int64_t>(),
const_cast<int *>(block_tables.data<int>()),
const_cast<bool *>(stop_flags.data<bool>()),
prompt_lens.data<int64_t>(),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int64_t *>(step_draft_tokens.data<int64_t>()),
const_cast<int *>(step_seq_lens_this_time.data<int>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<bool *>(is_block_step.data<bool>()),
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
stop_nums.data<int64_t>(),
real_bsz,
max_bsz,
max_next_step_tokens,
draft_token_len,
accept_tokens_len,
block_size,
block_num_per_seq
);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(speculate_schedule_cache)
@@ -184,17 +199,19 @@ PD_BUILD_STATIC_OP(speculate_schedule_cache)
"accept_tokens_out",
"is_block_step_out",
"not_need_stop_out"})
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"block_tables", "block_tables_out"},
{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
{"step_draft_tokens", "step_draft_tokens_out"},
{"step_seq_lens_this_time", "step_seq_lens_this_time_out"},
{"accept_num", "accept_num_out"},
{"accept_tokens", "accept_tokens_out"},
{"is_block_step", "is_block_step_out"},
{"not_need_stop", "not_need_stop_out"},})
.SetInplaceMap({
{"draft_tokens", "draft_tokens_out"},
{"block_tables", "block_tables_out"},
{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
{"step_draft_tokens", "step_draft_tokens_out"},
{"step_seq_lens_this_time", "step_seq_lens_this_time_out"},
{"accept_num", "accept_num_out"},
{"accept_tokens", "accept_tokens_out"},
{"is_block_step", "is_block_step_out"},
{"not_need_stop", "not_need_stop_out"},
})
.SetKernelFn(PD_KERNEL(SpeculateScheduleCache));

View File

@@ -77,7 +77,7 @@ class MTPProposer(Proposer):
self.enable_logprob = self.model_config.enable_logprob
# [mixed, prefill, decoder]
self.role = "mixed"
self.role = self.scheduler_config.splitwise_role
self.sampler = MTPSampler(fd_config)
self._init_model_inputs()
@@ -365,6 +365,7 @@ class MTPProposer(Proposer):
)
# self.model_inputs["caches"] = self.cache_kvs
# Inherit generation hyperparameters from the main model for consistency
self.model_inputs["prompt_lens"] = self.target_model_inputs["prompt_lens"]
self.model_inputs["top_p"] = self.target_model_inputs["top_p"]
self.model_inputs["top_k"] = self.target_model_inputs["top_k"]
self.model_inputs["temperature"] = self.target_model_inputs["temperature"]
@@ -501,9 +502,10 @@ class MTPProposer(Proposer):
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.model_inputs["is_block_step"][idx : idx + 1] = False
continue
# if has_prefill_task or has_decode_task:
# self.model_inputs["not_need_stop"][0] = True
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
# TODO(liuzichang): Solve splitewise-p bug to restore
# self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
"""
@@ -704,11 +706,25 @@ class MTPProposer(Proposer):
self.model_inputs["substep"],
)
if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0:
skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER))
mtp_save_first_token(
self.model_inputs["base_model_draft_tokens"],
self.model_inputs["not_need_stop"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["prompt_lens"],
self.model_inputs["step_idx"],
self.local_rank,
self.parallel_config.use_ep,
skip_save,
)
# Ensure only save first token once.
paddle.assign(
paddle.where(
self.model_inputs["stop_flags"],
paddle.zeros_like(self.model_inputs["step_idx"]),
self.model_inputs["step_idx"],
),
self.model_inputs["step_idx"],
)
def _propose(self, step_use_cudagraph: bool = False):