mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
support mtp in splitewise and scheduler_v1 mode (#4743)
This commit is contained in:
@@ -17,144 +17,201 @@
|
||||
#include <sys/ipc.h>
|
||||
#include <sys/msg.h>
|
||||
#include <sys/types.h>
|
||||
#include "../speculate_msg.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#define MAX_BSZ 256
|
||||
|
||||
// #define SAVE_WITH_OUTPUT_DEBUG
|
||||
#define MAX_DRAFT_TOKENS 6
|
||||
struct msgdata {
|
||||
long mtype;
|
||||
int mtext[2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS]; // stop_flag, token_num, tokens
|
||||
};
|
||||
|
||||
void MTPSaveFirstToken(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
int msg_queue_id,
|
||||
bool save_each_rank) {
|
||||
if (!save_each_rank && rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
int x_dim = x.shape()[1];
|
||||
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
|
||||
int64_t* x_data = x_cpu.data<int64_t>();
|
||||
static struct msgdata msg_sed;
|
||||
|
||||
if (const char* inference_msg_queue_id_env_p =
|
||||
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
|
||||
std::string inference_msg_queue_id_env_str(
|
||||
inference_msg_queue_id_env_p);
|
||||
int inference_msg_queue_id_from_env =
|
||||
std::stoi(inference_msg_queue_id_env_str);
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
|
||||
<< inference_msg_queue_id_from_env << std::endl;
|
||||
#endif
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
}
|
||||
|
||||
static key_t key = ftok("./", msg_queue_id);
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
|
||||
msg_sed.mtype = 1;
|
||||
bool not_need_stop_data = not_need_stop.data<bool>()[0];
|
||||
int inference_msg_id_from_env = 1;
|
||||
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
if (inference_msg_id_from_env < 0) {
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be negative, please use other "
|
||||
"number.");
|
||||
}
|
||||
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
|
||||
<< std::endl;
|
||||
#endif
|
||||
} else {
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout
|
||||
<< "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
|
||||
<< std::endl;
|
||||
#endif
|
||||
}
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "save_output_key: " << key << std::endl;
|
||||
std::cout << "save msgid: " << msgid << std::endl;
|
||||
#endif
|
||||
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
|
||||
: -inference_msg_id_from_env;
|
||||
int bsz = x.shape()[0];
|
||||
msg_sed.mtext[1] = bsz;
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
printf("bid: %d. 1: %d. 2: %d.\n", i, (int)x_data[i * x_dim], (int)x_data[i * x_dim + 1]);
|
||||
#endif
|
||||
msg_sed.mtext[i + 2] = 2;
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ] = (int)x_data[i * x_dim];
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ] = (int)x_data[i * x_dim + 1];
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
printf("mtext[%d]:%d. mtext[%d]:%d. \n", i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ,
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ],
|
||||
i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ,
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ]);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "msg data: ";
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
std::cout << " " << (int)x_data[2*i] << " ";
|
||||
std::cout << " " << (int)x_data[2*i + 1];
|
||||
|
||||
}
|
||||
std::cout << std::endl;
|
||||
#endif
|
||||
if ((msgsnd(msgid,
|
||||
&msg_sed,
|
||||
(2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS) * 4, 0)) == -1) {
|
||||
printf("full msg buffer\n");
|
||||
}
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& prompt_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
int64_t rank_id,
|
||||
int msg_queue_id,
|
||||
bool save_each_rank,
|
||||
bool skip_chunk_prefill) {
|
||||
if (!save_each_rank && rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
int x_dim = x.shape()[1];
|
||||
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
|
||||
int64_t* x_data = x_cpu.data<int64_t>();
|
||||
|
||||
auto seq_lens_decoder_cpu =
|
||||
seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
|
||||
int* seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
|
||||
|
||||
auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true);
|
||||
int64_t* prompt_lens_data = prompt_lens_cpu.data<int64_t>();
|
||||
|
||||
auto step_idx_cpu = step_idx.copy_to(paddle::CPUPlace(), true);
|
||||
int64_t* step_idx_data = step_idx_cpu.data<int64_t>();
|
||||
|
||||
static struct speculate_msgdata msg_sed;
|
||||
|
||||
if (const char* inference_msg_queue_id_env_p =
|
||||
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
|
||||
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
|
||||
int inference_msg_queue_id_from_env =
|
||||
std::stoi(inference_msg_queue_id_env_str);
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
|
||||
<< inference_msg_queue_id_from_env << std::endl;
|
||||
#endif
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
}
|
||||
|
||||
static key_t key = ftok("./", msg_queue_id);
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
|
||||
msg_sed.mtype = 1;
|
||||
bool not_need_stop_data = not_need_stop.data<bool>()[0];
|
||||
int inference_msg_id_from_env = 1;
|
||||
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
if (inference_msg_id_from_env < 0) {
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be negative, please use other "
|
||||
"number.");
|
||||
}
|
||||
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
|
||||
<< std::endl;
|
||||
#endif
|
||||
} else {
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
|
||||
<< std::endl;
|
||||
#endif
|
||||
}
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "save_output_key: " << key << std::endl;
|
||||
std::cout << "save msgid: " << msgid << std::endl;
|
||||
#endif
|
||||
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
|
||||
: -inference_msg_id_from_env;
|
||||
int bsz = x.shape()[0];
|
||||
msg_sed.mtext[1] = bsz;
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
printf("bid: %d. 1: %d. 2: %d.\n",
|
||||
i,
|
||||
(int)x_data[i * x_dim],
|
||||
(int)x_data[i * x_dim + 1]);
|
||||
#endif
|
||||
if ((skip_chunk_prefill &&
|
||||
seq_lens_decoder_data[i] < prompt_lens_data[i]) ||
|
||||
step_idx_data[i] == 0) {
|
||||
msg_sed.mtext[i + 2] = 0;
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
printf("bid[%d] skip save mtp output \n", i);
|
||||
#endif
|
||||
continue;
|
||||
} else if (step_idx_data[i] == 1) {
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
printf("bid[%d] save mtp tokens \n", i);
|
||||
#endif
|
||||
msg_sed.mtext[i + 2] = 2;
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ] =
|
||||
(int)x_data[i * x_dim];
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ] =
|
||||
(int)x_data[i * x_dim + 1];
|
||||
}
|
||||
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
printf("mtext[%d]:%d. mtext[%d]:%d. \n",
|
||||
i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ,
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ],
|
||||
i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ,
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ]);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "msg data: ";
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
std::cout << " " << (int)x_data[2 * i] << " ";
|
||||
std::cout << " " << (int)x_data[2 * i + 1];
|
||||
}
|
||||
std::cout << std::endl;
|
||||
#endif
|
||||
if ((msgsnd(msgid,
|
||||
&msg_sed,
|
||||
(2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS) * 4,
|
||||
0)) == -1) {
|
||||
printf("full msg buffer\n");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void MTPSaveFirstTokenStatic(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
bool save_each_rank) {
|
||||
MTPSaveFirstToken(x, not_need_stop, rank_id, 1, save_each_rank);
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& prompt_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
int64_t rank_id,
|
||||
bool save_each_rank,
|
||||
bool skip_chunk_prefill) {
|
||||
MTPSaveFirstToken(x,
|
||||
not_need_stop,
|
||||
seq_lens_decoder,
|
||||
prompt_lens,
|
||||
step_idx,
|
||||
rank_id,
|
||||
1,
|
||||
save_each_rank,
|
||||
skip_chunk_prefill);
|
||||
}
|
||||
|
||||
void MTPSaveFirstTokenDynamic(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
int msg_queue_id,
|
||||
bool save_each_rank) {
|
||||
MTPSaveFirstToken(x, not_need_stop, rank_id, msg_queue_id, save_each_rank);
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& prompt_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
int64_t rank_id,
|
||||
int msg_queue_id,
|
||||
bool save_each_rank,
|
||||
bool skip_chunk_prefill) {
|
||||
MTPSaveFirstToken(x,
|
||||
not_need_stop,
|
||||
seq_lens_decoder,
|
||||
prompt_lens,
|
||||
step_idx,
|
||||
rank_id,
|
||||
msg_queue_id,
|
||||
save_each_rank,
|
||||
skip_chunk_prefill);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(mtp_save_first_token)
|
||||
.Inputs({"x", "not_need_stop"})
|
||||
.Inputs(
|
||||
{"x", "not_need_stop", "seq_lens_decoder", "prompt_lens", "step_idx"})
|
||||
.Attrs({"rank_id: int64_t",
|
||||
"save_each_rank: bool"})
|
||||
"save_each_rank: bool",
|
||||
"skip_chunk_prefill: bool"})
|
||||
.Outputs({"x_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}})
|
||||
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenStatic));
|
||||
|
||||
PD_BUILD_STATIC_OP(mtp_save_first_token_dynamic)
|
||||
.Inputs({"x", "not_need_stop"})
|
||||
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
|
||||
.Inputs(
|
||||
{"x", "not_need_stop", "seq_lens_decoder", "prompt_lens", "step_idx"})
|
||||
.Attrs({"rank_id: int64_t",
|
||||
"msg_queue_id: int",
|
||||
"save_each_rank: bool",
|
||||
"skip_chunk_prefill: bool"})
|
||||
.Outputs({"x_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}})
|
||||
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenDynamic));
|
||||
|
||||
@@ -15,85 +15,94 @@
|
||||
#include "helper.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void speculate_schedula_cache(
|
||||
const int64_t *draft_tokens,
|
||||
int *block_tables,
|
||||
bool *stop_flags,
|
||||
const int64_t* prompt_lens,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *step_draft_tokens,
|
||||
int *step_seq_lens_this_time,
|
||||
int *accept_num,
|
||||
int64_t *accept_tokens,
|
||||
bool *is_block_step,
|
||||
bool *not_need_stop,
|
||||
const int64_t *stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
const int draft_tokens_len,
|
||||
const int accept_tokens_len,
|
||||
const int block_size,
|
||||
const int block_num_per_seq) {
|
||||
const int bid = threadIdx.x;
|
||||
int stop_flag_now_int = 0;
|
||||
if (bid < real_bsz) {
|
||||
if (!stop_flags[bid]) {
|
||||
const int64_t *draft_tokens_now = draft_tokens + bid * draft_tokens_len;
|
||||
int64_t *step_draft_tokens_now = step_draft_tokens + bid * draft_tokens_len;
|
||||
int *block_table_now = block_tables + bid * block_num_per_seq;
|
||||
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
|
||||
__global__ void speculate_schedula_cache(const int64_t *draft_tokens,
|
||||
int *block_tables,
|
||||
bool *stop_flags,
|
||||
const int64_t *prompt_lens,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *step_draft_tokens,
|
||||
int *step_seq_lens_this_time,
|
||||
int *accept_num,
|
||||
int64_t *accept_tokens,
|
||||
bool *is_block_step,
|
||||
bool *not_need_stop,
|
||||
const int64_t *stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
const int draft_tokens_len,
|
||||
const int accept_tokens_len,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const bool prefill_one_step_stop) {
|
||||
const int bid = threadIdx.x;
|
||||
int stop_flag_now_int = 0;
|
||||
if (bid < real_bsz) {
|
||||
if (!stop_flags[bid]) {
|
||||
const int64_t *draft_tokens_now = draft_tokens + bid * draft_tokens_len;
|
||||
int64_t *step_draft_tokens_now =
|
||||
step_draft_tokens + bid * draft_tokens_len;
|
||||
int *block_table_now = block_tables + bid * block_num_per_seq;
|
||||
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
|
||||
|
||||
if (seq_lens_decoder[bid] >= prompt_lens[bid]) {
|
||||
// decoder
|
||||
const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
|
||||
if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) {
|
||||
is_block_step[bid] = true;
|
||||
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
|
||||
seq_lens_this_time[bid] = 0;
|
||||
stop_flags[bid] = true;
|
||||
stop_flag_now_int = 1;
|
||||
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
|
||||
seq_lens_decoder[bid] = 0;
|
||||
accept_num[bid] = 0;
|
||||
for (int i = 0; i < accept_tokens_len; i++) {
|
||||
accept_tokens_now[i] = -1;
|
||||
}
|
||||
for (int i = 0; i < draft_tokens_len; i++) {
|
||||
step_draft_tokens_now[i] = draft_tokens_now[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// prefill
|
||||
stop_flags[bid] = true;
|
||||
seq_lens_this_time[bid] = 0;
|
||||
seq_lens_decoder[bid] = 0;
|
||||
seq_lens_encoder[bid] = 0;
|
||||
accept_num[bid] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
if (seq_lens_decoder[bid] >= prompt_lens[bid]) {
|
||||
const int max_possible_block_idx =
|
||||
(seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
|
||||
|
||||
|
||||
} else {
|
||||
stop_flag_now_int = 1;
|
||||
if (prefill_one_step_stop) {
|
||||
stop_flags[bid] = true;
|
||||
seq_lens_this_time[bid] = 0;
|
||||
seq_lens_decoder[bid] = 0;
|
||||
seq_lens_encoder[bid] = 0;
|
||||
accept_num[bid] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
} else if (max_possible_block_idx < block_num_per_seq &&
|
||||
block_table_now[max_possible_block_idx] == -1) {
|
||||
is_block_step[bid] = true;
|
||||
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
|
||||
seq_lens_this_time[bid] = 0;
|
||||
stop_flags[bid] = true;
|
||||
stop_flag_now_int = 1;
|
||||
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
|
||||
seq_lens_decoder[bid] = 0;
|
||||
accept_num[bid] = 0;
|
||||
for (int i = 0; i < accept_tokens_len; i++) {
|
||||
accept_tokens_now[i] = -1;
|
||||
}
|
||||
for (int i = 0; i < draft_tokens_len; i++) {
|
||||
step_draft_tokens_now[i] = draft_tokens_now[i];
|
||||
}
|
||||
}
|
||||
} else if (bid >= real_bsz && bid < max_bsz) {
|
||||
} else {
|
||||
// prefill
|
||||
stop_flags[bid] = true;
|
||||
seq_lens_this_time[bid] = 0;
|
||||
seq_lens_decoder[bid] = 0;
|
||||
seq_lens_encoder[bid] = 0;
|
||||
accept_num[bid] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
__syncthreads();
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
}
|
||||
|
||||
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// printf("stop_sum %d \n", stop_sum);
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
} else {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
} else if (bid >= real_bsz && bid < max_bsz) {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
__syncthreads();
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// printf("stop_sum %d \n", stop_sum);
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
@@ -113,45 +122,51 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const int block_size,
|
||||
const int max_draft_tokens) {
|
||||
const int real_bsz = seq_lens_this_time.shape()[0];
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int accept_tokens_len = accept_tokens.shape()[1];
|
||||
const int draft_token_len = draft_tokens.shape()[1];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
const int real_bsz = seq_lens_this_time.shape()[0];
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int accept_tokens_len = accept_tokens.shape()[1];
|
||||
const int draft_token_len = draft_tokens.shape()[1];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
|
||||
constexpr int BlockSize = 512;
|
||||
const int max_next_step_tokens = 2 * max_draft_tokens + 2;
|
||||
constexpr int BlockSize = 512;
|
||||
const int max_next_step_tokens = 2 * max_draft_tokens + 2;
|
||||
bool prefill_one_step_stop = false;
|
||||
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
|
||||
if (env_p[0] == '1') {
|
||||
prefill_one_step_stop = true;
|
||||
}
|
||||
}
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
speculate_schedula_cache<BlockSize>
|
||||
<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
|
||||
draft_tokens.data<int64_t>(),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
prompt_lens.data<int64_t>(),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(step_draft_tokens.data<int64_t>()),
|
||||
const_cast<int *>(step_seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens,
|
||||
draft_token_len,
|
||||
accept_tokens_len,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
prefill_one_step_stop);
|
||||
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
speculate_schedula_cache<BlockSize><<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
|
||||
draft_tokens.data<int64_t>(),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
prompt_lens.data<int64_t>(),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(step_draft_tokens.data<int64_t>()),
|
||||
const_cast<int *>(step_seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens,
|
||||
draft_token_len,
|
||||
accept_tokens_len,
|
||||
block_size,
|
||||
block_num_per_seq
|
||||
);
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_schedule_cache)
|
||||
@@ -184,17 +199,19 @@ PD_BUILD_STATIC_OP(speculate_schedule_cache)
|
||||
"accept_tokens_out",
|
||||
"is_block_step_out",
|
||||
"not_need_stop_out"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"block_tables", "block_tables_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
|
||||
{"step_draft_tokens", "step_draft_tokens_out"},
|
||||
{"step_seq_lens_this_time", "step_seq_lens_this_time_out"},
|
||||
{"accept_num", "accept_num_out"},
|
||||
{"accept_tokens", "accept_tokens_out"},
|
||||
{"is_block_step", "is_block_step_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},})
|
||||
.SetInplaceMap({
|
||||
{"draft_tokens", "draft_tokens_out"},
|
||||
{"block_tables", "block_tables_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
|
||||
{"step_draft_tokens", "step_draft_tokens_out"},
|
||||
{"step_seq_lens_this_time", "step_seq_lens_this_time_out"},
|
||||
{"accept_num", "accept_num_out"},
|
||||
{"accept_tokens", "accept_tokens_out"},
|
||||
{"is_block_step", "is_block_step_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateScheduleCache));
|
||||
|
||||
@@ -77,7 +77,7 @@ class MTPProposer(Proposer):
|
||||
self.enable_logprob = self.model_config.enable_logprob
|
||||
|
||||
# [mixed, prefill, decoder]
|
||||
self.role = "mixed"
|
||||
self.role = self.scheduler_config.splitwise_role
|
||||
|
||||
self.sampler = MTPSampler(fd_config)
|
||||
self._init_model_inputs()
|
||||
@@ -365,6 +365,7 @@ class MTPProposer(Proposer):
|
||||
)
|
||||
# self.model_inputs["caches"] = self.cache_kvs
|
||||
# Inherit generation hyperparameters from the main model for consistency
|
||||
self.model_inputs["prompt_lens"] = self.target_model_inputs["prompt_lens"]
|
||||
self.model_inputs["top_p"] = self.target_model_inputs["top_p"]
|
||||
self.model_inputs["top_k"] = self.target_model_inputs["top_k"]
|
||||
self.model_inputs["temperature"] = self.target_model_inputs["temperature"]
|
||||
@@ -501,9 +502,10 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
||||
self.model_inputs["is_block_step"][idx : idx + 1] = False
|
||||
continue
|
||||
# if has_prefill_task or has_decode_task:
|
||||
# self.model_inputs["not_need_stop"][0] = True
|
||||
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
|
||||
|
||||
# TODO(liuzichang): Solve splitewise-p bug to restore
|
||||
# self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
|
||||
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
|
||||
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
|
||||
"""
|
||||
@@ -704,11 +706,25 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["substep"],
|
||||
)
|
||||
if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0:
|
||||
skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER))
|
||||
mtp_save_first_token(
|
||||
self.model_inputs["base_model_draft_tokens"],
|
||||
self.model_inputs["not_need_stop"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
self.model_inputs["prompt_lens"],
|
||||
self.model_inputs["step_idx"],
|
||||
self.local_rank,
|
||||
self.parallel_config.use_ep,
|
||||
skip_save,
|
||||
)
|
||||
# Ensure only save first token once.
|
||||
paddle.assign(
|
||||
paddle.where(
|
||||
self.model_inputs["stop_flags"],
|
||||
paddle.zeros_like(self.model_inputs["step_idx"]),
|
||||
self.model_inputs["step_idx"],
|
||||
),
|
||||
self.model_inputs["step_idx"],
|
||||
)
|
||||
|
||||
def _propose(self, step_use_cudagraph: bool = False):
|
||||
|
||||
Reference in New Issue
Block a user